mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-02 14:09:56 +00:00
Compare commits
241 Commits
test/perft
...
poc-token-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b016a1f0d0 | ||
|
|
c009055693 | ||
|
|
14181c909c | ||
|
|
a05dc3823d | ||
|
|
7d19bdf085 | ||
|
|
a1b048f2ad | ||
|
|
0bd227196e | ||
|
|
eea7687ddf | ||
|
|
57d3ee5aac | ||
|
|
cfdfdecc14 | ||
|
|
ac995bae6d | ||
|
|
41a5509ce0 | ||
|
|
db5e26db94 | ||
|
|
fe975fb834 | ||
|
|
e368d2995b | ||
|
|
a3241d8376 | ||
|
|
6dfc5772ba | ||
|
|
f70925178c | ||
|
|
9554934b92 | ||
|
|
7fdb824a37 | ||
|
|
412407adc0 | ||
|
|
e0874d7de7 | ||
|
|
8df1536cbb | ||
|
|
fcbacc62ec | ||
|
|
ee2ae45653 | ||
|
|
6f2f0f9ae4 | ||
|
|
c37ebc6fb3 | ||
|
|
23abb5743c | ||
|
|
b87aa0bc15 | ||
|
|
f1a65d732d | ||
|
|
a3c0ea3e71 | ||
|
|
abaf061c2a | ||
|
|
e531fb54b1 | ||
|
|
5fcfed5b16 | ||
|
|
5f43449f67 | ||
|
|
6796601aa6 | ||
|
|
1fc25c301b | ||
|
|
08ae281b2d | ||
|
|
bd47f44c63 | ||
|
|
381260911b | ||
|
|
38db42e7d6 | ||
|
|
5d606d909d | ||
|
|
d689718b50 | ||
|
|
54a73c6649 | ||
|
|
418377842e | ||
|
|
15ef56e03d | ||
|
|
917035f8e8 | ||
|
|
963e3f5457 | ||
|
|
e20b969188 | ||
|
|
1c7059ee67 | ||
|
|
22a3365658 | ||
|
|
08ab1e3478 | ||
|
|
ebb1f4007d | ||
|
|
acb53ece93 | ||
|
|
e020950cfd | ||
|
|
9dba262a20 | ||
|
|
5bcdf36377 | ||
|
|
1ffe8deb10 | ||
|
|
d069145bd1 | ||
|
|
f3493ee042 | ||
|
|
bf48044e5c | ||
|
|
fb4cc37a4a | ||
|
|
55b8d89a79 | ||
|
|
6968a32a5a | ||
|
|
cfe6753349 | ||
|
|
5ae15b3af3 | ||
|
|
b79adb706c | ||
|
|
f22497d5da | ||
|
|
95d672c9df | ||
|
|
7d08a609e6 | ||
|
|
eea6120cd0 | ||
|
|
0cb02bd906 | ||
|
|
08d3867f41 | ||
|
|
b16d63643c | ||
|
|
940d01bdea | ||
|
|
ba9158d159 | ||
|
|
ca9a7e11ef | ||
|
|
a803f47685 | ||
|
|
79fed32f01 | ||
|
|
6b00bb0a66 | ||
|
|
e2adef1eea | ||
|
|
9e5fa11792 | ||
|
|
1ff75acb31 | ||
|
|
1754160686 | ||
|
|
423f6266fb | ||
|
|
16d1b4a14a | ||
|
|
7c14056faf | ||
|
|
62e37dc2e2 | ||
|
|
6a08695ee8 | ||
|
|
9a67a8e427 | ||
|
|
73aa0785ba | ||
|
|
53c1016a8e | ||
|
|
fd442138e6 | ||
|
|
be5f30225a | ||
|
|
7467e9fb8c | ||
|
|
2390c2e46e | ||
|
|
778c223176 | ||
|
|
36cd0dd85c | ||
|
|
09a1d5a02d | ||
|
|
7c996ac9b5 | ||
|
|
cf9fd5d960 | ||
|
|
1c5ab7cb8f | ||
|
|
aaad3b25a7 | ||
|
|
9904235a2f | ||
|
|
780e9f57a5 | ||
|
|
a8db73285b | ||
|
|
3b43c00d12 | ||
|
|
2f390e1794 | ||
|
|
3630ebb3ae | ||
|
|
260c46df04 | ||
|
|
7f11e3205d | ||
|
|
1c8f92a96f | ||
|
|
7b6294b624 | ||
|
|
156d0b1fef | ||
|
|
2cf00dba58 | ||
|
|
d2a7f3ae36 | ||
|
|
6a64d4e4dd | ||
|
|
51e63c246b | ||
|
|
99e6b1eda4 | ||
|
|
dc26a5a436 | ||
|
|
3883b2fb41 | ||
|
|
ed58659a01 | ||
|
|
5190923c70 | ||
|
|
7c647dd160 | ||
|
|
07e59b2708 | ||
|
|
0a3a9f977d | ||
|
|
2f263bf7e6 | ||
|
|
f65f4fc280 | ||
|
|
adbd7ab4c3 | ||
|
|
0419834482 | ||
|
|
f797d2d9cb | ||
|
|
5ae7efe8f7 | ||
|
|
d6e35bd0fe | ||
|
|
0e00f1c8f7 | ||
|
|
4433f44a12 | ||
|
|
7504e718d7 | ||
|
|
9b0387e7ee | ||
|
|
5ccce1ab3f | ||
|
|
e366fe340e | ||
|
|
b01809f8e3 | ||
|
|
790ef39187 | ||
|
|
3af16cf333 | ||
|
|
d09c69f303 | ||
|
|
096d4ac529 | ||
|
|
8fafde614a | ||
|
|
694ae13418 | ||
|
|
b5b7dd4f53 | ||
|
|
476785b122 | ||
|
|
907677f835 | ||
|
|
7d844b9410 | ||
|
|
eeabc64a73 | ||
|
|
5da2b0fdcc | ||
|
|
a0005a604e | ||
|
|
a89bb807a6 | ||
|
|
28f3354ffa | ||
|
|
562923c600 | ||
|
|
0dd0c67b3b | ||
|
|
ca33849f31 | ||
|
|
18cd0f1480 | ||
|
|
b02982f6b1 | ||
|
|
4d89ae27ef | ||
|
|
733ea77c5c | ||
|
|
92f72bfce6 | ||
|
|
bffb25bea7 | ||
|
|
3af4543e80 | ||
|
|
146774860b | ||
|
|
5243481316 | ||
|
|
76a39c1dcb | ||
|
|
02ce918114 | ||
|
|
30cfc22cb6 | ||
|
|
3168afbfcb | ||
|
|
a73ee47557 | ||
|
|
fa6ff005f2 | ||
|
|
095379fa60 | ||
|
|
30572fe1b8 | ||
|
|
3a6f364b03 | ||
|
|
5345d716ee | ||
|
|
f882c36e0a | ||
|
|
e95cfa1a00 | ||
|
|
0d480071b6 | ||
|
|
8e0b7b6c25 | ||
|
|
f204da0d68 | ||
|
|
7d74904d62 | ||
|
|
760ac5e07d | ||
|
|
4352228797 | ||
|
|
74c770609c | ||
|
|
f4ca36ed7e | ||
|
|
c86da92fc6 | ||
|
|
3f0c577456 | ||
|
|
717da8c7b7 | ||
|
|
a0a61d4f47 | ||
|
|
5b1fced872 | ||
|
|
c98dcf5ef9 | ||
|
|
57cb6bfccb | ||
|
|
95bf97dc3c | ||
|
|
3d116c9d33 | ||
|
|
a9ce9f8d5a | ||
|
|
10b981a855 | ||
|
|
7700b4333d | ||
|
|
7d0131111e | ||
|
|
1daea35e4b | ||
|
|
f97544af0d | ||
|
|
231e80cc15 | ||
|
|
a4c1362bff | ||
|
|
b611d4a751 | ||
|
|
2c9decfa55 | ||
|
|
3c5ac17e2f | ||
|
|
ae42bbb898 | ||
|
|
b86722394b | ||
|
|
a103f69767 | ||
|
|
73fbb3fc62 | ||
|
|
7b3523e25e | ||
|
|
6e4e1386e7 | ||
|
|
671e9af6eb | ||
|
|
50f42caf94 | ||
|
|
b7eeefc102 | ||
|
|
8dd22f3a4f | ||
|
|
4b89427447 | ||
|
|
b71e2860cf | ||
|
|
160b27bc60 | ||
|
|
c084386b88 | ||
|
|
6889047350 | ||
|
|
245bbb4acf | ||
|
|
2b2fc02d83 | ||
|
|
703ef29199 | ||
|
|
b0b60b938a | ||
|
|
e3a026bf1c | ||
|
|
94503465ee | ||
|
|
8d959b0abc | ||
|
|
1d8390b935 | ||
|
|
2851e38a1f | ||
|
|
51261fe7a9 | ||
|
|
304321d019 | ||
|
|
f8c3295645 | ||
|
|
183619d1e1 | ||
|
|
3b832d1f21 | ||
|
|
fcb849698f | ||
|
|
7527e0ebdb | ||
|
|
ed5f98da5b | ||
|
|
12b38e25da | ||
|
|
626e892e3b |
@@ -39,7 +39,7 @@ jobs:
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
fi
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name "combined" -not -name ".git*" | sort)
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name ".git*" | sort)
|
||||
|
||||
echo ""
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
|
||||
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -43,5 +43,5 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
|
||||
|
||||
|
||||
14
.github/workflows/golang-test-linux.yml
vendored
14
.github/workflows/golang-test-linux.yml
vendored
@@ -97,16 +97,6 @@ jobs:
|
||||
working-directory: relay
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||
|
||||
- name: Build combined
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: combined
|
||||
run: CGO_ENABLED=1 go build .
|
||||
|
||||
- name: Build combined 386
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: combined
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o combined-386 .
|
||||
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
needs: [build-cache]
|
||||
@@ -154,7 +144,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
@@ -214,7 +204,7 @@ jobs:
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
|
||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' })" >> $env:GITHUB_ENV
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
|
||||
16
.github/workflows/release.yml
vendored
16
.github/workflows/release.yml
vendored
@@ -160,7 +160,7 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Log in to the GitHub container registry
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
@@ -176,7 +176,6 @@ jobs:
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
@@ -186,19 +185,6 @@ jobs:
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
- name: Tag and push PR images (amd64 only)
|
||||
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
|
||||
run: |
|
||||
PR_TAG="pr-${{ github.event.pull_request.number }}"
|
||||
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||
grep '^ghcr.io/' | while read -r SRC; do
|
||||
IMG_NAME="${SRC%%:*}"
|
||||
DST="${IMG_NAME}:${PR_TAG}"
|
||||
echo "Tagging ${SRC} -> ${DST}"
|
||||
docker tag "$SRC" "$DST"
|
||||
docker push "$DST"
|
||||
done
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
|
||||
@@ -140,20 +140,6 @@ builds:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-proxy
|
||||
dir: proxy/cmd/proxy
|
||||
env: [CGO_ENABLED=0]
|
||||
binary: netbird-proxy
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird
|
||||
|
||||
@@ -603,55 +589,6 @@ dockers:
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
docker_manifests:
|
||||
- name_template: netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
@@ -832,30 +769,6 @@ docker_manifests:
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/reverse-proxy:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/reverse-proxy:latest
|
||||
image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
brews:
|
||||
- ids:
|
||||
- default
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -1,4 +1,4 @@
|
||||
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/, relay/ and combined/.
|
||||
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||
|
||||
BSD 3-Clause License
|
||||
|
||||
@@ -1,19 +1,10 @@
|
||||
package android
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||
|
||||
var (
|
||||
// EnvKeyNBForceRelay Exported for Android java client to force relay connections
|
||||
// EnvKeyNBForceRelay Exported for Android java client
|
||||
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
|
||||
|
||||
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
|
||||
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
|
||||
|
||||
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
|
||||
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold
|
||||
)
|
||||
|
||||
// EnvList wraps a Go map for export to Java
|
||||
|
||||
@@ -42,8 +42,6 @@ const (
|
||||
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
||||
dnsPolicyConfigConfigOptionsValue = 0x8
|
||||
|
||||
nrptMaxDomainsPerRule = 50
|
||||
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
@@ -200,11 +198,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||
// Update count even on error to ensure cleanup covers partially created rules
|
||||
r.nrptEntryCount = count
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns match policy: %w", err)
|
||||
}
|
||||
r.nrptEntryCount = count
|
||||
} else {
|
||||
r.nrptEntryCount = 0
|
||||
}
|
||||
@@ -242,33 +239,23 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
||||
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
|
||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||
for i, domain := range domains {
|
||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||
|
||||
// We need to batch domains into chunks and create one NRPT rule per batch.
|
||||
ruleIndex := 0
|
||||
for i := 0; i < len(domains); i += nrptMaxDomainsPerRule {
|
||||
end := i + nrptMaxDomainsPerRule
|
||||
if end > len(domains) {
|
||||
end = len(domains)
|
||||
singleDomain := []string{domain}
|
||||
|
||||
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
||||
}
|
||||
batchDomains := domains[i:end]
|
||||
|
||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex)
|
||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex)
|
||||
|
||||
if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil {
|
||||
return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err)
|
||||
}
|
||||
|
||||
// Increment immediately so the caller's cleanup path knows about this rule
|
||||
ruleIndex++
|
||||
|
||||
if r.gpo {
|
||||
if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil {
|
||||
return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex-1, err)
|
||||
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("added NRPT rule %d with %d domains", ruleIndex-1, len(batchDomains))
|
||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
||||
}
|
||||
|
||||
if r.gpo {
|
||||
@@ -277,8 +264,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("added %d NRPT rules for %d domains. Domain list: %v", ruleIndex, len(domains), domains)
|
||||
return ruleIndex, nil
|
||||
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
|
||||
return len(domains), nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||
// when the number of match domains decreases between configuration changes.
|
||||
// With batching enabled (50 domains per rule), we need enough domains to create multiple rules.
|
||||
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping registry integration test in short mode")
|
||||
@@ -38,60 +37,51 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
|
||||
domains125 := make([]DomainConfig, 125)
|
||||
for i := 0; i < 125; i++ {
|
||||
domains125[i] = DomainConfig{
|
||||
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||
MatchOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
config125 := HostDNSConfig{
|
||||
config5 := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: domains125,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "domain1.com", MatchOnly: true},
|
||||
{Domain: "domain2.com", MatchOnly: true},
|
||||
{Domain: "domain3.com", MatchOnly: true},
|
||||
{Domain: "domain4.com", MatchOnly: true},
|
||||
{Domain: "domain5.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err = cfg.applyDNSConfig(config125, nil)
|
||||
err = cfg.applyDNSConfig(config5, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify 3 NRPT rules exist
|
||||
assert.Equal(t, 3, cfg.nrptEntryCount, "Should create 3 NRPT rules for 125 domains")
|
||||
for i := 0; i < 3; i++ {
|
||||
// Verify all 5 entries exist
|
||||
for i := 0; i < 5; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "NRPT rule %d should exist after first config", i)
|
||||
assert.True(t, exists, "Entry %d should exist after first config", i)
|
||||
}
|
||||
|
||||
// Reduce to 75 domains which will result in 2 NRPT rules (50+25)
|
||||
domains75 := make([]DomainConfig, 75)
|
||||
for i := 0; i < 75; i++ {
|
||||
domains75[i] = DomainConfig{
|
||||
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||
MatchOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
config75 := HostDNSConfig{
|
||||
config2 := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: domains75,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "domain1.com", MatchOnly: true},
|
||||
{Domain: "domain2.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err = cfg.applyDNSConfig(config75, nil)
|
||||
err = cfg.applyDNSConfig(config2, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify first 2 NRPT rules exist
|
||||
assert.Equal(t, 2, cfg.nrptEntryCount, "Should create 2 NRPT rules for 75 domains")
|
||||
// Verify first 2 entries exist
|
||||
for i := 0; i < 2; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "NRPT rule %d should exist after second config", i)
|
||||
assert.True(t, exists, "Entry %d should exist after second config", i)
|
||||
}
|
||||
|
||||
// Verify rule 2 is cleaned up
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, 2))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "NRPT rule 2 should NOT exist after reducing to 75 domains")
|
||||
// Verify entries 2-4 are cleaned up
|
||||
for i := 2; i < 5; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
|
||||
}
|
||||
}
|
||||
|
||||
func registryKeyExists(path string) (bool, error) {
|
||||
@@ -107,106 +97,6 @@ func registryKeyExists(path string) (bool, error) {
|
||||
}
|
||||
|
||||
func cleanupRegistryKeys(*testing.T) {
|
||||
// Clean up more entries to account for batching tests with many domains
|
||||
cfg := ®istryConfigurator{nrptEntryCount: 20}
|
||||
cfg := ®istryConfigurator{nrptEntryCount: 10}
|
||||
_ = cfg.removeDNSMatchPolicies()
|
||||
}
|
||||
|
||||
// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules.
|
||||
func TestNRPTDomainBatching(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping registry integration test in short mode")
|
||||
}
|
||||
|
||||
defer cleanupRegistryKeys(t)
|
||||
cleanupRegistryKeys(t)
|
||||
|
||||
testIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
// Create a test interface registry key so updateSearchDomains doesn't fail
|
||||
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
|
||||
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
|
||||
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
|
||||
require.NoError(t, err, "Should create test interface registry key")
|
||||
testKey.Close()
|
||||
defer func() {
|
||||
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
domainCount int
|
||||
expectedRuleCount int
|
||||
}{
|
||||
{
|
||||
name: "Less than 50 domains (single rule)",
|
||||
domainCount: 30,
|
||||
expectedRuleCount: 1,
|
||||
},
|
||||
{
|
||||
name: "Exactly 50 domains (single rule)",
|
||||
domainCount: 50,
|
||||
expectedRuleCount: 1,
|
||||
},
|
||||
{
|
||||
name: "51 domains (two rules)",
|
||||
domainCount: 51,
|
||||
expectedRuleCount: 2,
|
||||
},
|
||||
{
|
||||
name: "100 domains (two rules)",
|
||||
domainCount: 100,
|
||||
expectedRuleCount: 2,
|
||||
},
|
||||
{
|
||||
name: "125 domains (three rules: 50+50+25)",
|
||||
domainCount: 125,
|
||||
expectedRuleCount: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Clean up before each subtest
|
||||
cleanupRegistryKeys(t)
|
||||
|
||||
// Generate domains
|
||||
domains := make([]DomainConfig, tc.domainCount)
|
||||
for i := 0; i < tc.domainCount; i++ {
|
||||
domains[i] = DomainConfig{
|
||||
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||
MatchOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: domains,
|
||||
}
|
||||
|
||||
err := cfg.applyDNSConfig(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify that exactly expectedRuleCount rules were created
|
||||
assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount,
|
||||
"Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount)
|
||||
|
||||
// Verify all expected rules exist
|
||||
for i := 0; i < tc.expectedRuleCount; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "NRPT rule %d should exist", i)
|
||||
}
|
||||
|
||||
// Verify no extra rules were created
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,18 +84,3 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeginBatch mock implementation of BeginBatch from Server interface
|
||||
func (m *MockServer) BeginBatch() {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// EndBatch mock implementation of EndBatch from Server interface
|
||||
func (m *MockServer) EndBatch() {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// CancelBatch mock implementation of CancelBatch from Server interface
|
||||
func (m *MockServer) CancelBatch() {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
@@ -45,9 +45,6 @@ type IosDnsManager interface {
|
||||
type Server interface {
|
||||
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
||||
DeregisterHandler(domains domain.List, priority int)
|
||||
BeginBatch()
|
||||
EndBatch()
|
||||
CancelBatch()
|
||||
Initialize() error
|
||||
Stop()
|
||||
DnsIP() netip.Addr
|
||||
@@ -90,7 +87,6 @@ type DefaultServer struct {
|
||||
currentConfigHash uint64
|
||||
handlerChain *HandlerChain
|
||||
extraDomains map[domain.Domain]int
|
||||
batchMode bool
|
||||
|
||||
mgmtCacheResolver *mgmt.Resolver
|
||||
|
||||
@@ -238,9 +234,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
|
||||
// convert to zone with simple ref counter
|
||||
s.extraDomains[toZone(domain)]++
|
||||
}
|
||||
if !s.batchMode {
|
||||
s.applyHostConfig()
|
||||
}
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||
@@ -269,41 +263,9 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
||||
delete(s.extraDomains, zone)
|
||||
}
|
||||
}
|
||||
if !s.batchMode {
|
||||
s.applyHostConfig()
|
||||
}
|
||||
}
|
||||
|
||||
// BeginBatch starts batch mode for DNS handler registration/deregistration.
|
||||
// In batch mode, applyHostConfig() is not called after each handler operation,
|
||||
// allowing multiple handlers to be registered/deregistered efficiently.
|
||||
// Must be followed by EndBatch() to apply the accumulated changes.
|
||||
func (s *DefaultServer) BeginBatch() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
log.Debugf("DNS batch mode enabled")
|
||||
s.batchMode = true
|
||||
}
|
||||
|
||||
// EndBatch ends batch mode and applies all accumulated DNS configuration changes.
|
||||
func (s *DefaultServer) EndBatch() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
log.Debugf("DNS batch mode disabled, applying accumulated changes")
|
||||
s.batchMode = false
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
// CancelBatch cancels batch mode without applying accumulated changes.
|
||||
// This is useful when operations fail partway through and you want to
|
||||
// discard partial state rather than applying it.
|
||||
func (s *DefaultServer) CancelBatch() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
log.Debugf("DNS batch mode cancelled, discarding accumulated changes")
|
||||
s.batchMode = false
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||
log.Debugf("deregistering handler with priority %d for %v", priority, domains)
|
||||
|
||||
@@ -561,7 +523,6 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
s.currentConfig.RouteAll = false
|
||||
}
|
||||
|
||||
// Always apply host config for management updates, regardless of batch mode
|
||||
s.applyHostConfig()
|
||||
|
||||
s.shutdownWg.Add(1)
|
||||
@@ -926,7 +887,6 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
}
|
||||
}
|
||||
|
||||
// Always apply host config when nameserver goes down, regardless of batch mode
|
||||
s.applyHostConfig()
|
||||
|
||||
go func() {
|
||||
@@ -962,7 +922,6 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||
}
|
||||
|
||||
// Always apply host config when nameserver reactivates, regardless of batch mode
|
||||
s.applyHostConfig()
|
||||
|
||||
s.updateNSState(nsGroup, nil, true)
|
||||
|
||||
@@ -18,12 +18,7 @@ func TestGetServerDns(t *testing.T) {
|
||||
t.Errorf("invalid dns server instance: %s", err)
|
||||
}
|
||||
|
||||
mockSrvB, ok := srvB.(*MockServer)
|
||||
if !ok {
|
||||
t.Errorf("returned server is not a MockServer")
|
||||
}
|
||||
|
||||
if mockSrvB != srv {
|
||||
if srvB != srv {
|
||||
t.Errorf("mismatch dns instances")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -410,7 +410,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||
}
|
||||
|
||||
func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
|
||||
func (conn *Conn) onICEStateDisconnected() {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
@@ -430,10 +430,6 @@ func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
|
||||
if conn.isReadyToUpgrade() {
|
||||
conn.Log.Infof("ICE disconnected, set Relay to active connection")
|
||||
conn.dumpState.SwitchToRelay()
|
||||
if sessionChanged {
|
||||
conn.resetEndpoint()
|
||||
}
|
||||
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
||||
@@ -761,17 +757,6 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
return wgProxy, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) resetEndpoint() {
|
||||
if !isController(conn.config) {
|
||||
return
|
||||
}
|
||||
conn.Log.Infof("reset wg endpoint")
|
||||
conn.wgWatcher.Reset()
|
||||
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
|
||||
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) isReadyToUpgrade() bool {
|
||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
|
||||
}
|
||||
|
||||
@@ -66,10 +66,6 @@ func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) RemoveEndpointAddress() error {
|
||||
return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
||||
if e.cancelFunc == nil {
|
||||
return
|
||||
|
||||
@@ -32,8 +32,6 @@ type WGWatcher struct {
|
||||
|
||||
enabled bool
|
||||
muEnabled sync.RWMutex
|
||||
|
||||
resetCh chan struct{}
|
||||
}
|
||||
|
||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||
@@ -42,7 +40,6 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
||||
wgIfaceStater: wgIfaceStater,
|
||||
peerKey: peerKey,
|
||||
stateDump: stateDump,
|
||||
resetCh: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,15 +76,6 @@ func (w *WGWatcher) IsEnabled() bool {
|
||||
return w.enabled
|
||||
}
|
||||
|
||||
// Reset signals the watcher that the WireGuard peer has been reset and a new
|
||||
// handshake is expected. This restarts the handshake timeout from scratch.
|
||||
func (w *WGWatcher) Reset() {
|
||||
select {
|
||||
case w.resetCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
|
||||
w.log.Infof("WireGuard watcher started")
|
||||
@@ -117,12 +105,6 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
|
||||
w.stateDump.WGcheckSuccess()
|
||||
|
||||
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
||||
case <-w.resetCh:
|
||||
w.log.Infof("WireGuard watcher received peer reset, restarting handshake timeout")
|
||||
lastHandshake = time.Time{}
|
||||
enabledTime = time.Now()
|
||||
timer.Stop()
|
||||
timer.Reset(wgHandshakeOvertime)
|
||||
case <-ctx.Done():
|
||||
w.log.Infof("WireGuard watcher stopped")
|
||||
return
|
||||
|
||||
@@ -52,9 +52,8 @@ type WorkerICE struct {
|
||||
// increase by one when disconnecting the agent
|
||||
// with it the remote peer can discard the already deprecated offer/answer
|
||||
// Without it the remote peer may recreate a workable ICE connection
|
||||
sessionID ICESessionID
|
||||
remoteSessionChanged bool
|
||||
muxAgent sync.Mutex
|
||||
sessionID ICESessionID
|
||||
muxAgent sync.Mutex
|
||||
|
||||
localUfrag string
|
||||
localPwd string
|
||||
@@ -107,7 +106,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
return
|
||||
}
|
||||
w.log.Debugf("agent already exists, recreate the connection")
|
||||
w.remoteSessionChanged = true
|
||||
w.agentDialerCancel()
|
||||
if w.agent != nil {
|
||||
if err := w.agent.Close(); err != nil {
|
||||
@@ -308,17 +306,13 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
||||
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
|
||||
}
|
||||
|
||||
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) bool {
|
||||
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) {
|
||||
cancel()
|
||||
if err := agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
|
||||
w.muxAgent.Lock()
|
||||
defer w.muxAgent.Unlock()
|
||||
|
||||
sessionChanged := w.remoteSessionChanged
|
||||
w.remoteSessionChanged = false
|
||||
|
||||
if w.agent == agent {
|
||||
// consider to remove from here and move to the OnNewOffer
|
||||
@@ -331,7 +325,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
||||
w.agentConnecting = false
|
||||
w.remoteSessionID = ""
|
||||
}
|
||||
return sessionChanged
|
||||
w.muxAgent.Unlock()
|
||||
}
|
||||
|
||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||
@@ -432,11 +426,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||
|
||||
sessionChanged := w.closeAgent(agent, dialerCancel)
|
||||
w.closeAgent(agent, dialerCancel)
|
||||
|
||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.onICEStateDisconnected(sessionChanged)
|
||||
w.conn.onICEStateDisconnected()
|
||||
}
|
||||
default:
|
||||
return
|
||||
|
||||
@@ -346,23 +346,6 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
// Begin batch mode to avoid calling applyHostConfig() after each DNS handler operation
|
||||
batchStarted := false
|
||||
if m.dnsServer != nil {
|
||||
m.dnsServer.BeginBatch()
|
||||
batchStarted = true
|
||||
defer func() {
|
||||
if merr != nil {
|
||||
// On error, cancel batch to discard partial DNS state
|
||||
m.dnsServer.CancelBatch()
|
||||
} else {
|
||||
// On success, apply accumulated DNS changes
|
||||
m.dnsServer.EndBatch()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for id, handler := range toRemove {
|
||||
if err := handler.RemoveRoute(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
|
||||
@@ -393,7 +376,6 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
||||
m.activeRoutes[id] = handler
|
||||
}
|
||||
|
||||
_ = batchStarted // Mark as used
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,10 +2,7 @@
|
||||
|
||||
package NetBirdSDK
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||
|
||||
// EnvList is an exported struct to be bound by gomobile
|
||||
type EnvList struct {
|
||||
@@ -35,13 +32,3 @@ func (el *EnvList) AllItems() map[string]string {
|
||||
func GetEnvKeyNBForceRelay() string {
|
||||
return peer.EnvKeyNBForceRelay
|
||||
}
|
||||
|
||||
// GetEnvKeyNBLazyConn Exports the environment variable for the iOS client
|
||||
func GetEnvKeyNBLazyConn() string {
|
||||
return lazyconn.EnvEnableLazyConn
|
||||
}
|
||||
|
||||
// GetEnvKeyNBInactivityThreshold Exports the environment variable for the iOS client
|
||||
func GetEnvKeyNBInactivityThreshold() string {
|
||||
return lazyconn.EnvInactivityThreshold
|
||||
}
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
FROM golang:1.25-bookworm AS builder
|
||||
WORKDIR /app
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update && apt-get install -y gcc libc6-dev git && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
# Build with version info from git (matching goreleaser ldflags)
|
||||
RUN CGO_ENABLED=1 GOOS=linux go build \
|
||||
-ldflags="-s -w \
|
||||
-X github.com/netbirdio/netbird/version.version=$(git describe --tags --always --dirty 2>/dev/null || echo 'dev') \
|
||||
-X main.commit=$(git rev-parse --short HEAD 2>/dev/null || echo 'unknown') \
|
||||
-X main.date=$(date -u +%Y-%m-%dT%H:%M:%SZ) \
|
||||
-X main.builtBy=docker" \
|
||||
-o netbird-server ./combined
|
||||
|
||||
FROM ubuntu:24.04
|
||||
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||
ENTRYPOINT [ "/go/bin/netbird-server" ]
|
||||
CMD ["--config", "/etc/netbird/config.yaml"]
|
||||
COPY --from=builder /app/netbird-server /go/bin/netbird-server
|
||||
661
combined/LICENSE
661
combined/LICENSE
@@ -1,661 +0,0 @@
|
||||
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||
Version 3, 19 November 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU Affero General Public License is a free, copyleft license for
|
||||
software and other kinds of works, specifically designed to ensure
|
||||
cooperation with the community in the case of network server software.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
our General Public Licenses are intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
Developers that use our General Public Licenses protect your rights
|
||||
with two steps: (1) assert copyright on the software, and (2) offer
|
||||
you this License which gives you legal permission to copy, distribute
|
||||
and/or modify the software.
|
||||
|
||||
A secondary benefit of defending all users' freedom is that
|
||||
improvements made in alternate versions of the program, if they
|
||||
receive widespread use, become available for other developers to
|
||||
incorporate. Many developers of free software are heartened and
|
||||
encouraged by the resulting cooperation. However, in the case of
|
||||
software used on network servers, this result may fail to come about.
|
||||
The GNU General Public License permits making a modified version and
|
||||
letting the public access it on a server without ever releasing its
|
||||
source code to the public.
|
||||
|
||||
The GNU Affero General Public License is designed specifically to
|
||||
ensure that, in such cases, the modified source code becomes available
|
||||
to the community. It requires the operator of a network server to
|
||||
provide the source code of the modified version running there to the
|
||||
users of that server. Therefore, public use of a modified version, on
|
||||
a publicly accessible server, gives the public access to the source
|
||||
code of the modified version.
|
||||
|
||||
An older license, called the Affero General Public License and
|
||||
published by Affero, was designed to accomplish similar goals. This is
|
||||
a different license, not a version of the Affero GPL, but Affero has
|
||||
released a new version of the Affero GPL which permits relicensing under
|
||||
this license.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, if you modify the
|
||||
Program, your modified version must prominently offer all users
|
||||
interacting with it remotely through a computer network (if your version
|
||||
supports such interaction) an opportunity to receive the Corresponding
|
||||
Source of your version by providing access to the Corresponding Source
|
||||
from a network server at no charge, through some standard or customary
|
||||
means of facilitating copying of software. This Corresponding Source
|
||||
shall include the Corresponding Source for any work covered by version 3
|
||||
of the GNU General Public License that is incorporated pursuant to the
|
||||
following paragraph.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the work with which it is combined will remain governed by version
|
||||
3 of the GNU General Public License.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU Affero General Public License from time to time. Such new versions
|
||||
will be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU Affero General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU Affero General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU Affero General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If your software can interact with users remotely through a computer
|
||||
network, you should also make sure that it provides a way for users to
|
||||
get its source. For example, if your program is a web application, its
|
||||
interface could display a "Source" link that leads users to an archive
|
||||
of the code. There are many ways you could offer source, and different
|
||||
solutions will be better for different programs; see section 13 for the
|
||||
specific requirements.
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
@@ -627,15 +627,7 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
||||
|
||||
// Set HTTP config fields for embedded IDP
|
||||
httpConfig.AuthIssuer = mgmt.Auth.Issuer
|
||||
httpConfig.AuthAudience = "netbird-dashboard"
|
||||
httpConfig.AuthClientID = httpConfig.AuthAudience
|
||||
httpConfig.CLIAuthAudience = "netbird-cli"
|
||||
httpConfig.AuthUserIDClaim = "sub"
|
||||
httpConfig.AuthKeysLocation = mgmt.Auth.Issuer + "/keys"
|
||||
httpConfig.OIDCConfigEndpoint = mgmt.Auth.Issuer + "/.well-known/openid-configuration"
|
||||
httpConfig.IdpSignKeyRefreshEnabled = mgmt.Auth.SignKeyRefreshEnabled
|
||||
callbackURL := strings.TrimSuffix(httpConfig.AuthIssuer, "/oauth2")
|
||||
httpConfig.AuthCallbackURL = callbackURL + types.ProxyCallbackEndpointFull
|
||||
|
||||
return &nbconfig.Config{
|
||||
Stuns: stuns,
|
||||
|
||||
@@ -62,8 +62,6 @@ Configuration is loaded from a YAML file specified with --config.`,
|
||||
func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
|
||||
_ = rootCmd.MarkPersistentFlagRequired("config")
|
||||
|
||||
rootCmd.AddCommand(newTokenCommands())
|
||||
}
|
||||
|
||||
func Execute() error {
|
||||
|
||||
@@ -4,49 +4,93 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"strconv"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// newTokenCommands creates the token command tree with combined-specific store opener.
|
||||
func newTokenCommands() *cobra.Command {
|
||||
return tokencmd.NewCommands(withTokenStore)
|
||||
var (
|
||||
tokenName string
|
||||
tokenExpireIn string
|
||||
tokenDatadir string
|
||||
|
||||
tokenCmd = &cobra.Command{
|
||||
Use: "token",
|
||||
Short: "Manage proxy access tokens",
|
||||
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
|
||||
}
|
||||
|
||||
tokenCreateCmd = &cobra.Command{
|
||||
Use: "create",
|
||||
Short: "Create a new proxy access token",
|
||||
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
|
||||
RunE: tokenCreateRun,
|
||||
}
|
||||
|
||||
tokenListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List all proxy access tokens",
|
||||
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
|
||||
RunE: tokenListRun,
|
||||
}
|
||||
|
||||
tokenRevokeCmd = &cobra.Command{
|
||||
Use: "revoke [token-id]",
|
||||
Short: "Revoke a proxy access token",
|
||||
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: tokenRevokeRun,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
||||
|
||||
tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
|
||||
tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
|
||||
tokenCreateCmd.MarkFlagRequired("name") //nolint
|
||||
|
||||
tokenCmd.AddCommand(tokenCreateCmd, tokenListCmd, tokenRevokeCmd)
|
||||
rootCmd.AddCommand(tokenCmd)
|
||||
}
|
||||
|
||||
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
|
||||
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
|
||||
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
//nolint
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
|
||||
|
||||
// Load combined server YAML config
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
if dsn := cfg.Server.Store.DSN; dsn != "" {
|
||||
switch strings.ToLower(cfg.Server.Store.Engine) {
|
||||
case "postgres":
|
||||
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||
case "mysql":
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
// Get datadir from config or override
|
||||
datadir := cfg.Server.DataDir
|
||||
if tokenDatadir != "" {
|
||||
datadir = tokenDatadir
|
||||
}
|
||||
|
||||
datadir := cfg.Management.DataDir
|
||||
engine := types.Engine(cfg.Management.Store.Engine)
|
||||
// Get store engine from config
|
||||
storeEngine := types.Engine(cfg.Server.Store.Engine)
|
||||
if storeEngine == "" {
|
||||
storeEngine = "sqlite"
|
||||
}
|
||||
|
||||
s, err := store.NewStore(ctx, engine, datadir, nil, true)
|
||||
s, err := store.NewStore(ctx, storeEngine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
@@ -58,3 +102,118 @@ func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Sto
|
||||
|
||||
return fn(ctx, s)
|
||||
}
|
||||
|
||||
func tokenCreateRun(cmd *cobra.Command, _ []string) error {
|
||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||
expiresIn, err := parseDuration(tokenExpireIn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse expiration: %w", err)
|
||||
}
|
||||
|
||||
generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI")
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
|
||||
return fmt.Errorf("save token: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("Token created successfully!") //nolint:forbidigo
|
||||
fmt.Printf("Token: %s\n", generated.PlainToken) //nolint:forbidigo
|
||||
fmt.Println() //nolint:forbidigo
|
||||
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.") //nolint:forbidigo
|
||||
fmt.Printf("Token ID: %s\n", generated.ID) //nolint:forbidigo
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func tokenListRun(cmd *cobra.Command, _ []string) error {
|
||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list tokens: %w", err)
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
fmt.Println("No proxy access tokens found.") //nolint:forbidigo
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
|
||||
fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
|
||||
|
||||
for _, t := range tokens {
|
||||
expires := "never"
|
||||
if t.ExpiresAt != nil {
|
||||
expires = t.ExpiresAt.Format("2006-01-02")
|
||||
}
|
||||
|
||||
lastUsed := "never"
|
||||
if t.LastUsed != nil {
|
||||
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
|
||||
}
|
||||
|
||||
revoked := "no"
|
||||
if t.Revoked {
|
||||
revoked = "yes"
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
t.ID,
|
||||
t.Name,
|
||||
t.CreatedAt.Format("2006-01-02"),
|
||||
expires,
|
||||
lastUsed,
|
||||
revoked,
|
||||
)
|
||||
}
|
||||
|
||||
w.Flush()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func tokenRevokeRun(cmd *cobra.Command, args []string) error {
|
||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||
tokenID := args[0]
|
||||
|
||||
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
|
||||
return fmt.Errorf("revoke token: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Token %s revoked successfully.\n", tokenID) //nolint:forbidigo
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// parseDuration parses a duration string with support for days (e.g., "30d", "365d").
|
||||
// An empty string returns zero duration (no expiration).
|
||||
func parseDuration(s string) (time.Duration, error) {
|
||||
if len(s) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if s[len(s)-1] == 'd' {
|
||||
d, err := strconv.Atoi(s[:len(s)-1])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid day format: %s", s)
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||
}
|
||||
return time.Duration(d) * 24 * time.Hour, nil
|
||||
}
|
||||
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
@@ -1,463 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
)
|
||||
|
||||
const (
|
||||
echoPort = 9000
|
||||
connectTimeout = 120 * time.Second
|
||||
startTimeout = 60 * time.Second
|
||||
stopTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type peerInfo struct {
|
||||
client *embed.Client
|
||||
tunnelIP string
|
||||
name string
|
||||
}
|
||||
|
||||
type pairStats struct {
|
||||
from string
|
||||
to string
|
||||
sent int64
|
||||
received int64
|
||||
lost int64
|
||||
rtts []time.Duration
|
||||
}
|
||||
|
||||
func (s *pairStats) summary() (avgRTT, minRTT, maxRTT time.Duration, lossPercent float64) {
|
||||
if len(s.rtts) == 0 {
|
||||
return 0, 0, 0, 100
|
||||
}
|
||||
minRTT = s.rtts[0]
|
||||
maxRTT = s.rtts[0]
|
||||
var total time.Duration
|
||||
for _, rtt := range s.rtts {
|
||||
total += rtt
|
||||
if rtt < minRTT {
|
||||
minRTT = rtt
|
||||
}
|
||||
if rtt > maxRTT {
|
||||
maxRTT = rtt
|
||||
}
|
||||
}
|
||||
avgRTT = total / time.Duration(len(s.rtts))
|
||||
if s.sent > 0 {
|
||||
lossPercent = float64(s.lost) / float64(s.sent) * 100
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func main() {
|
||||
managementURL := flag.String("management-url", "", "Management server URL (required)")
|
||||
setupKey := flag.String("setup-key", "", "Reusable setup key (required)")
|
||||
numPeers := flag.Int("peers", 5, "Number of peers to spawn")
|
||||
forceRelay := flag.Bool("force-relay", false, "Force relay connections (NB_FORCE_RELAY=true)")
|
||||
duration := flag.Duration("duration", 30*time.Second, "Traffic test duration")
|
||||
packetSize := flag.Int("packet-size", 512, "UDP packet size in bytes")
|
||||
logLevel := flag.String("log-level", "panic", "Client log level (trace, debug, info, warn, error, panic)")
|
||||
flag.Parse()
|
||||
|
||||
if *managementURL == "" || *setupKey == "" {
|
||||
fmt.Fprintln(os.Stderr, "Error: --management-url and --setup-key are required")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *numPeers < 2 {
|
||||
fmt.Fprintln(os.Stderr, "Error: --peers must be at least 2")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Minimum packet size: 8 bytes for timestamp + 8 bytes for sequence number
|
||||
if *packetSize < 16 {
|
||||
fmt.Fprintln(os.Stderr, "Error: --packet-size must be at least 16")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *forceRelay {
|
||||
os.Setenv("NB_FORCE_RELAY", "true")
|
||||
}
|
||||
os.Setenv("NB_USE_NETSTACK_MODE", "true")
|
||||
|
||||
fmt.Println("=== NetBird Performance Test ===")
|
||||
fmt.Printf("Management URL: %s\n", *managementURL)
|
||||
fmt.Printf("Peers: %d\n", *numPeers)
|
||||
fmt.Printf("Force relay: %v\n", *forceRelay)
|
||||
fmt.Printf("Duration: %s\n", *duration)
|
||||
fmt.Printf("Packet size: %d bytes\n", *packetSize)
|
||||
fmt.Println()
|
||||
|
||||
// Phase 1: Create peers
|
||||
fmt.Println("--- Phase 1: Creating peers ---")
|
||||
peers := make([]peerInfo, *numPeers)
|
||||
for i := 0; i < *numPeers; i++ {
|
||||
name := fmt.Sprintf("perf-peer-%d", i)
|
||||
port := 0
|
||||
c, err := embed.New(embed.Options{
|
||||
DeviceName: name,
|
||||
SetupKey: *setupKey,
|
||||
ManagementURL: *managementURL,
|
||||
WireguardPort: &port,
|
||||
LogLevel: *logLevel,
|
||||
LogOutput: io.Discard,
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating peer %s: %v\n", name, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
peers[i] = peerInfo{client: c, name: name}
|
||||
fmt.Printf(" Created %s\n", name)
|
||||
}
|
||||
|
||||
// Phase 2: Start peers in parallel
|
||||
fmt.Println("\n--- Phase 2: Starting peers ---")
|
||||
startTime := time.Now()
|
||||
var wg sync.WaitGroup
|
||||
startErrors := make([]error, *numPeers)
|
||||
|
||||
for i := range peers {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), startTimeout)
|
||||
defer cancel()
|
||||
t := time.Now()
|
||||
if err := peers[idx].client.Start(ctx); err != nil {
|
||||
startErrors[idx] = err
|
||||
return
|
||||
}
|
||||
fmt.Printf(" %s started in %s\n", peers[idx].name, time.Since(t).Round(time.Millisecond))
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Check for start errors
|
||||
var failed bool
|
||||
for i, err := range startErrors {
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, " Error starting %s: %v\n", peers[i].name, err)
|
||||
failed = true
|
||||
}
|
||||
}
|
||||
if failed {
|
||||
cleanup(peers)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf(" All peers started in %s\n", time.Since(startTime).Round(time.Millisecond))
|
||||
|
||||
// Get tunnel IPs
|
||||
for i := range peers {
|
||||
status, err := peers[i].client.Status()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, " Error getting status for %s: %v\n", peers[i].name, err)
|
||||
cleanup(peers)
|
||||
os.Exit(1)
|
||||
}
|
||||
ip := status.LocalPeerState.IP
|
||||
// Strip CIDR suffix if present (e.g. "100.64.0.1/16" -> "100.64.0.1")
|
||||
if idx := strings.Index(ip, "/"); idx != -1 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
peers[i].tunnelIP = ip
|
||||
fmt.Printf(" %s -> %s\n", peers[i].name, peers[i].tunnelIP)
|
||||
}
|
||||
|
||||
// Phase 3: Wait for connections
|
||||
fmt.Println("\n--- Phase 3: Waiting for peer connections ---")
|
||||
connStart := time.Now()
|
||||
expectedPeers := *numPeers - 1
|
||||
deadline := time.After(connectTimeout)
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
allConnected := false
|
||||
waitLoop:
|
||||
for {
|
||||
select {
|
||||
case <-deadline:
|
||||
fmt.Fprintf(os.Stderr, " Timeout waiting for connections after %s\n", connectTimeout)
|
||||
printConnectionStatus(peers)
|
||||
cleanup(peers)
|
||||
os.Exit(1)
|
||||
case <-ticker.C:
|
||||
allConnected = true
|
||||
for i := range peers {
|
||||
connected := countConnectedPeers(peers[i].client)
|
||||
if connected < expectedPeers {
|
||||
allConnected = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allConnected {
|
||||
break waitLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf(" All peers connected in %s\n", time.Since(connStart).Round(time.Millisecond))
|
||||
printConnectionStatus(peers)
|
||||
|
||||
// Phase 4: Traffic test
|
||||
fmt.Printf("\n--- Phase 4: Traffic test (%s) ---\n", *duration)
|
||||
|
||||
// Start echo listeners on all peers
|
||||
listeners := make([]net.PacketConn, *numPeers)
|
||||
for i := range peers {
|
||||
conn, err := peers[i].client.ListenUDP(fmt.Sprintf(":%d", echoPort))
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, " Error creating listener on %s: %v\n", peers[i].name, err)
|
||||
cleanup(peers)
|
||||
os.Exit(1)
|
||||
}
|
||||
listeners[i] = conn
|
||||
go echoServer(conn, *packetSize)
|
||||
fmt.Printf(" Echo listener started on %s:%d\n", peers[i].tunnelIP, echoPort)
|
||||
}
|
||||
|
||||
// Run traffic between all pairs (i < j)
|
||||
var statsMu sync.Mutex
|
||||
var allStats []pairStats
|
||||
|
||||
var trafficWg sync.WaitGroup
|
||||
for i := 0; i < *numPeers; i++ {
|
||||
for j := i + 1; j < *numPeers; j++ {
|
||||
trafficWg.Add(1)
|
||||
go func(from, to int) {
|
||||
defer trafficWg.Done()
|
||||
stats := runTraffic(peers[from].client, peers[from].name, peers[to].tunnelIP, peers[to].name, *duration, *packetSize)
|
||||
statsMu.Lock()
|
||||
allStats = append(allStats, stats)
|
||||
statsMu.Unlock()
|
||||
}(i, j)
|
||||
}
|
||||
}
|
||||
trafficWg.Wait()
|
||||
|
||||
// Close listeners
|
||||
for _, l := range listeners {
|
||||
if l != nil {
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 5: Report
|
||||
fmt.Println("\n--- Phase 5: Results ---")
|
||||
printReport(allStats)
|
||||
|
||||
// Cleanup
|
||||
fmt.Println("\n--- Cleanup ---")
|
||||
cleanup(peers)
|
||||
fmt.Println("Done.")
|
||||
}
|
||||
|
||||
func countConnectedPeers(c *embed.Client) int {
|
||||
status, err := c.Status()
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
count := 0
|
||||
for _, p := range status.Peers {
|
||||
if p.ConnStatus == embed.PeerStatusConnected {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func printConnectionStatus(peers []peerInfo) {
|
||||
for i := range peers {
|
||||
status, err := peers[i].client.Status()
|
||||
if err != nil {
|
||||
fmt.Printf(" %s: error getting status: %v\n", peers[i].name, err)
|
||||
continue
|
||||
}
|
||||
connected := 0
|
||||
relayed := 0
|
||||
for _, p := range status.Peers {
|
||||
if p.ConnStatus == embed.PeerStatusConnected {
|
||||
connected++
|
||||
if p.Relayed {
|
||||
relayed++
|
||||
}
|
||||
}
|
||||
}
|
||||
connType := "direct"
|
||||
if relayed > 0 {
|
||||
connType = fmt.Sprintf("%d direct, %d relayed", connected-relayed, relayed)
|
||||
}
|
||||
fmt.Printf(" %s: %d/%d connected (%s)\n", peers[i].name, connected, len(status.Peers), connType)
|
||||
}
|
||||
}
|
||||
|
||||
func echoServer(conn net.PacketConn, maxSize int) {
|
||||
buf := make([]byte, maxSize+100)
|
||||
for {
|
||||
n, addr, err := conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = conn.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}
|
||||
|
||||
func runTraffic(client *embed.Client, fromName, toIP, toName string, duration time.Duration, packetSize int) pairStats {
|
||||
stats := pairStats{
|
||||
from: fromName,
|
||||
to: toName,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), duration+10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := client.Dial(ctx, "udp", fmt.Sprintf("%s:%d", toIP, echoPort))
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, " Error dialing %s -> %s: %v\n", fromName, toName, err)
|
||||
return stats
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
deadline := time.Now().Add(duration)
|
||||
buf := make([]byte, packetSize)
|
||||
recvBuf := make([]byte, packetSize+100)
|
||||
var seq uint64
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
seq++
|
||||
// Encode timestamp and sequence number
|
||||
binary.BigEndian.PutUint64(buf[0:8], uint64(time.Now().UnixNano()))
|
||||
binary.BigEndian.PutUint64(buf[8:16], seq)
|
||||
|
||||
stats.sent++
|
||||
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
_, err := conn.Write(buf)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, " Error sending packet to %s: %v\n", toName, err)
|
||||
stats.lost++
|
||||
continue
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
n, err := conn.Read(recvBuf)
|
||||
if err != nil {
|
||||
stats.lost++
|
||||
continue
|
||||
}
|
||||
|
||||
if n >= 8 {
|
||||
sentNano := binary.BigEndian.Uint64(recvBuf[0:8])
|
||||
rtt := time.Since(time.Unix(0, int64(sentNano)))
|
||||
stats.received++
|
||||
stats.rtts = append(stats.rtts, rtt)
|
||||
} else {
|
||||
stats.received++
|
||||
}
|
||||
|
||||
// Small sleep to avoid flooding
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
func printReport(allStats []pairStats) {
|
||||
if len(allStats) == 0 {
|
||||
fmt.Println(" No traffic data collected.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf(" %-30s %8s %8s %8s %8s %10s %10s %10s\n",
|
||||
"Pair", "Sent", "Recv", "Lost", "Loss%", "Avg RTT", "Min RTT", "Max RTT")
|
||||
fmt.Println(" " + strings.Repeat("-", 108))
|
||||
|
||||
var totalSent, totalRecv, totalLost int64
|
||||
var totalRTTs []time.Duration
|
||||
|
||||
for _, s := range allStats {
|
||||
avg, min, max, loss := s.summary()
|
||||
pair := fmt.Sprintf("%s -> %s", s.from, s.to)
|
||||
fmt.Printf(" %-30s %8d %8d %8d %7.1f%% %10s %10s %10s\n",
|
||||
pair, s.sent, s.received, s.lost, loss,
|
||||
avg.Round(time.Microsecond), min.Round(time.Microsecond), max.Round(time.Microsecond))
|
||||
totalSent += s.sent
|
||||
totalRecv += s.received
|
||||
totalLost += s.lost
|
||||
totalRTTs = append(totalRTTs, s.rtts...)
|
||||
}
|
||||
|
||||
fmt.Println(" " + strings.Repeat("-", 108))
|
||||
|
||||
// Overall summary
|
||||
var overallLoss float64
|
||||
if totalSent > 0 {
|
||||
overallLoss = float64(totalLost) / float64(totalSent) * 100
|
||||
}
|
||||
|
||||
var avgRTT, minRTT, maxRTT time.Duration
|
||||
if len(totalRTTs) > 0 {
|
||||
minRTT = totalRTTs[0]
|
||||
maxRTT = totalRTTs[0]
|
||||
var total time.Duration
|
||||
for _, rtt := range totalRTTs {
|
||||
total += rtt
|
||||
if rtt < minRTT {
|
||||
minRTT = rtt
|
||||
}
|
||||
if rtt > maxRTT {
|
||||
maxRTT = rtt
|
||||
}
|
||||
}
|
||||
avgRTT = total / time.Duration(len(totalRTTs))
|
||||
}
|
||||
|
||||
fmt.Printf(" %-30s %8d %8d %8d %7.1f%% %10s %10s %10s\n",
|
||||
"TOTAL", totalSent, totalRecv, totalLost, overallLoss,
|
||||
avgRTT.Round(time.Microsecond), minRTT.Round(time.Microsecond), maxRTT.Round(time.Microsecond))
|
||||
|
||||
// Extra stats
|
||||
if len(totalRTTs) > 0 {
|
||||
fmt.Println()
|
||||
var sumSq float64
|
||||
avgNs := float64(avgRTT.Nanoseconds())
|
||||
for _, rtt := range totalRTTs {
|
||||
diff := float64(rtt.Nanoseconds()) - avgNs
|
||||
sumSq += diff * diff
|
||||
}
|
||||
stddev := time.Duration(math.Sqrt(sumSq / float64(len(totalRTTs))))
|
||||
|
||||
fmt.Printf(" RTT stddev: %s\n", stddev.Round(time.Microsecond))
|
||||
fmt.Printf(" Pairs tested: %d\n", len(allStats))
|
||||
}
|
||||
}
|
||||
|
||||
func cleanup(peers []peerInfo) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), stopTimeout)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := range peers {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
if err := peers[idx].client.Stop(ctx); err != nil {
|
||||
fmt.Fprintf(os.Stderr, " Error stopping %s: %v\n", peers[idx].name, err)
|
||||
} else {
|
||||
fmt.Printf(" Stopped %s\n", peers[idx].name)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
BINARY="$SCRIPT_DIR/perftest"
|
||||
|
||||
# Defaults
|
||||
MANAGEMENT_URL="${MANAGEMENT_URL:-}"
|
||||
SETUP_KEY="${SETUP_KEY:-}"
|
||||
PEERS="${PEERS:-5}"
|
||||
DURATION="${DURATION:-30s}"
|
||||
PACKET_SIZE="${PACKET_SIZE:-512}"
|
||||
FORCE_RELAY="${FORCE_RELAY:-false}"
|
||||
LOG_LEVEL="${LOG_LEVEL:-panic}"
|
||||
|
||||
usage() {
|
||||
cat <<EOF
|
||||
Usage: MANAGEMENT_URL=... SETUP_KEY=... $0 [options]
|
||||
|
||||
Environment variables (or flags):
|
||||
MANAGEMENT_URL Management server URL (required)
|
||||
SETUP_KEY Reusable setup key (required). Use ephemeral.
|
||||
PEERS Number of peers (default: 5)
|
||||
DURATION Traffic test duration (default: 30s)
|
||||
PACKET_SIZE UDP packet size in bytes (default: 512)
|
||||
FORCE_RELAY Force relay mode (default: false)
|
||||
LOG_LEVEL Client log level (default: panic)
|
||||
|
||||
All extra arguments are passed directly to the binary.
|
||||
EOF
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [[ -z "$MANAGEMENT_URL" || -z "$SETUP_KEY" ]]; then
|
||||
echo "Error: MANAGEMENT_URL and SETUP_KEY must be set"
|
||||
echo
|
||||
usage
|
||||
fi
|
||||
|
||||
# Build
|
||||
echo "Building perftest..."
|
||||
cd "$SCRIPT_DIR"
|
||||
go build -o "$BINARY" .
|
||||
echo "Build OK: $BINARY"
|
||||
echo
|
||||
|
||||
# Run
|
||||
ARGS=(
|
||||
--management-url "$MANAGEMENT_URL"
|
||||
--setup-key "$SETUP_KEY"
|
||||
--peers "$PEERS"
|
||||
--duration "$DURATION"
|
||||
--packet-size "$PACKET_SIZE"
|
||||
--log-level "$LOG_LEVEL"
|
||||
)
|
||||
|
||||
if [[ "$FORCE_RELAY" == "true" ]]; then
|
||||
ARGS+=(--force-relay)
|
||||
fi
|
||||
|
||||
exec "$BINARY" "${ARGS[@]}" "$@"
|
||||
1
go.mod
1
go.mod
@@ -83,7 +83,6 @@ require (
|
||||
github.com/pion/stun/v3 v3.1.0
|
||||
github.com/pion/transport/v3 v3.1.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/pires/go-proxyproto v0.11.0
|
||||
github.com/pkg/sftp v1.13.9
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/quic-go/quic-go v0.55.0
|
||||
|
||||
2
go.sum
2
go.sum
@@ -474,8 +474,6 @@ github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
|
||||
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
||||
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
||||
github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8=
|
||||
github.com/pires/go-proxyproto v0.11.0 h1:gUQpS85X/VJMdUsYyEgyn59uLJvGqPhJV5YvG68wXH4=
|
||||
github.com/pires/go-proxyproto v0.11.0/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
||||
@@ -169,8 +169,7 @@ read_proxy_docker_network() {
|
||||
read_enable_proxy() {
|
||||
echo "" > /dev/stderr
|
||||
echo "Do you want to enable the NetBird Proxy service?" > /dev/stderr
|
||||
echo "The proxy allows you to selectively expose internal NetBird network resources" > /dev/stderr
|
||||
echo "to the internet. You control which resources are exposed through the dashboard." > /dev/stderr
|
||||
echo "The proxy exposes internal NetBird network resources to the internet." > /dev/stderr
|
||||
echo -n "Enable proxy? [y/N]: " > /dev/stderr
|
||||
read -r CHOICE < /dev/tty
|
||||
|
||||
@@ -182,44 +181,6 @@ read_enable_proxy() {
|
||||
return 0
|
||||
}
|
||||
|
||||
read_proxy_domain() {
|
||||
local suggested_proxy="proxy.${BASE_DOMAIN}"
|
||||
|
||||
echo "" > /dev/stderr
|
||||
echo "NOTE: The proxy domain must be different from the management domain ($NETBIRD_DOMAIN)" > /dev/stderr
|
||||
echo "to avoid TLS certificate conflicts." > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
echo "You also need to add a wildcard DNS record for the proxy domain," > /dev/stderr
|
||||
echo "e.g. *.${suggested_proxy} pointing to the same server domain as $NETBIRD_DOMAIN with a CNAME record." > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
echo -n "Enter the domain for the NetBird Proxy (e.g. ${suggested_proxy}): " > /dev/stderr
|
||||
read -r READ_PROXY_DOMAIN < /dev/tty
|
||||
|
||||
if [[ -z "$READ_PROXY_DOMAIN" ]]; then
|
||||
echo "The proxy domain cannot be empty." > /dev/stderr
|
||||
read_proxy_domain
|
||||
return
|
||||
fi
|
||||
|
||||
if [[ "$READ_PROXY_DOMAIN" == "$NETBIRD_DOMAIN" ]]; then
|
||||
echo "" > /dev/stderr
|
||||
echo "WARNING: The proxy domain cannot be the same as the management domain ($NETBIRD_DOMAIN)." > /dev/stderr
|
||||
read_proxy_domain
|
||||
return
|
||||
fi
|
||||
|
||||
echo ${READ_PROXY_DOMAIN} | grep ${NETBIRD_DOMAIN} > /dev/null
|
||||
if [[ $? -eq 0 ]]; then
|
||||
echo "" > /dev/stderr
|
||||
echo "WARNING: The proxy domain cannot be a subdomain of the management domain ($NETBIRD_DOMAIN)." > /dev/stderr
|
||||
read_proxy_domain
|
||||
return
|
||||
fi
|
||||
|
||||
echo "$READ_PROXY_DOMAIN"
|
||||
return 0
|
||||
}
|
||||
|
||||
read_traefik_acme_email() {
|
||||
echo "" > /dev/stderr
|
||||
echo "Enter your email for Let's Encrypt certificate notifications." > /dev/stderr
|
||||
@@ -316,7 +277,6 @@ initialize_default_values() {
|
||||
DASHBOARD_IMAGE="netbirdio/dashboard:latest"
|
||||
# Combined server replaces separate signal, relay, and management containers
|
||||
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:latest"
|
||||
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:latest"
|
||||
|
||||
# Reverse proxy configuration
|
||||
REVERSE_PROXY_TYPE="0"
|
||||
@@ -329,12 +289,8 @@ initialize_default_values() {
|
||||
BIND_LOCALHOST_ONLY="true"
|
||||
EXTERNAL_PROXY_NETWORK=""
|
||||
|
||||
# Traefik static IP within the internal bridge network
|
||||
TRAEFIK_IP="172.30.0.10"
|
||||
|
||||
# NetBird Proxy configuration
|
||||
ENABLE_PROXY="false"
|
||||
PROXY_DOMAIN=""
|
||||
PROXY_TOKEN=""
|
||||
return 0
|
||||
}
|
||||
@@ -346,12 +302,10 @@ configure_domain() {
|
||||
|
||||
if [[ "$NETBIRD_DOMAIN" == "use-ip" ]]; then
|
||||
NETBIRD_DOMAIN=$(get_main_ip_address)
|
||||
BASE_DOMAIN=$NETBIRD_DOMAIN
|
||||
else
|
||||
NETBIRD_PORT=443
|
||||
NETBIRD_HTTP_PROTOCOL="https"
|
||||
NETBIRD_RELAY_PROTO="rels"
|
||||
BASE_DOMAIN=$(echo $NETBIRD_DOMAIN | sed -E 's/^[^.]+\.//')
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
@@ -364,9 +318,6 @@ configure_reverse_proxy() {
|
||||
if [[ "$REVERSE_PROXY_TYPE" == "0" ]]; then
|
||||
TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email)
|
||||
ENABLE_PROXY=$(read_enable_proxy)
|
||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||
PROXY_DOMAIN=$(read_proxy_domain)
|
||||
fi
|
||||
fi
|
||||
|
||||
# Handle external Traefik-specific prompts (option 1)
|
||||
@@ -396,7 +347,7 @@ check_existing_installation() {
|
||||
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
||||
echo "You can use the following commands:"
|
||||
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
||||
echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
|
||||
echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
|
||||
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
||||
exit 1
|
||||
fi
|
||||
@@ -415,8 +366,6 @@ generate_configuration_files() {
|
||||
# This will be overwritten with the actual token after netbird-server starts
|
||||
echo "# Placeholder - will be updated with token after netbird-server starts" > proxy.env
|
||||
echo "NB_PROXY_TOKEN=placeholder" >> proxy.env
|
||||
# TCP ServersTransport for PROXY protocol v2 to the proxy backend
|
||||
render_traefik_dynamic > traefik-dynamic.yaml
|
||||
fi
|
||||
;;
|
||||
1)
|
||||
@@ -564,24 +513,21 @@ init_environment() {
|
||||
############################################
|
||||
|
||||
render_docker_compose_traefik_builtin() {
|
||||
# Generate proxy service section and Traefik dynamic config if enabled
|
||||
# Generate proxy service section if enabled
|
||||
local proxy_service=""
|
||||
local proxy_volumes=""
|
||||
local traefik_file_provider=""
|
||||
local traefik_dynamic_volume=""
|
||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||
traefik_file_provider=' - "--providers.file.filename=/etc/traefik/dynamic.yaml"'
|
||||
traefik_dynamic_volume=" - ./traefik-dynamic.yaml:/etc/traefik/dynamic.yaml:ro"
|
||||
proxy_service="
|
||||
# NetBird Proxy - exposes internal resources to the internet
|
||||
proxy:
|
||||
image: $NETBIRD_PROXY_IMAGE
|
||||
build:
|
||||
context: ../
|
||||
dockerfile: proxy/Dockerfile
|
||||
pull_policy: build
|
||||
container_name: netbird-proxy
|
||||
# Hairpin NAT fix: route domain back to traefik's static IP within Docker
|
||||
extra_hosts:
|
||||
- \"$NETBIRD_DOMAIN:$TRAEFIK_IP\"
|
||||
ports:
|
||||
- 51820:51820/udp
|
||||
- \"$NETBIRD_DOMAIN:172.30.0.10\"
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
depends_on:
|
||||
@@ -599,7 +545,6 @@ render_docker_compose_traefik_builtin() {
|
||||
- traefik.tcp.routers.proxy-passthrough.service=proxy-tls
|
||||
- traefik.tcp.routers.proxy-passthrough.priority=1
|
||||
- traefik.tcp.services.proxy-tls.loadbalancer.server.port=8443
|
||||
- traefik.tcp.services.proxy-tls.loadbalancer.serverstransport=pp-v2@file
|
||||
logging:
|
||||
driver: \"json-file\"
|
||||
options:
|
||||
@@ -619,7 +564,7 @@ services:
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
netbird:
|
||||
ipv4_address: $TRAEFIK_IP
|
||||
ipv4_address: 172.30.0.10
|
||||
command:
|
||||
# Logging
|
||||
- "--log.level=INFO"
|
||||
@@ -646,14 +591,12 @@ services:
|
||||
# gRPC transport settings
|
||||
- "--serverstransport.forwardingtimeouts.responseheadertimeout=0s"
|
||||
- "--serverstransport.forwardingtimeouts.idleconntimeout=0s"
|
||||
$traefik_file_provider
|
||||
ports:
|
||||
- '443:443'
|
||||
- '80:80'
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
- netbird_traefik_letsencrypt:/letsencrypt
|
||||
$traefik_dynamic_volume
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
@@ -685,7 +628,11 @@ $traefik_dynamic_volume
|
||||
|
||||
# Combined server (Management + Signal + Relay + STUN)
|
||||
netbird-server:
|
||||
image: $NETBIRD_SERVER_IMAGE
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: combined/Dockerfile.multistage
|
||||
pull_policy: build
|
||||
#image: $NETBIRD_SERVER_IMAGE
|
||||
container_name: netbird-server
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
@@ -763,10 +710,6 @@ server:
|
||||
cliRedirectURIs:
|
||||
- "http://localhost:53000/"
|
||||
|
||||
reverseProxy:
|
||||
trustedHTTPProxies:
|
||||
- "$TRAEFIK_IP/32"
|
||||
|
||||
store:
|
||||
engine: "sqlite"
|
||||
encryptionKey: "$DATASTORE_ENCRYPTION_KEY"
|
||||
@@ -796,17 +739,6 @@ EOF
|
||||
return 0
|
||||
}
|
||||
|
||||
render_traefik_dynamic() {
|
||||
cat <<'EOF'
|
||||
tcp:
|
||||
serversTransports:
|
||||
pp-v2:
|
||||
proxyProtocol:
|
||||
version: 2
|
||||
EOF
|
||||
return 0
|
||||
}
|
||||
|
||||
render_proxy_env() {
|
||||
cat <<EOF
|
||||
# NetBird Proxy Configuration
|
||||
@@ -816,7 +748,7 @@ NB_PROXY_MANAGEMENT_ADDRESS=http://netbird-server:80
|
||||
# Allow insecure gRPC connection to management (required for internal Docker network)
|
||||
NB_PROXY_ALLOW_INSECURE=true
|
||||
# Public URL where this proxy is reachable (used for cluster registration)
|
||||
NB_PROXY_DOMAIN=$PROXY_DOMAIN
|
||||
NB_PROXY_DOMAIN=$NETBIRD_DOMAIN
|
||||
NB_PROXY_ADDRESS=:8443
|
||||
NB_PROXY_TOKEN=$PROXY_TOKEN
|
||||
NB_PROXY_CERTIFICATE_DIRECTORY=/certs
|
||||
@@ -826,10 +758,6 @@ NB_PROXY_OIDC_CLIENT_ID=netbird-proxy
|
||||
NB_PROXY_OIDC_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
|
||||
NB_PROXY_OIDC_SCOPES=openid,profile,email
|
||||
NB_PROXY_FORWARDED_PROTO=https
|
||||
# Enable PROXY protocol to preserve client IPs through L4 proxies (Traefik TCP passthrough)
|
||||
NB_PROXY_PROXY_PROTOCOL=true
|
||||
# Trust Traefik's IP for PROXY protocol headers
|
||||
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
|
||||
EOF
|
||||
return 0
|
||||
}
|
||||
@@ -1188,30 +1116,23 @@ print_builtin_traefik_instructions() {
|
||||
echo " NETBIRD SETUP COMPLETE"
|
||||
echo "$MSG_SEPARATOR"
|
||||
echo ""
|
||||
echo "You can access the NetBird dashboard at:"
|
||||
echo " $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
||||
echo ""
|
||||
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
||||
echo "Follow the onboarding steps to set up your NetBird instance."
|
||||
echo ""
|
||||
echo "Traefik is handling TLS certificates automatically via Let's Encrypt."
|
||||
echo "If you see certificate warnings, wait a moment for certificate issuance to complete."
|
||||
echo ""
|
||||
echo "Open ports:"
|
||||
echo " - 443/tcp (HTTPS - all NetBird services)"
|
||||
echo " - 80/tcp (HTTP - redirects to HTTPS)"
|
||||
echo " - $NETBIRD_STUN_PORT/udp (STUN - required for NAT traversal)"
|
||||
echo " - 443/tcp (HTTPS - all NetBird services)"
|
||||
echo " - 80/tcp (HTTP - redirects to HTTPS)"
|
||||
echo " - $NETBIRD_STUN_PORT/udp (STUN - required for NAT traversal)"
|
||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||
echo " - 51820/udp (WIREGUARD - (optional) for P2P proxy connections)"
|
||||
echo ""
|
||||
echo "NetBird Proxy:"
|
||||
echo " The proxy service is enabled and running."
|
||||
echo " Any domain NOT matching $NETBIRD_DOMAIN will be passed through to the proxy."
|
||||
echo " The proxy handles its own TLS certificates via ACME TLS-ALPN-01 challenge."
|
||||
echo " Point your proxy domain to this server's domain address like in the examples below:"
|
||||
echo ""
|
||||
echo " $PROXY_DOMAIN CNAME $NETBIRD_DOMAIN"
|
||||
echo " *.$PROXY_DOMAIN CNAME $NETBIRD_DOMAIN"
|
||||
echo ""
|
||||
echo " Point your proxy domains (CNAMEs) to this server's IP address."
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -81,7 +81,9 @@ func init() {
|
||||
|
||||
rootCmd.AddCommand(migrationCmd)
|
||||
|
||||
tc := newTokenCommands()
|
||||
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||
rootCmd.AddCommand(tc)
|
||||
tokenCmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||
tokenCmd.AddCommand(tokenCreateCmd)
|
||||
tokenCmd.AddCommand(tokenListCmd)
|
||||
tokenCmd.AddCommand(tokenRevokeCmd)
|
||||
rootCmd.AddCommand(tokenCmd)
|
||||
}
|
||||
|
||||
@@ -3,24 +3,62 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var tokenDatadir string
|
||||
var (
|
||||
tokenName string
|
||||
tokenExpireIn string
|
||||
tokenDatadir string
|
||||
|
||||
// newTokenCommands creates the token command tree with management-specific store opener.
|
||||
func newTokenCommands() *cobra.Command {
|
||||
cmd := tokencmd.NewCommands(withTokenStore)
|
||||
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
||||
return cmd
|
||||
tokenCmd = &cobra.Command{
|
||||
Use: "token",
|
||||
Short: "Manage proxy access tokens",
|
||||
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
|
||||
}
|
||||
|
||||
tokenCreateCmd = &cobra.Command{
|
||||
Use: "create",
|
||||
Short: "Create a new proxy access token",
|
||||
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
|
||||
RunE: tokenCreateRun,
|
||||
}
|
||||
|
||||
tokenListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List all proxy access tokens",
|
||||
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
|
||||
RunE: tokenListRun,
|
||||
}
|
||||
|
||||
tokenRevokeCmd = &cobra.Command{
|
||||
Use: "revoke [token-id]",
|
||||
Short: "Revoke a proxy access token",
|
||||
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: tokenRevokeRun,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
||||
|
||||
tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
|
||||
tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
|
||||
tokenCreateCmd.MarkFlagRequired("name") //nolint
|
||||
}
|
||||
|
||||
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
|
||||
@@ -29,9 +67,10 @@ func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Sto
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
//nolint
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
|
||||
|
||||
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||
config, err := loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
@@ -53,3 +92,118 @@ func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Sto
|
||||
|
||||
return fn(ctx, s)
|
||||
}
|
||||
|
||||
func tokenCreateRun(cmd *cobra.Command, _ []string) error {
|
||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||
expiresIn, err := parseDuration(tokenExpireIn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse expiration: %w", err)
|
||||
}
|
||||
|
||||
generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI")
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
|
||||
return fmt.Errorf("save token: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("Token created successfully!") //nolint:forbidigo
|
||||
fmt.Printf("Token: %s\n", generated.PlainToken) //nolint:forbidigo
|
||||
fmt.Println() //nolint:forbidigo
|
||||
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.") //nolint:forbidigo
|
||||
fmt.Printf("Token ID: %s\n", generated.ID) //nolint:forbidigo
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func tokenListRun(cmd *cobra.Command, _ []string) error {
|
||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list tokens: %w", err)
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
fmt.Println("No proxy access tokens found.") //nolint:forbidigo
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
|
||||
fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
|
||||
|
||||
for _, t := range tokens {
|
||||
expires := "never"
|
||||
if t.ExpiresAt != nil {
|
||||
expires = t.ExpiresAt.Format("2006-01-02")
|
||||
}
|
||||
|
||||
lastUsed := "never"
|
||||
if t.LastUsed != nil {
|
||||
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
|
||||
}
|
||||
|
||||
revoked := "no"
|
||||
if t.Revoked {
|
||||
revoked = "yes"
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
t.ID,
|
||||
t.Name,
|
||||
t.CreatedAt.Format("2006-01-02"),
|
||||
expires,
|
||||
lastUsed,
|
||||
revoked,
|
||||
)
|
||||
}
|
||||
|
||||
w.Flush()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func tokenRevokeRun(cmd *cobra.Command, args []string) error {
|
||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||
tokenID := args[0]
|
||||
|
||||
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
|
||||
return fmt.Errorf("revoke token: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Token %s revoked successfully.\n", tokenID) //nolint:forbidigo
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// parseDuration parses a duration string with support for days (e.g., "30d", "365d").
|
||||
// An empty string returns zero duration (no expiration).
|
||||
func parseDuration(s string) (time.Duration, error) {
|
||||
if len(s) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if s[len(s)-1] == 'd' {
|
||||
d, err := strconv.Atoi(s[:len(s)-1])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid day format: %s", s)
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||
}
|
||||
return time.Duration(d) * 24 * time.Hour, nil
|
||||
}
|
||||
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
@@ -1,185 +0,0 @@
|
||||
// Package tokencmd provides reusable cobra commands for managing proxy access tokens.
|
||||
// Both the management and combined binaries use these commands, each providing
|
||||
// their own StoreOpener to handle config loading and store initialization.
|
||||
package tokencmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// StoreOpener initializes a store from the command context and calls fn.
|
||||
type StoreOpener func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error
|
||||
|
||||
// NewCommands creates the token command tree with the given store opener.
|
||||
// Returns the parent "token" command with create, list, and revoke subcommands.
|
||||
func NewCommands(opener StoreOpener) *cobra.Command {
|
||||
var (
|
||||
tokenName string
|
||||
tokenExpireIn string
|
||||
)
|
||||
|
||||
tokenCmd := &cobra.Command{
|
||||
Use: "token",
|
||||
Short: "Manage proxy access tokens",
|
||||
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
|
||||
}
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create",
|
||||
Short: "Create a new proxy access token",
|
||||
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, s store.Store) error {
|
||||
return runCreate(ctx, s, cmd.OutOrStdout(), tokenName, tokenExpireIn)
|
||||
})
|
||||
},
|
||||
}
|
||||
createCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
|
||||
createCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
|
||||
if err := createCmd.MarkFlagRequired("name"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List all proxy access tokens",
|
||||
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, s store.Store) error {
|
||||
return runList(ctx, s, cmd.OutOrStdout())
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
revokeCmd := &cobra.Command{
|
||||
Use: "revoke [token-id]",
|
||||
Short: "Revoke a proxy access token",
|
||||
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return opener(cmd, func(ctx context.Context, s store.Store) error {
|
||||
return runRevoke(ctx, s, cmd.OutOrStdout(), args[0])
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
tokenCmd.AddCommand(createCmd, listCmd, revokeCmd)
|
||||
return tokenCmd
|
||||
}
|
||||
|
||||
func runCreate(ctx context.Context, s store.Store, w io.Writer, name string, expireIn string) error {
|
||||
expiresIn, err := ParseDuration(expireIn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse expiration: %w", err)
|
||||
}
|
||||
|
||||
generated, err := types.CreateNewProxyAccessToken(name, expiresIn, nil, "CLI")
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
|
||||
return fmt.Errorf("save token: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(w, "Token created successfully!")
|
||||
_, _ = fmt.Fprintf(w, "Token: %s\n", generated.PlainToken)
|
||||
_, _ = fmt.Fprintln(w)
|
||||
_, _ = fmt.Fprintln(w, "IMPORTANT: Save this token now. It will not be shown again.")
|
||||
_, _ = fmt.Fprintf(w, "Token ID: %s\n", generated.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func runList(ctx context.Context, s store.Store, out io.Writer) error {
|
||||
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list tokens: %w", err)
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
_, _ = fmt.Fprintln(out, "No proxy access tokens found.")
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0)
|
||||
_, _ = fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
|
||||
_, _ = fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
|
||||
|
||||
for _, t := range tokens {
|
||||
expires := "never"
|
||||
if t.ExpiresAt != nil {
|
||||
expires = t.ExpiresAt.Format("2006-01-02")
|
||||
}
|
||||
|
||||
lastUsed := "never"
|
||||
if t.LastUsed != nil {
|
||||
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
|
||||
}
|
||||
|
||||
revoked := "no"
|
||||
if t.Revoked {
|
||||
revoked = "yes"
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
t.ID,
|
||||
t.Name,
|
||||
t.CreatedAt.Format("2006-01-02"),
|
||||
expires,
|
||||
lastUsed,
|
||||
revoked,
|
||||
)
|
||||
}
|
||||
|
||||
w.Flush()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runRevoke(ctx context.Context, s store.Store, w io.Writer, tokenID string) error {
|
||||
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
|
||||
return fmt.Errorf("revoke token: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "Token %s revoked successfully.\n", tokenID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseDuration parses a duration string with support for days (e.g., "30d", "365d").
|
||||
// An empty string returns zero duration (no expiration).
|
||||
func ParseDuration(s string) (time.Duration, error) {
|
||||
if len(s) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if s[len(s)-1] == 'd' {
|
||||
d, err := strconv.Atoi(s[:len(s)-1])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid day format: %s", s)
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||
}
|
||||
return time.Duration(d) * 24 * time.Hour, nil
|
||||
}
|
||||
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokencmd
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -89,7 +89,7 @@ func TestParseDuration(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := ParseDuration(tt.input)
|
||||
result, err := parseDuration(tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
@@ -39,63 +39,78 @@ type AccessLogFilter struct {
|
||||
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
|
||||
queryParams := r.URL.Query()
|
||||
|
||||
f.Page = parsePositiveInt(queryParams.Get("page"), 1)
|
||||
f.PageSize = min(parsePositiveInt(queryParams.Get("page_size"), DefaultPageSize), MaxPageSize)
|
||||
f.Page = 1
|
||||
if pageStr := queryParams.Get("page"); pageStr != "" {
|
||||
if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
|
||||
f.Page = page
|
||||
}
|
||||
}
|
||||
|
||||
f.Search = parseOptionalString(queryParams.Get("search"))
|
||||
f.SourceIP = parseOptionalString(queryParams.Get("source_ip"))
|
||||
f.Host = parseOptionalString(queryParams.Get("host"))
|
||||
f.Path = parseOptionalString(queryParams.Get("path"))
|
||||
f.UserID = parseOptionalString(queryParams.Get("user_id"))
|
||||
f.UserEmail = parseOptionalString(queryParams.Get("user_email"))
|
||||
f.UserName = parseOptionalString(queryParams.Get("user_name"))
|
||||
f.Method = parseOptionalString(queryParams.Get("method"))
|
||||
f.Status = parseOptionalString(queryParams.Get("status"))
|
||||
f.StatusCode = parseOptionalInt(queryParams.Get("status_code"))
|
||||
f.StartDate = parseOptionalRFC3339(queryParams.Get("start_date"))
|
||||
f.EndDate = parseOptionalRFC3339(queryParams.Get("end_date"))
|
||||
}
|
||||
f.PageSize = DefaultPageSize
|
||||
if pageSizeStr := queryParams.Get("page_size"); pageSizeStr != "" {
|
||||
if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
|
||||
f.PageSize = pageSize
|
||||
if f.PageSize > MaxPageSize {
|
||||
f.PageSize = MaxPageSize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parsePositiveInt parses a positive integer from a string, returning defaultValue if invalid
|
||||
func parsePositiveInt(s string, defaultValue int) int {
|
||||
if s == "" {
|
||||
return defaultValue
|
||||
if search := queryParams.Get("search"); search != "" {
|
||||
f.Search = &search
|
||||
}
|
||||
if val, err := strconv.Atoi(s); err == nil && val > 0 {
|
||||
return val
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// parseOptionalString returns a pointer to the string if non-empty, otherwise nil
|
||||
func parseOptionalString(s string) *string {
|
||||
if s == "" {
|
||||
return nil
|
||||
if sourceIP := queryParams.Get("source_ip"); sourceIP != "" {
|
||||
f.SourceIP = &sourceIP
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
// parseOptionalInt parses an optional positive integer from a string
|
||||
func parseOptionalInt(s string) *int {
|
||||
if s == "" {
|
||||
return nil
|
||||
if host := queryParams.Get("host"); host != "" {
|
||||
f.Host = &host
|
||||
}
|
||||
if val, err := strconv.Atoi(s); err == nil && val > 0 {
|
||||
v := val
|
||||
return &v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseOptionalRFC3339 parses an optional RFC3339 timestamp from a string
|
||||
func parseOptionalRFC3339(s string) *time.Time {
|
||||
if s == "" {
|
||||
return nil
|
||||
if path := queryParams.Get("path"); path != "" {
|
||||
f.Path = &path
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
return &t
|
||||
|
||||
if userID := queryParams.Get("user_id"); userID != "" {
|
||||
f.UserID = &userID
|
||||
}
|
||||
|
||||
if userEmail := queryParams.Get("user_email"); userEmail != "" {
|
||||
f.UserEmail = &userEmail
|
||||
}
|
||||
|
||||
if userName := queryParams.Get("user_name"); userName != "" {
|
||||
f.UserName = &userName
|
||||
}
|
||||
|
||||
if method := queryParams.Get("method"); method != "" {
|
||||
f.Method = &method
|
||||
}
|
||||
|
||||
if status := queryParams.Get("status"); status != "" {
|
||||
f.Status = &status
|
||||
}
|
||||
|
||||
if statusCodeStr := queryParams.Get("status_code"); statusCodeStr != "" {
|
||||
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil && statusCode > 0 {
|
||||
f.StatusCode = &statusCode
|
||||
}
|
||||
}
|
||||
|
||||
if startDate := queryParams.Get("start_date"); startDate != "" {
|
||||
parsedStartDate, err := time.Parse(time.RFC3339, startDate)
|
||||
if err == nil {
|
||||
f.StartDate = &parsedStartDate
|
||||
}
|
||||
}
|
||||
|
||||
if endDate := queryParams.Get("end_date"); endDate != "" {
|
||||
parsedEndDate, err := time.Parse(time.RFC3339, endDate)
|
||||
if err == nil {
|
||||
f.EndDate = &parsedEndDate
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOffset calculates the database offset for pagination
|
||||
|
||||
@@ -4,10 +4,8 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccessLogFilter_ParseFromRequest(t *testing.T) {
|
||||
@@ -161,211 +159,3 @@ func TestAccessLogFilter_GetLimit(t *testing.T) {
|
||||
limit := filter.GetLimit()
|
||||
assert.Equal(t, 25, limit, "GetLimit should return PageSize")
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_ParseFromRequest_FilterParams(t *testing.T) {
|
||||
startDate := "2024-01-15T10:30:00Z"
|
||||
endDate := "2024-01-16T15:45:00Z"
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
q := req.URL.Query()
|
||||
q.Set("search", "test query")
|
||||
q.Set("source_ip", "192.168.1.1")
|
||||
q.Set("host", "example.com")
|
||||
q.Set("path", "/api/users")
|
||||
q.Set("user_id", "user123")
|
||||
q.Set("user_email", "user@example.com")
|
||||
q.Set("user_name", "John Doe")
|
||||
q.Set("method", "GET")
|
||||
q.Set("status", "success")
|
||||
q.Set("status_code", "200")
|
||||
q.Set("start_date", startDate)
|
||||
q.Set("end_date", endDate)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
require.NotNil(t, filter.Search)
|
||||
assert.Equal(t, "test query", *filter.Search)
|
||||
|
||||
require.NotNil(t, filter.SourceIP)
|
||||
assert.Equal(t, "192.168.1.1", *filter.SourceIP)
|
||||
|
||||
require.NotNil(t, filter.Host)
|
||||
assert.Equal(t, "example.com", *filter.Host)
|
||||
|
||||
require.NotNil(t, filter.Path)
|
||||
assert.Equal(t, "/api/users", *filter.Path)
|
||||
|
||||
require.NotNil(t, filter.UserID)
|
||||
assert.Equal(t, "user123", *filter.UserID)
|
||||
|
||||
require.NotNil(t, filter.UserEmail)
|
||||
assert.Equal(t, "user@example.com", *filter.UserEmail)
|
||||
|
||||
require.NotNil(t, filter.UserName)
|
||||
assert.Equal(t, "John Doe", *filter.UserName)
|
||||
|
||||
require.NotNil(t, filter.Method)
|
||||
assert.Equal(t, "GET", *filter.Method)
|
||||
|
||||
require.NotNil(t, filter.Status)
|
||||
assert.Equal(t, "success", *filter.Status)
|
||||
|
||||
require.NotNil(t, filter.StatusCode)
|
||||
assert.Equal(t, 200, *filter.StatusCode)
|
||||
|
||||
require.NotNil(t, filter.StartDate)
|
||||
expectedStart, _ := time.Parse(time.RFC3339, startDate)
|
||||
assert.Equal(t, expectedStart, *filter.StartDate)
|
||||
|
||||
require.NotNil(t, filter.EndDate)
|
||||
expectedEnd, _ := time.Parse(time.RFC3339, endDate)
|
||||
assert.Equal(t, expectedEnd, *filter.EndDate)
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_ParseFromRequest_EmptyFilters(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Nil(t, filter.Search)
|
||||
assert.Nil(t, filter.SourceIP)
|
||||
assert.Nil(t, filter.Host)
|
||||
assert.Nil(t, filter.Path)
|
||||
assert.Nil(t, filter.UserID)
|
||||
assert.Nil(t, filter.UserEmail)
|
||||
assert.Nil(t, filter.UserName)
|
||||
assert.Nil(t, filter.Method)
|
||||
assert.Nil(t, filter.Status)
|
||||
assert.Nil(t, filter.StatusCode)
|
||||
assert.Nil(t, filter.StartDate)
|
||||
assert.Nil(t, filter.EndDate)
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_ParseFromRequest_InvalidFilters(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
q := req.URL.Query()
|
||||
q.Set("status_code", "invalid")
|
||||
q.Set("start_date", "not-a-date")
|
||||
q.Set("end_date", "2024-99-99")
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Nil(t, filter.StatusCode, "invalid status_code should be nil")
|
||||
assert.Nil(t, filter.StartDate, "invalid start_date should be nil")
|
||||
assert.Nil(t, filter.EndDate, "invalid end_date should be nil")
|
||||
}
|
||||
|
||||
func TestParsePositiveInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
defaultValue int
|
||||
expected int
|
||||
}{
|
||||
{"empty string", "", 10, 10},
|
||||
{"valid positive int", "25", 10, 25},
|
||||
{"zero", "0", 10, 10},
|
||||
{"negative", "-5", 10, 10},
|
||||
{"invalid string", "abc", 10, 10},
|
||||
{"float", "3.14", 10, 10},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parsePositiveInt(tt.input, tt.defaultValue)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOptionalString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected *string
|
||||
}{
|
||||
{"empty string", "", nil},
|
||||
{"valid string", "hello", strPtr("hello")},
|
||||
{"whitespace", " ", strPtr(" ")},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseOptionalString(tt.input)
|
||||
if tt.expected == nil {
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, *tt.expected, *result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOptionalInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected *int
|
||||
}{
|
||||
{"empty string", "", nil},
|
||||
{"valid positive int", "42", intPtr(42)},
|
||||
{"zero", "0", nil},
|
||||
{"negative", "-10", nil},
|
||||
{"invalid string", "abc", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseOptionalInt(tt.input)
|
||||
if tt.expected == nil {
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, *tt.expected, *result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOptionalRFC3339(t *testing.T) {
|
||||
validDate := "2024-01-15T10:30:00Z"
|
||||
expectedTime, _ := time.Parse(time.RFC3339, validDate)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected *time.Time
|
||||
}{
|
||||
{"empty string", "", nil},
|
||||
{"valid RFC3339", validDate, &expectedTime},
|
||||
{"invalid format", "2024-01-15", nil},
|
||||
{"invalid date", "not-a-date", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseOptionalRFC3339(tt.input)
|
||||
if tt.expected == nil {
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, *tt.expected, *result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for creating pointers
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func intPtr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
|
||||
@@ -135,11 +135,54 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
|
||||
return nil, err
|
||||
var proxyCluster string
|
||||
if m.clusterDeriver != nil {
|
||||
proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
|
||||
return nil, status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.persistNewService(ctx, accountID, service); err != nil {
|
||||
service.AccountID = accountID
|
||||
service.ProxyCluster = proxyCluster
|
||||
service.InitNewRecord()
|
||||
err = service.Auth.HashSecrets()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
// Generate session JWT signing keys
|
||||
keyPair, err := sessionkey.GenerateKeyPair()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate session keys: %w", err)
|
||||
}
|
||||
service.SessionPrivateKey = keyPair.PrivateKey
|
||||
service.SessionPublicKey = keyPair.PublicKey
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
// Check for duplicate domain
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing service: %w", err)
|
||||
}
|
||||
}
|
||||
if existingService != nil {
|
||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
|
||||
}
|
||||
|
||||
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.CreateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -157,67 +200,6 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
if m.clusterDeriver != nil {
|
||||
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
|
||||
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
|
||||
}
|
||||
service.ProxyCluster = proxyCluster
|
||||
}
|
||||
|
||||
service.AccountID = accountID
|
||||
service.InitNewRecord()
|
||||
|
||||
if err := service.Auth.HashSecrets(); err != nil {
|
||||
return fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
keyPair, err := sessionkey.GenerateKeyPair()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate session keys: %w", err)
|
||||
}
|
||||
service.SessionPrivateKey = keyPair.PrivateKey
|
||||
service.SessionPublicKey = keyPair.PublicKey
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing service: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if existingService != nil && existingService.ID != excludeServiceID {
|
||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||
if err != nil {
|
||||
@@ -227,122 +209,99 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := service.Auth.HashSecrets(); err != nil {
|
||||
var oldCluster string
|
||||
var domainChanged bool
|
||||
var serviceEnabledChanged bool
|
||||
|
||||
err = service.Auth.HashSecrets()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
updateInfo, err := m.persistServiceUpdate(ctx, accountID, service)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldCluster = existingService.ProxyCluster
|
||||
|
||||
if existingService.Domain != service.Domain {
|
||||
domainChanged = true
|
||||
conflictService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("check existing service: %w", err)
|
||||
}
|
||||
}
|
||||
if conflictService != nil && conflictService.ID != service.ID {
|
||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
|
||||
}
|
||||
|
||||
if m.clusterDeriver != nil {
|
||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
||||
}
|
||||
service.ProxyCluster = newCluster
|
||||
}
|
||||
} else {
|
||||
service.ProxyCluster = existingService.ProxyCluster
|
||||
}
|
||||
|
||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||
service.Auth.PasswordAuth.Password == "" {
|
||||
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
||||
}
|
||||
|
||||
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
||||
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
||||
service.Auth.PinAuth.Pin == "" {
|
||||
service.Auth.PinAuth = existingService.Auth.PinAuth
|
||||
}
|
||||
|
||||
service.Meta = existingService.Meta
|
||||
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||
service.SessionPublicKey = existingService.SessionPublicKey
|
||||
serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||
|
||||
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("update service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
|
||||
|
||||
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m.sendServiceUpdateNotifications(service, updateInfo)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
type serviceUpdateInfo struct {
|
||||
oldCluster string
|
||||
domainChanged bool
|
||||
serviceEnabledChanged bool
|
||||
}
|
||||
|
||||
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
|
||||
var updateInfo serviceUpdateInfo
|
||||
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateInfo.oldCluster = existingService.ProxyCluster
|
||||
updateInfo.domainChanged = existingService.Domain != service.Domain
|
||||
|
||||
if updateInfo.domainChanged {
|
||||
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
service.ProxyCluster = existingService.ProxyCluster
|
||||
}
|
||||
|
||||
m.preserveExistingAuthSecrets(service, existingService)
|
||||
m.preserveServiceMetadata(service, existingService)
|
||||
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("update service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return &updateInfo, err
|
||||
}
|
||||
|
||||
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.clusterDeriver != nil {
|
||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
||||
} else {
|
||||
service.ProxyCluster = newCluster
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
|
||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||
service.Auth.PasswordAuth.Password == "" {
|
||||
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
||||
}
|
||||
|
||||
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
||||
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
||||
service.Auth.PinAuth.Pin == "" {
|
||||
service.Auth.PinAuth = existingService.Auth.PinAuth
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
|
||||
service.Meta = existingService.Meta
|
||||
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||
service.SessionPublicKey = existingService.SessionPublicKey
|
||||
}
|
||||
|
||||
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
|
||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
|
||||
switch {
|
||||
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster)
|
||||
case domainChanged && oldCluster != service.ProxyCluster:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), oldCluster)
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||
case !service.Enabled && updateInfo.serviceEnabledChanged:
|
||||
case !service.Enabled && serviceEnabledChanged:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
|
||||
case service.Enabled && updateInfo.serviceEnabledChanged:
|
||||
case service.Enabled && serviceEnabledChanged:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||
default:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
|
||||
}
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||
@@ -473,6 +432,8 @@ func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||
}
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,375 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func TestInitializeServiceForCreate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("successful initialization without cluster deriver", func(t *testing.T) {
|
||||
mgr := &managerImpl{
|
||||
clusterDeriver: nil,
|
||||
}
|
||||
|
||||
service := &reverseproxy.Service{
|
||||
Domain: "example.com",
|
||||
Auth: reverseproxy.AuthConfig{},
|
||||
}
|
||||
|
||||
err := mgr.initializeServiceForCreate(ctx, accountID, service)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, accountID, service.AccountID)
|
||||
assert.Empty(t, service.ProxyCluster, "proxy cluster should be empty when no deriver")
|
||||
assert.NotEmpty(t, service.ID, "service ID should be initialized")
|
||||
assert.NotEmpty(t, service.SessionPrivateKey, "session private key should be generated")
|
||||
assert.NotEmpty(t, service.SessionPublicKey, "session public key should be generated")
|
||||
})
|
||||
|
||||
t.Run("verifies session keys are different", func(t *testing.T) {
|
||||
mgr := &managerImpl{
|
||||
clusterDeriver: nil,
|
||||
}
|
||||
|
||||
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}}
|
||||
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}}
|
||||
|
||||
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
|
||||
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
|
||||
|
||||
assert.NoError(t, err1)
|
||||
assert.NoError(t, err2)
|
||||
assert.NotEqual(t, service1.SessionPrivateKey, service2.SessionPrivateKey, "private keys should be unique")
|
||||
assert.NotEqual(t, service1.SessionPublicKey, service2.SessionPublicKey, "public keys should be unique")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckDomainAvailable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
excludeServiceID string
|
||||
setupMock func(*store.MockStore)
|
||||
expectedError bool
|
||||
errorType status.Type
|
||||
}{
|
||||
{
|
||||
name: "domain available - not found",
|
||||
domain: "available.com",
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "available.com").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "domain already exists",
|
||||
domain: "exists.com",
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
errorType: status.AlreadyExists,
|
||||
},
|
||||
{
|
||||
name: "domain exists but excluded (same ID)",
|
||||
domain: "exists.com",
|
||||
excludeServiceID: "service-123",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "domain exists with different ID",
|
||||
domain: "exists.com",
|
||||
excludeServiceID: "service-456",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
errorType: status.AlreadyExists,
|
||||
},
|
||||
{
|
||||
name: "store error (non-NotFound)",
|
||||
domain: "error.com",
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "error.com").
|
||||
Return(nil, errors.New("database error"))
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
tt.setupMock(mockStore)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
|
||||
|
||||
if tt.expectedError {
|
||||
require.Error(t, err)
|
||||
if tt.errorType != 0 {
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok, "error should be a status error")
|
||||
assert.Equal(t, tt.errorType, sErr.Type())
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("empty domain", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty exclude ID with existing service", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "test.com").
|
||||
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
|
||||
|
||||
assert.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, status.AlreadyExists, sErr.Type())
|
||||
})
|
||||
|
||||
t.Run("nil existing service with nil error", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "nil.com").
|
||||
Return(nil, nil)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPersistNewService(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("successful service creation with no targets", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
service := &reverseproxy.Service{
|
||||
ID: "service-123",
|
||||
Domain: "new.com",
|
||||
Targets: []*reverseproxy.Target{},
|
||||
}
|
||||
|
||||
// Mock ExecuteInTransaction to execute the function immediately
|
||||
mockStore.EXPECT().
|
||||
ExecuteInTransaction(ctx, gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
||||
// Create another mock for the transaction
|
||||
txMock := store.NewMockStore(ctrl)
|
||||
txMock.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "new.com").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
txMock.EXPECT().
|
||||
CreateService(ctx, service).
|
||||
Return(nil)
|
||||
|
||||
return fn(txMock)
|
||||
})
|
||||
|
||||
mgr := &managerImpl{store: mockStore}
|
||||
err := mgr.persistNewService(ctx, accountID, service)
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("domain already exists", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
service := &reverseproxy.Service{
|
||||
ID: "service-123",
|
||||
Domain: "existing.com",
|
||||
Targets: []*reverseproxy.Target{},
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
ExecuteInTransaction(ctx, gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
||||
txMock := store.NewMockStore(ctrl)
|
||||
txMock.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "existing.com").
|
||||
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil)
|
||||
|
||||
return fn(txMock)
|
||||
})
|
||||
|
||||
mgr := &managerImpl{store: mockStore}
|
||||
err := mgr.persistNewService(ctx, accountID, service)
|
||||
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, status.AlreadyExists, sErr.Type())
|
||||
})
|
||||
}
|
||||
func TestPreserveExistingAuthSecrets(t *testing.T) {
|
||||
mgr := &managerImpl{}
|
||||
|
||||
t.Run("preserve password when empty", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "hashed-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||
|
||||
assert.Equal(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
|
||||
})
|
||||
|
||||
t.Run("preserve pin when empty", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PinAuth: &reverseproxy.PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "hashed-pin",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PinAuth: &reverseproxy.PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||
|
||||
assert.Equal(t, existing.Auth.PinAuth, updated.Auth.PinAuth)
|
||||
})
|
||||
|
||||
t.Run("do not preserve when password is provided", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "old-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "new-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||
|
||||
assert.Equal(t, "new-password", updated.Auth.PasswordAuth.Password)
|
||||
assert.NotEqual(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPreserveServiceMetadata(t *testing.T) {
|
||||
mgr := &managerImpl{}
|
||||
|
||||
existing := &reverseproxy.Service{
|
||||
Meta: reverseproxy.ServiceMeta{
|
||||
CertificateIssuedAt: time.Now(),
|
||||
Status: "active",
|
||||
},
|
||||
SessionPrivateKey: "private-key",
|
||||
SessionPublicKey: "public-key",
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Domain: "updated.com",
|
||||
}
|
||||
|
||||
mgr.preserveServiceMetadata(updated, existing)
|
||||
|
||||
assert.Equal(t, existing.Meta, updated.Meta)
|
||||
assert.Equal(t, existing.SessionPrivateKey, updated.SessionPrivateKey)
|
||||
assert.Equal(t, existing.SessionPublicKey, updated.SessionPublicKey)
|
||||
}
|
||||
@@ -140,11 +140,8 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
go metricsWorker.Run(srvCtx)
|
||||
}
|
||||
|
||||
// Eagerly create the gRPC server so that all AfterInit hooks are registered
|
||||
// before we iterate them. Lazy creation after the loop would miss hooks
|
||||
// registered during GRPCServer() construction (e.g., SetProxyManager).
|
||||
s.GRPCServer()
|
||||
|
||||
// Run afterInit hooks before starting any servers
|
||||
// This allows registering additional gRPC services (e.g., Signal) before Serve() is called
|
||||
for _, fn := range s.afterInit {
|
||||
if fn != nil {
|
||||
fn(s)
|
||||
|
||||
@@ -15,14 +15,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
@@ -520,11 +519,61 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
|
||||
}
|
||||
|
||||
authenticated, userId, method := s.authenticateRequest(ctx, req, service)
|
||||
var authenticated bool
|
||||
var userId string
|
||||
var method proxyauth.Method
|
||||
switch v := req.GetRequest().(type) {
|
||||
case *proto.AuthenticateRequest_Pin:
|
||||
auth := service.Auth.PinAuth
|
||||
if auth == nil || !auth.Enabled {
|
||||
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", req.GetId())
|
||||
break
|
||||
}
|
||||
err = argon2id.Verify(v.Pin.GetPin(), auth.Pin)
|
||||
if err != nil {
|
||||
if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
||||
log.WithContext(ctx).Tracef("PIN authentication failed: invalid PIN")
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("PIN authentication error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
authenticated = true
|
||||
userId = "pin-user"
|
||||
method = proxyauth.MethodPIN
|
||||
case *proto.AuthenticateRequest_Password:
|
||||
auth := service.Auth.PasswordAuth
|
||||
if auth == nil || !auth.Enabled {
|
||||
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", req.GetId())
|
||||
break
|
||||
}
|
||||
err = argon2id.Verify(v.Password.GetPassword(), auth.Password)
|
||||
if err != nil {
|
||||
if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
||||
log.WithContext(ctx).Tracef("Password authentication failed: invalid password")
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("Password authentication error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
authenticated = true
|
||||
userId = "password-user"
|
||||
method = proxyauth.MethodPassword
|
||||
}
|
||||
|
||||
token, err := s.generateSessionToken(ctx, authenticated, service, userId, method)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var token string
|
||||
if authenticated && service.SessionPrivateKey != "" {
|
||||
token, err = sessionkey.SignToken(
|
||||
service.SessionPrivateKey,
|
||||
userId,
|
||||
service.Domain,
|
||||
method,
|
||||
proxyauth.DefaultSessionExpiry,
|
||||
)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to sign session token")
|
||||
return nil, status.Errorf(codes.Internal, "sign session token: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &proto.AuthenticateResponse{
|
||||
@@ -533,73 +582,6 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *reverseproxy.Service) (bool, string, proxyauth.Method) {
|
||||
switch v := req.GetRequest().(type) {
|
||||
case *proto.AuthenticateRequest_Pin:
|
||||
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
|
||||
case *proto.AuthenticateRequest_Password:
|
||||
return s.authenticatePassword(ctx, req.GetId(), v, service.Auth.PasswordAuth)
|
||||
default:
|
||||
return false, "", ""
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *reverseproxy.PINAuthConfig) (bool, string, proxyauth.Method) {
|
||||
if auth == nil || !auth.Enabled {
|
||||
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID)
|
||||
return false, "", ""
|
||||
}
|
||||
|
||||
if err := argon2id.Verify(req.Pin.GetPin(), auth.Pin); err != nil {
|
||||
s.logAuthenticationError(ctx, err, "PIN")
|
||||
return false, "", ""
|
||||
}
|
||||
|
||||
return true, "pin-user", proxyauth.MethodPIN
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *reverseproxy.PasswordAuthConfig) (bool, string, proxyauth.Method) {
|
||||
if auth == nil || !auth.Enabled {
|
||||
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID)
|
||||
return false, "", ""
|
||||
}
|
||||
|
||||
if err := argon2id.Verify(req.Password.GetPassword(), auth.Password); err != nil {
|
||||
s.logAuthenticationError(ctx, err, "Password")
|
||||
return false, "", ""
|
||||
}
|
||||
|
||||
return true, "password-user", proxyauth.MethodPassword
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err error, authType string) {
|
||||
if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
||||
log.WithContext(ctx).Tracef("%s authentication failed: invalid credentials", authType)
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("%s authentication error: %v", authType, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *reverseproxy.Service, userId string, method proxyauth.Method) (string, error) {
|
||||
if !authenticated || service.SessionPrivateKey == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
token, err := sessionkey.SignToken(
|
||||
service.SessionPrivateKey,
|
||||
userId,
|
||||
service.Domain,
|
||||
method,
|
||||
proxyauth.DefaultSessionExpiry,
|
||||
)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to sign session token")
|
||||
return "", status.Errorf(codes.Internal, "sign session token: %v", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// SendStatusUpdate handles status updates from proxy clients
|
||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||
accountID := req.GetAccountId()
|
||||
|
||||
@@ -297,7 +297,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
var oldSettings *types.Settings
|
||||
var updateAccountPeers bool
|
||||
var groupChangesAffectPeers bool
|
||||
var reloadReverseProxy bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var groupsUpdated bool
|
||||
@@ -328,7 +327,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
||||
return err
|
||||
}
|
||||
reloadReverseProxy = true
|
||||
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
|
||||
}
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
@@ -393,11 +394,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
|
||||
}
|
||||
if reloadReverseProxy {
|
||||
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
|
||||
go am.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
@@ -3918,36 +3918,3 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
|
||||
assert.False(t, user.PendingApproval, "User should not be pending approval")
|
||||
assert.Equal(t, existingAccountID, user.AccountID)
|
||||
}
|
||||
|
||||
// TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange verifies that
|
||||
// changing NetworkRange via UpdateAccountSettings does not deadlock.
|
||||
// The deadlock occurs because ReloadAllServicesForAccount is called inside a DB
|
||||
// transaction but uses the main store connection, which blocks on the transaction lock.
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Use a channel to detect if the call completes or hangs
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
NetworkRange: netip.MustParsePrefix("10.100.0.0/16"),
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.NoError(t, err, "UpdateAccountSettings should complete without error")
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,8 +18,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
@@ -378,9 +376,9 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read)
|
||||
err = h.permissionsManager.ValidateAccountAccess(r.Context(), accountID, user, false)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
|
||||
util.WriteError(r.Context(), status.NewPermissionDeniedError(), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -390,12 +388,9 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if !allowed && !userAuth.IsChild {
|
||||
if account.Settings.RegularUsersViewBlocked {
|
||||
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
|
||||
return
|
||||
}
|
||||
|
||||
// If the user is regular user and does not own the peer
|
||||
// with the given peerID return an empty list
|
||||
if !user.HasAdminPower() && !user.IsServiceUser && !userAuth.IsChild {
|
||||
peer, ok := account.Peers[peerID]
|
||||
if !ok {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
|
||||
|
||||
@@ -22,8 +22,6 @@ import (
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
@@ -117,16 +115,6 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
||||
ctrl2 := gomock.NewController(t)
|
||||
permissionsManager := permissions.NewMockManager(ctrl2)
|
||||
permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
permissionsManager.EXPECT().
|
||||
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Eq(modules.Peers), gomock.Eq(operations.Read)).
|
||||
DoAndReturn(func(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error) {
|
||||
user, ok := account.Users[userID]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("user not found")
|
||||
}
|
||||
return user.HasAdminPower() || user.IsServiceUser, nil
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
return &Handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
@@ -395,11 +383,12 @@ func TestGetAccessiblePeers(t *testing.T) {
|
||||
UserID: regularUser,
|
||||
}
|
||||
|
||||
p := initTestMetaData(t, peer1, peer2, peer3)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
peerID string
|
||||
callerUserID string
|
||||
viewBlocked bool
|
||||
expectedStatus int
|
||||
expectedPeers []string
|
||||
}{
|
||||
@@ -438,56 +427,10 @@ func TestGetAccessiblePeers(t *testing.T) {
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{"peer1", "peer2"},
|
||||
},
|
||||
{
|
||||
name: "regular user gets empty for owned peer list when view blocked",
|
||||
peerID: "peer1",
|
||||
callerUserID: regularUser,
|
||||
viewBlocked: true,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{},
|
||||
},
|
||||
{
|
||||
name: "regular user gets empty list for unowned peer when view blocked",
|
||||
peerID: "peer2",
|
||||
callerUserID: regularUser,
|
||||
viewBlocked: true,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{},
|
||||
},
|
||||
{
|
||||
name: "admin user still sees accessible peers when view blocked",
|
||||
peerID: "peer2",
|
||||
callerUserID: adminUser,
|
||||
viewBlocked: true,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{"peer1", "peer3"},
|
||||
},
|
||||
{
|
||||
name: "service user still sees accessible peers when view blocked",
|
||||
peerID: "peer3",
|
||||
callerUserID: serviceUser,
|
||||
viewBlocked: true,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{"peer1", "peer2"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := initTestMetaData(t, peer1, peer2, peer3)
|
||||
|
||||
if tc.viewBlocked {
|
||||
mockAM := p.accountManager.(*mock_server.MockAccountManager)
|
||||
originalGetAccountByIDFunc := mockAM.GetAccountByIDFunc
|
||||
mockAM.GetAccountByIDFunc = func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
|
||||
account, err := originalGetAccountByIDFunc(ctx, accountID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
account.Settings.RegularUsersViewBlocked = true
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
|
||||
|
||||
@@ -561,99 +561,6 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri
|
||||
return account.Network.Copy(), err
|
||||
}
|
||||
|
||||
type peerAddAuthConfig struct {
|
||||
AccountID string
|
||||
SetupKeyID string
|
||||
SetupKeyName string
|
||||
GroupsToAdd []string
|
||||
AllowExtraDNSLabels bool
|
||||
Ephemeral bool
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) processPeerAddAuth(ctx context.Context, accountID, userID, encodedHashedKey string, peer *nbpeer.Peer, temporary, addedByUser, addedBySetupKey bool, opEvent *activity.Event) (*peerAddAuthConfig, error) {
|
||||
config := &peerAddAuthConfig{
|
||||
AccountID: accountID,
|
||||
Ephemeral: peer.Ephemeral,
|
||||
}
|
||||
|
||||
switch {
|
||||
case addedByUser:
|
||||
if err := am.handleUserAddedPeer(ctx, accountID, userID, temporary, opEvent, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case addedBySetupKey:
|
||||
if err := am.handleSetupKeyAddedPeer(ctx, encodedHashedKey, peer, opEvent, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
if peer.ProxyMeta.Embedded {
|
||||
log.WithContext(ctx).Debugf("adding peer for proxy embedded, accountID: %s", accountID)
|
||||
} else {
|
||||
log.WithContext(ctx).Warnf("adding peer without setup key or userID, accountID: %s", accountID)
|
||||
}
|
||||
}
|
||||
|
||||
opEvent.AccountID = config.AccountID
|
||||
if temporary {
|
||||
config.Ephemeral = true
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleUserAddedPeer(ctx context.Context, accountID, userID string, temporary bool, opEvent *activity.Event, config *peerAddAuthConfig) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "failed adding new peer: user not found")
|
||||
}
|
||||
if user.PendingApproval {
|
||||
return status.Errorf(status.PermissionDenied, "user pending approval cannot add peers")
|
||||
}
|
||||
|
||||
if temporary {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
} else {
|
||||
config.AccountID = user.AccountID
|
||||
config.GroupsToAdd = user.AutoGroups
|
||||
}
|
||||
|
||||
opEvent.InitiatorID = userID
|
||||
opEvent.Activity = activity.PeerAddedByUser
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleSetupKeyAddedPeer(ctx context.Context, encodedHashedKey string, peer *nbpeer.Peer, opEvent *activity.Event, config *peerAddAuthConfig) error {
|
||||
sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
if !sk.IsValid() {
|
||||
return status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
|
||||
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
|
||||
}
|
||||
|
||||
opEvent.InitiatorID = sk.Id
|
||||
opEvent.Activity = activity.PeerAddedWithSetupKey
|
||||
config.GroupsToAdd = sk.AutoGroups
|
||||
config.Ephemeral = sk.Ephemeral
|
||||
config.SetupKeyID = sk.Id
|
||||
config.SetupKeyName = sk.Name
|
||||
config.AllowExtraDNSLabels = sk.AllowExtraDNSLabels
|
||||
config.AccountID = sk.AccountID
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeer adds a new peer to the Store.
|
||||
// Each Account has a list of pre-authorized SetupKey and if no Account has a given key err with a code status.PermissionDenied
|
||||
// will be returned, meaning the setup key is invalid or not found.
|
||||
@@ -689,12 +596,70 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
|
||||
var newPeer *nbpeer.Peer
|
||||
|
||||
peerAddConfig, err := am.processPeerAddAuth(ctx, accountID, userID, encodedHashedKey, peer, temporary, addedByUser, addedBySetupKey, opEvent)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
var setupKeyID string
|
||||
var setupKeyName string
|
||||
var groupsToAdd []string
|
||||
var allowExtraDNSLabels bool
|
||||
ephemeral := peer.Ephemeral
|
||||
switch {
|
||||
case addedByUser:
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: user not found")
|
||||
}
|
||||
if user.PendingApproval {
|
||||
return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers")
|
||||
}
|
||||
if temporary {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Create)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, nil, nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
} else {
|
||||
accountID = user.AccountID
|
||||
groupsToAdd = user.AutoGroups
|
||||
}
|
||||
opEvent.InitiatorID = userID
|
||||
opEvent.Activity = activity.PeerAddedByUser
|
||||
case addedBySetupKey:
|
||||
// Validate the setup key
|
||||
sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
// we will check key twice for early return
|
||||
if !sk.IsValid() {
|
||||
return nil, nil, nil, status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
opEvent.InitiatorID = sk.Id
|
||||
opEvent.Activity = activity.PeerAddedWithSetupKey
|
||||
groupsToAdd = sk.AutoGroups
|
||||
ephemeral = sk.Ephemeral
|
||||
setupKeyID = sk.Id
|
||||
setupKeyName = sk.Name
|
||||
allowExtraDNSLabels = sk.AllowExtraDNSLabels
|
||||
accountID = sk.AccountID
|
||||
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
|
||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
|
||||
}
|
||||
default:
|
||||
if peer.ProxyMeta.Embedded {
|
||||
log.WithContext(ctx).Debugf("adding peer for proxy embedded, accountID: %s", accountID)
|
||||
} else {
|
||||
log.WithContext(ctx).Warnf("adding peer without setup key or userID, accountID: %s", accountID)
|
||||
}
|
||||
}
|
||||
opEvent.AccountID = accountID
|
||||
|
||||
if temporary {
|
||||
ephemeral = true
|
||||
}
|
||||
accountID = peerAddConfig.AccountID
|
||||
ephemeral := peerAddConfig.Ephemeral
|
||||
|
||||
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
||||
if am.idpManager != nil {
|
||||
@@ -728,7 +693,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
Location: peer.Location,
|
||||
InactivityExpirationEnabled: addedByUser && !temporary,
|
||||
ExtraDNSLabels: peer.ExtraDNSLabels,
|
||||
AllowExtraDNSLabels: peerAddConfig.AllowExtraDNSLabels,
|
||||
AllowExtraDNSLabels: allowExtraDNSLabels,
|
||||
}
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
@@ -746,7 +711,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
}
|
||||
}
|
||||
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, peerAddConfig.GroupsToAdd, settings.Extra, temporary)
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary)
|
||||
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
@@ -782,8 +747,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
return err
|
||||
}
|
||||
|
||||
if len(peerAddConfig.GroupsToAdd) > 0 {
|
||||
for _, g := range peerAddConfig.GroupsToAdd {
|
||||
if len(groupsToAdd) > 0 {
|
||||
for _, g := range groupsToAdd {
|
||||
err = transaction.AddPeerToGroup(ctx, newPeer.AccountID, newPeer.ID, g)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -815,7 +780,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
err = transaction.IncrementSetupKeyUsage(ctx, peerAddConfig.SetupKeyID)
|
||||
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment setup key usage: %w", err)
|
||||
}
|
||||
@@ -856,7 +821,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
opEvent.TargetID = newPeer.ID
|
||||
opEvent.Meta = newPeer.EventMeta(am.networkMapController.GetDNSDomain(settings))
|
||||
if !addedByUser {
|
||||
opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName
|
||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
@@ -2489,252 +2489,3 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
|
||||
_, _, _, err = manager.LoginPeer(context.Background(), login)
|
||||
require.NoError(t, err, "Regular user should be able to login peers")
|
||||
}
|
||||
|
||||
func TestHandleUserAddedPeer(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("regular user can add peer", func(t *testing.T) {
|
||||
regularUser := types.NewRegularUser("regular-user-1", "", "")
|
||||
regularUser.AccountID = account.Id
|
||||
regularUser.AutoGroups = []string{"group1", "group2"}
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
|
||||
err = manager.handleUserAddedPeer(context.Background(), account.Id, regularUser.Id, false, opEvent, config)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, account.Id, config.AccountID)
|
||||
assert.Equal(t, regularUser.AutoGroups, config.GroupsToAdd)
|
||||
assert.Equal(t, regularUser.Id, opEvent.InitiatorID)
|
||||
assert.Equal(t, activity.PeerAddedByUser, opEvent.Activity)
|
||||
})
|
||||
|
||||
t.Run("pending approval user cannot add peer", func(t *testing.T) {
|
||||
pendingUser := types.NewRegularUser("pending-user", "", "")
|
||||
pendingUser.AccountID = account.Id
|
||||
pendingUser.PendingApproval = true
|
||||
err = manager.Store.SaveUser(context.Background(), pendingUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
|
||||
err = manager.handleUserAddedPeer(context.Background(), account.Id, pendingUser.Id, false, opEvent, config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user pending approval cannot add peers")
|
||||
})
|
||||
|
||||
t.Run("user not found", func(t *testing.T) {
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
|
||||
err = manager.handleUserAddedPeer(context.Background(), account.Id, "non-existent-user", false, opEvent, config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user not found")
|
||||
})
|
||||
|
||||
t.Run("temporary peer requires permissions", func(t *testing.T) {
|
||||
regularUser := types.NewRegularUser("regular-user-2", "", "")
|
||||
regularUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
|
||||
// Should fail because user doesn't have permissions for temporary peers
|
||||
err = manager.handleUserAddedPeer(context.Background(), account.Id, regularUser.Id, true, opEvent, config)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleSetupKeyAddedPeer(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create admin user for setup key creation
|
||||
adminUser := types.NewAdminUser("admin-user")
|
||||
adminUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), adminUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("valid setup key", func(t *testing.T) {
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, []string{}, 0, adminUser.Id, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
upperKey := strings.ToUpper(setupKey.Key)
|
||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
peer := &nbpeer.Peer{ExtraDNSLabels: []string{}}
|
||||
|
||||
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, setupKey.Id, config.SetupKeyID)
|
||||
assert.Equal(t, setupKey.Name, config.SetupKeyName)
|
||||
assert.Equal(t, setupKey.AutoGroups, config.GroupsToAdd)
|
||||
assert.Equal(t, setupKey.Ephemeral, config.Ephemeral)
|
||||
assert.Equal(t, setupKey.Id, opEvent.InitiatorID)
|
||||
assert.Equal(t, activity.PeerAddedWithSetupKey, opEvent.Activity)
|
||||
})
|
||||
|
||||
t.Run("invalid setup key", func(t *testing.T) {
|
||||
invalidKey := "invalid-key"
|
||||
hashedKey := sha256.Sum256([]byte(invalidKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
peer := &nbpeer.Peer{}
|
||||
|
||||
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "setup key is invalid")
|
||||
})
|
||||
|
||||
t.Run("expired setup key", func(t *testing.T) {
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "expired-key", types.SetupKeyReusable, time.Millisecond, []string{}, 0, adminUser.Id, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for key to expire
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
upperKey := strings.ToUpper(setupKey.Key)
|
||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
peer := &nbpeer.Peer{}
|
||||
|
||||
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "setup key is invalid")
|
||||
})
|
||||
|
||||
t.Run("extra DNS labels not allowed", func(t *testing.T) {
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "no-dns-key", types.SetupKeyReusable, time.Hour, []string{}, 0, adminUser.Id, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
upperKey := strings.ToUpper(setupKey.Key)
|
||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
peer := &nbpeer.Peer{ExtraDNSLabels: []string{"custom.label"}}
|
||||
|
||||
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "doesn't allow extra DNS labels")
|
||||
})
|
||||
|
||||
t.Run("extra DNS labels allowed", func(t *testing.T) {
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "dns-key", types.SetupKeyReusable, time.Hour, []string{}, 0, adminUser.Id, false, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
upperKey := strings.ToUpper(setupKey.Key)
|
||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
peer := &nbpeer.Peer{ExtraDNSLabels: []string{"custom.label"}}
|
||||
|
||||
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, config.AllowExtraDNSLabels)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProcessPeerAddAuth(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
adminUser := types.NewAdminUser("admin")
|
||||
adminUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), adminUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("user authentication flow", func(t *testing.T) {
|
||||
regularUser := types.NewRegularUser("user-auth-test", "", "")
|
||||
regularUser.AccountID = account.Id
|
||||
regularUser.AutoGroups = []string{"group1"}
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
opEvent := &activity.Event{Timestamp: time.Now()}
|
||||
peer := &nbpeer.Peer{Ephemeral: false}
|
||||
|
||||
config, err := manager.processPeerAddAuth(context.Background(), account.Id, regularUser.Id, "", peer, false, true, false, opEvent)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, account.Id, config.AccountID)
|
||||
assert.False(t, config.Ephemeral)
|
||||
assert.Equal(t, regularUser.AutoGroups, config.GroupsToAdd)
|
||||
assert.Equal(t, account.Id, opEvent.AccountID)
|
||||
})
|
||||
|
||||
t.Run("setup key authentication flow", func(t *testing.T) {
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "auth-test-key", types.SetupKeyReusable, time.Hour, []string{}, 0, adminUser.Id, true, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
upperKey := strings.ToUpper(setupKey.Key)
|
||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
opEvent := &activity.Event{Timestamp: time.Now()}
|
||||
peer := &nbpeer.Peer{Ephemeral: false}
|
||||
|
||||
config, err := manager.processPeerAddAuth(context.Background(), account.Id, "", encodedHashedKey, peer, false, false, true, opEvent)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, account.Id, config.AccountID)
|
||||
assert.True(t, config.Ephemeral) // setupKey.Ephemeral is true
|
||||
assert.Equal(t, setupKey.AutoGroups, config.GroupsToAdd)
|
||||
assert.Equal(t, account.Id, opEvent.AccountID)
|
||||
})
|
||||
|
||||
t.Run("temporary flag overrides ephemeral", func(t *testing.T) {
|
||||
regularUser := types.NewRegularUser("temp-user", "", "")
|
||||
regularUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
opEvent := &activity.Event{Timestamp: time.Now()}
|
||||
peer := &nbpeer.Peer{Ephemeral: false}
|
||||
|
||||
config, err := manager.processPeerAddAuth(context.Background(), account.Id, regularUser.Id, "", peer, true, true, false, opEvent)
|
||||
require.Error(t, err) // Will fail permission check but that's expected
|
||||
_ = config // avoid unused warning
|
||||
})
|
||||
|
||||
t.Run("proxy embedded peer (no auth)", func(t *testing.T) {
|
||||
opEvent := &activity.Event{Timestamp: time.Now()}
|
||||
peer := &nbpeer.Peer{
|
||||
Ephemeral: false,
|
||||
ProxyMeta: nbpeer.ProxyMeta{Embedded: true},
|
||||
}
|
||||
|
||||
config, err := manager.processPeerAddAuth(context.Background(), account.Id, "", "", peer, false, false, false, opEvent)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, account.Id, config.AccountID)
|
||||
assert.False(t, config.Ephemeral)
|
||||
assert.Empty(t, config.GroupsToAdd)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
package store
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -package store -destination=store_mock.go -source=./store.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -374,6 +374,74 @@ func (a *Account) GetPeerNetworkMap(
|
||||
return nm
|
||||
}
|
||||
|
||||
// GetProxyConnectionResources returns ACL peers for the proxy-embedded peer based on exposed services.
|
||||
// No firewall rules are generated here; the proxy peer is always a new on-demand client with a stateful
|
||||
// firewall, so OUT rules are unnecessary. Inbound rules are handled on the target/router peer side.
|
||||
func (a *Account) GetProxyConnectionResources(ctx context.Context, exposedServices map[string][]*reverseproxy.Service) []*nbpeer.Peer {
|
||||
var aclPeers []*nbpeer.Peer
|
||||
|
||||
for _, peerServices := range exposedServices {
|
||||
for _, service := range peerServices {
|
||||
if !service.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, target := range service.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
}
|
||||
if target.TargetType == reverseproxy.TargetTypePeer {
|
||||
tpeer := a.GetPeer(target.TargetId)
|
||||
if tpeer == nil {
|
||||
continue
|
||||
}
|
||||
aclPeers = append(aclPeers, tpeer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return aclPeers
|
||||
}
|
||||
|
||||
// GetPeerProxyResources returns ACL peers and inbound firewall rules for a peer that is targeted by reverse proxy services.
|
||||
// Only IN rules are generated; OUT rules are omitted since proxy peers are always new clients with stateful firewalls.
|
||||
// Rules use PortRange only (not the legacy Port field) as this feature only targets current peer versions.
|
||||
func (a *Account) GetPeerProxyResources(peerID string, services []*reverseproxy.Service, proxyPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
var aclPeers []*nbpeer.Peer
|
||||
var firewallRules []*FirewallRule
|
||||
|
||||
for _, service := range services {
|
||||
if !service.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, target := range service.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
aclPeers = proxyPeers
|
||||
|
||||
needsPeerRules := (target.TargetType == reverseproxy.TargetTypePeer && target.TargetId == peerID) ||
|
||||
(target.TargetType == reverseproxy.TargetTypeHost || target.TargetType == reverseproxy.TargetTypeSubnet || target.TargetType == reverseproxy.TargetTypeDomain)
|
||||
|
||||
if needsPeerRules {
|
||||
for _, proxyPeer := range proxyPeers {
|
||||
firewallRules = append(firewallRules, &FirewallRule{
|
||||
PolicyID: "proxy-" + service.ID,
|
||||
PeerIP: proxyPeer.IP.String(),
|
||||
Direction: FirewallRuleDirectionIN,
|
||||
Action: "allow",
|
||||
Protocol: string(PolicyRuleProtocolTCP),
|
||||
PortRange: RulePortRange{Start: uint16(target.Port), End: uint16(target.Port)},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return aclPeers, firewallRules
|
||||
}
|
||||
|
||||
func (a *Account) addNetworksRoutingPeers(
|
||||
networkResourcesRoutes []*route.Route,
|
||||
peer *nbpeer.Peer,
|
||||
@@ -1796,6 +1864,71 @@ func (a *Account) GetProxyPeers() map[string][]*nbpeer.Peer {
|
||||
return proxyPeers
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerProxyRoutes(ctx context.Context, peer *nbpeer.Peer, proxies map[string][]*reverseproxy.Service, resourcesMap map[string]*resourceTypes.NetworkResource, routers map[string]map[string]*routerTypes.NetworkRouter, proxyPeers []*nbpeer.Peer) ([]*route.Route, []*RouteFirewallRule, []*nbpeer.Peer) {
|
||||
sourceRanges := make([]string, 0, len(proxyPeers))
|
||||
for _, proxyPeer := range proxyPeers {
|
||||
sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, proxyPeer.IP))
|
||||
}
|
||||
peers := make(map[string]*nbpeer.Peer, len(resourcesMap))
|
||||
|
||||
var routes []*route.Route
|
||||
var firewallRules []*RouteFirewallRule
|
||||
for _, proxyPerResource := range proxies {
|
||||
for _, proxy := range proxyPerResource {
|
||||
for _, target := range proxy.Targets {
|
||||
if target.TargetType == reverseproxy.TargetTypeHost || target.TargetType == reverseproxy.TargetTypeSubnet || target.TargetType == reverseproxy.TargetTypeDomain {
|
||||
resource, ok := resourcesMap[target.TargetId]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Warnf("proxy target %s not found in resources map", target.TargetId)
|
||||
continue
|
||||
}
|
||||
networkRouters, ok := routers[resource.NetworkID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Warnf("proxy target %s not found in routers map", target.TargetId)
|
||||
continue
|
||||
}
|
||||
for peerID, router := range networkRouters {
|
||||
routePeer := a.GetPeer(peerID)
|
||||
route := resource.ToRoute(routePeer, router)
|
||||
routes = append(routes, route)
|
||||
rule := RouteFirewallRule{
|
||||
PolicyID: fmt.Sprintf("proxy-%s-%s", proxy.ID, route.ID),
|
||||
RouteID: route.ID,
|
||||
SourceRanges: sourceRanges,
|
||||
Action: string(PolicyTrafficActionAccept),
|
||||
Destination: route.Network.String(),
|
||||
Protocol: string(PolicyRuleProtocolTCP),
|
||||
Domains: route.Domains,
|
||||
IsDynamic: route.IsDynamic(),
|
||||
PortRange: RulePortRange{
|
||||
Start: uint16(target.Port),
|
||||
End: uint16(target.Port),
|
||||
},
|
||||
}
|
||||
firewallRules = append(firewallRules, &rule)
|
||||
peers[peerID] = routePeer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resultPeers := make([]*nbpeer.Peer, 0, len(peers))
|
||||
for _, peer := range peers {
|
||||
resultPeers = append(resultPeers, peer)
|
||||
}
|
||||
|
||||
return routes, firewallRules, resultPeers
|
||||
}
|
||||
|
||||
func (a *Account) GetResourcesMap() map[string]*resourceTypes.NetworkResource {
|
||||
resourcesMap := make(map[string]*resourceTypes.NetworkResource, len(a.NetworkResources))
|
||||
for _, resource := range a.NetworkResources {
|
||||
resourcesMap[resource.ID] = resource
|
||||
}
|
||||
return resourcesMap
|
||||
}
|
||||
|
||||
func (a *Account) InjectProxyPolicies(ctx context.Context) {
|
||||
if len(a.Services) == 0 {
|
||||
return
|
||||
@@ -1810,83 +1943,62 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) {
|
||||
if !service.Enabled {
|
||||
continue
|
||||
}
|
||||
a.injectServiceProxyPolicies(ctx, service, proxyPeersByCluster)
|
||||
}
|
||||
}
|
||||
for _, target := range service.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
|
||||
for _, target := range service.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
}
|
||||
a.injectTargetProxyPolicies(ctx, service, target, proxyPeersByCluster[service.ProxyCluster])
|
||||
}
|
||||
}
|
||||
for _, proxyPeer := range proxyPeersByCluster[service.ProxyCluster] {
|
||||
port := target.Port
|
||||
if port == 0 {
|
||||
switch target.Protocol {
|
||||
case "https":
|
||||
port = 443
|
||||
case "http":
|
||||
port = 80
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("unsupported protocol %s for proxy target %s, skipping policy injection", target.Protocol, target.TargetId)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) {
|
||||
port, ok := a.resolveTargetPort(ctx, target)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
path := ""
|
||||
if target.Path != nil {
|
||||
path = *target.Path
|
||||
}
|
||||
|
||||
for _, proxyPeer := range proxyPeers {
|
||||
policy := a.createProxyPolicy(service, target, proxyPeer, port, path)
|
||||
a.Policies = append(a.Policies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) {
|
||||
if target.Port != 0 {
|
||||
return target.Port, true
|
||||
}
|
||||
|
||||
switch target.Protocol {
|
||||
case "https":
|
||||
return 443, true
|
||||
case "http":
|
||||
return 80, true
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("unsupported protocol %s for proxy target %s, skipping policy injection", target.Protocol, target.TargetId)
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
|
||||
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
|
||||
return &Policy{
|
||||
ID: policyID,
|
||||
Name: fmt.Sprintf("Proxy Access to %s", service.Name),
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: policyID,
|
||||
PolicyID: policyID,
|
||||
Name: fmt.Sprintf("Allow access to %s", service.Name),
|
||||
Enabled: true,
|
||||
SourceResource: Resource{
|
||||
ID: proxyPeer.ID,
|
||||
Type: ResourceTypePeer,
|
||||
},
|
||||
DestinationResource: Resource{
|
||||
ID: target.TargetId,
|
||||
Type: ResourceType(target.TargetType),
|
||||
},
|
||||
Bidirectional: false,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
PortRanges: []RulePortRange{
|
||||
{
|
||||
Start: uint16(port),
|
||||
End: uint16(port),
|
||||
path := ""
|
||||
if target.Path != nil {
|
||||
path = *target.Path
|
||||
}
|
||||
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
|
||||
a.Policies = append(a.Policies, &Policy{
|
||||
ID: policyID,
|
||||
Name: fmt.Sprintf("Proxy Access to %s", service.Name),
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: policyID,
|
||||
PolicyID: policyID,
|
||||
Name: fmt.Sprintf("Allow access to %s", service.Name),
|
||||
Enabled: true,
|
||||
SourceResource: Resource{
|
||||
ID: proxyPeer.ID,
|
||||
Type: ResourceTypePeer,
|
||||
},
|
||||
DestinationResource: Resource{
|
||||
ID: target.TargetId,
|
||||
Type: ResourceType(target.TargetType),
|
||||
},
|
||||
Bidirectional: false,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
PortRanges: []RulePortRange{
|
||||
{
|
||||
Start: uint16(port),
|
||||
End: uint16(port),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,19 +1,25 @@
|
||||
FROM golang:1.25-alpine AS builder
|
||||
WORKDIR /app
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o netbird-proxy ./proxy/cmd/proxy
|
||||
|
||||
RUN echo "netbird:x:1000:1000:netbird:/var/lib/netbird:/sbin/nologin" > /tmp/passwd && \
|
||||
echo "netbird:x:1000:netbird" > /tmp/group && \
|
||||
mkdir -p /tmp/var/lib/netbird && \
|
||||
mkdir -p /tmp/certs
|
||||
|
||||
FROM gcr.io/distroless/base:debug
|
||||
COPY netbird-proxy /go/bin/netbird-proxy
|
||||
COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy
|
||||
COPY --from=builder /tmp/passwd /etc/passwd
|
||||
COPY --from=builder /tmp/group /etc/group
|
||||
COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird
|
||||
COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs
|
||||
COPY --from=builder --chown=1000:1000 /tmp/certs /certs
|
||||
USER netbird:netbird
|
||||
ENV HOME=/var/lib/netbird
|
||||
ENV NB_PROXY_ADDRESS=":8443"
|
||||
EXPOSE 8443
|
||||
ENTRYPOINT ["/go/bin/netbird-proxy"]
|
||||
ENTRYPOINT ["/usr/bin/netbird-proxy"]
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
FROM golang:1.25-alpine AS builder
|
||||
WORKDIR /app
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY client ./client
|
||||
COPY dns ./dns
|
||||
COPY encryption ./encryption
|
||||
COPY flow ./flow
|
||||
COPY formatter ./formatter
|
||||
COPY monotime ./monotime
|
||||
COPY proxy ./proxy
|
||||
COPY route ./route
|
||||
COPY shared ./shared
|
||||
COPY sharedsock ./sharedsock
|
||||
COPY upload-server ./upload-server
|
||||
COPY util ./util
|
||||
COPY version ./version
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o netbird-proxy ./proxy/cmd/proxy
|
||||
|
||||
RUN echo "netbird:x:1000:1000:netbird:/var/lib/netbird:/sbin/nologin" > /tmp/passwd && \
|
||||
echo "netbird:x:1000:netbird" > /tmp/group && \
|
||||
mkdir -p /tmp/var/lib/netbird && \
|
||||
mkdir -p /tmp/certs
|
||||
|
||||
FROM gcr.io/distroless/base:debug
|
||||
COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy
|
||||
COPY --from=builder /tmp/passwd /etc/passwd
|
||||
COPY --from=builder /tmp/group /etc/group
|
||||
COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird
|
||||
COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs
|
||||
USER netbird:netbird
|
||||
ENV HOME=/var/lib/netbird
|
||||
ENV NB_PROXY_ADDRESS=":8443"
|
||||
EXPOSE 8443
|
||||
ENTRYPOINT ["/usr/bin/netbird-proxy"]
|
||||
@@ -39,10 +39,10 @@ var (
|
||||
addr string
|
||||
proxyDomain string
|
||||
certDir string
|
||||
acmeCerts bool
|
||||
acmeAddr string
|
||||
acmeDir string
|
||||
acmeChallengeType string
|
||||
acmeCerts bool
|
||||
acmeAddr string
|
||||
acmeDir string
|
||||
acmeChallengeType string
|
||||
debugEndpoint bool
|
||||
debugEndpointAddr string
|
||||
healthAddr string
|
||||
@@ -56,7 +56,6 @@ var (
|
||||
certKeyFile string
|
||||
certLockMethod string
|
||||
wgPort int
|
||||
proxyProtocol bool
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
@@ -91,7 +90,6 @@ func init() {
|
||||
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
|
||||
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
|
||||
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
|
||||
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
|
||||
}
|
||||
|
||||
// Execute runs the root command.
|
||||
@@ -125,7 +123,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
|
||||
_ = util.InitLogger(logger, level, util.LogConsole)
|
||||
|
||||
logger.Infof("configured log level: %s", level)
|
||||
log.Infof("configured log level: %s", level)
|
||||
|
||||
switch forwardedProto {
|
||||
case "auto", "http", "https":
|
||||
@@ -167,14 +165,13 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
TrustedProxies: parsedTrustedProxies,
|
||||
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
||||
WireguardPort: wgPort,
|
||||
ProxyProtocol: proxyProtocol,
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||
defer stop()
|
||||
|
||||
if err := srv.ListenAndServe(ctx, addr); err != nil {
|
||||
logger.Error(err)
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -84,7 +84,7 @@ func (l *flockLocker) Lock(ctx context.Context, domain string) (func(), error) {
|
||||
|
||||
// nil lockFile means locking is not supported (non-unix).
|
||||
if lockFile == nil {
|
||||
return func() { /* no-op: locking unsupported on this platform */ }, nil
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
return func() {
|
||||
@@ -98,5 +98,5 @@ type noopLocker struct{}
|
||||
|
||||
// Lock is a no-op that always succeeds immediately.
|
||||
func (noopLocker) Lock(context.Context, string) (func(), error) {
|
||||
return func() { /* no-op: locker disabled */ }, nil
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
@@ -90,8 +90,10 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
if err != nil {
|
||||
host = r.Host
|
||||
}
|
||||
mw.domainsMux.RLock()
|
||||
config, exists := mw.domains[host]
|
||||
mw.domainsMux.RUnlock()
|
||||
|
||||
config, exists := mw.getDomainConfig(host)
|
||||
mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists)
|
||||
|
||||
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
|
||||
@@ -101,160 +103,115 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// Set account and service IDs in captured data for access logging.
|
||||
setCapturedIDs(r, config)
|
||||
|
||||
if mw.handleOAuthCallbackError(w, r) {
|
||||
return
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetAccountId(types.AccountID(config.AccountID))
|
||||
cd.SetServiceId(config.ServiceID)
|
||||
}
|
||||
|
||||
if mw.forwardWithSessionCookie(w, r, host, config, next) {
|
||||
return
|
||||
}
|
||||
|
||||
mw.authenticateWithSchemes(w, r, host, config)
|
||||
})
|
||||
}
|
||||
|
||||
func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
|
||||
mw.domainsMux.RLock()
|
||||
defer mw.domainsMux.RUnlock()
|
||||
config, exists := mw.domains[host]
|
||||
return config, exists
|
||||
}
|
||||
|
||||
func setCapturedIDs(r *http.Request, config DomainConfig) {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetAccountId(types.AccountID(config.AccountID))
|
||||
cd.SetServiceId(config.ServiceID)
|
||||
}
|
||||
}
|
||||
|
||||
// handleOAuthCallbackError checks for error query parameters from an OAuth
|
||||
// callback and renders the access denied page if present.
|
||||
func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool {
|
||||
errCode := r.URL.Query().Get("error")
|
||||
if errCode == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(auth.MethodOIDC.String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
errDesc := r.URL.Query().Get("error_description")
|
||||
if errDesc == "" {
|
||||
errDesc = "An error occurred during authentication"
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
|
||||
return true
|
||||
}
|
||||
|
||||
// forwardWithSessionCookie checks for a valid session cookie and, if found,
|
||||
// sets the user identity on the request context and forwards to the next handler.
|
||||
func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
|
||||
cookie, err := r.Cookie(auth.SessionCookieName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(userID)
|
||||
cd.SetAuthMethod(method)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
|
||||
// authenticateWithSchemes tries each configured auth scheme in order.
|
||||
// On success it sets a session cookie and redirects; on failure it renders the login page.
|
||||
func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) {
|
||||
methods := make(map[string]string)
|
||||
var attemptedMethod string
|
||||
|
||||
for _, scheme := range config.Schemes {
|
||||
token, promptData, err := scheme.Authenticate(r)
|
||||
if err != nil {
|
||||
mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err)
|
||||
// Check for error from OAuth callback (e.g., access denied)
|
||||
if errCode := r.URL.Query().Get("error"); errCode != "" {
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(auth.MethodOIDC.String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
|
||||
errDesc := r.URL.Query().Get("error_description")
|
||||
if errDesc == "" {
|
||||
errDesc = "An error occurred during authentication"
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
|
||||
return
|
||||
}
|
||||
|
||||
// Track if credentials were submitted but auth failed
|
||||
if token == "" && wasCredentialSubmitted(r, scheme.Type()) {
|
||||
attemptedMethod = scheme.Type().String()
|
||||
// Check for an existing session cookie (contains JWT)
|
||||
if cookie, err := r.Cookie(auth.SessionCookieName); err == nil {
|
||||
if userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey); err == nil {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(userID)
|
||||
cd.SetAuthMethod(method)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
mw.handleAuthenticatedToken(w, r, host, token, config, scheme)
|
||||
return
|
||||
}
|
||||
methods[scheme.Type().String()] = promptData
|
||||
}
|
||||
// Try to authenticate with each scheme.
|
||||
methods := make(map[string]string)
|
||||
var attemptedMethod string
|
||||
for _, scheme := range config.Schemes {
|
||||
token, promptData, err := scheme.Authenticate(r)
|
||||
if err != nil {
|
||||
mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
}
|
||||
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
if attemptedMethod != "" {
|
||||
cd.SetAuthMethod(attemptedMethod)
|
||||
}
|
||||
}
|
||||
web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized)
|
||||
}
|
||||
// Track if credentials were submitted but auth failed
|
||||
if token == "" && wasCredentialSubmitted(r, scheme.Type()) {
|
||||
attemptedMethod = scheme.Type().String()
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
|
||||
if err != nil {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !result.Valid {
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
|
||||
return
|
||||
}
|
||||
|
||||
expiration := config.SessionExpiration
|
||||
if expiration == 0 {
|
||||
expiration = auth.DefaultSessionExpiry
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: auth.SessionCookieName,
|
||||
Value: token,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(expiration.Seconds()),
|
||||
})
|
||||
|
||||
// Redirect instead of forwarding the auth POST to the backend.
|
||||
// The browser will follow with a GET carrying the new session cookie.
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
methods[scheme.Type().String()] = promptData
|
||||
}
|
||||
|
||||
// handleAuthenticatedToken validates the token, handles denied access, and on
|
||||
// success sets a session cookie and redirects to the original URL.
|
||||
func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Request, host, token string, config DomainConfig, scheme Scheme) {
|
||||
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
|
||||
if err != nil {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
if attemptedMethod != "" {
|
||||
cd.SetAuthMethod(attemptedMethod)
|
||||
}
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !result.Valid {
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
|
||||
return
|
||||
}
|
||||
|
||||
expiration := config.SessionExpiration
|
||||
if expiration == 0 {
|
||||
expiration = auth.DefaultSessionExpiry
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: auth.SessionCookieName,
|
||||
Value: token,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(expiration.Seconds()),
|
||||
web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
// Redirect instead of forwarding the auth POST to the backend.
|
||||
// The browser will follow with a GET carrying the new session cookie.
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// wasCredentialSubmitted checks if credentials were submitted for the given auth method.
|
||||
|
||||
@@ -83,10 +83,6 @@ func (c *Client) printHealth(data map[string]any) {
|
||||
}
|
||||
}
|
||||
|
||||
c.printHealthClients(data)
|
||||
}
|
||||
|
||||
func (c *Client) printHealthClients(data map[string]any) {
|
||||
clients, ok := data["clients"].(map[string]any)
|
||||
if !ok || len(clients) == 0 {
|
||||
return
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPrintHealth_WithCertsAndClients(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
c := NewClient("localhost:8444", false, &buf)
|
||||
|
||||
data := map[string]any{
|
||||
"status": "ok",
|
||||
"uptime": "1h30m",
|
||||
"management_connected": true,
|
||||
"all_clients_healthy": true,
|
||||
"certs_total": float64(3),
|
||||
"certs_ready": float64(2),
|
||||
"certs_pending": float64(1),
|
||||
"certs_failed": float64(0),
|
||||
"certs_ready_domains": []any{"a.example.com", "b.example.com"},
|
||||
"certs_pending_domains": []any{"c.example.com"},
|
||||
"clients": map[string]any{
|
||||
"acc-1": map[string]any{
|
||||
"healthy": true,
|
||||
"management_connected": true,
|
||||
"signal_connected": true,
|
||||
"relays_connected": float64(1),
|
||||
"relays_total": float64(2),
|
||||
"peers_connected": float64(3),
|
||||
"peers_total": float64(5),
|
||||
"peers_p2p": float64(2),
|
||||
"peers_relayed": float64(1),
|
||||
"peers_degraded": float64(0),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
c.printHealth(data)
|
||||
out := buf.String()
|
||||
|
||||
assert.Contains(t, out, "Status: ok")
|
||||
assert.Contains(t, out, "Uptime: 1h30m")
|
||||
assert.Contains(t, out, "yes") // management_connected
|
||||
assert.Contains(t, out, "2 ready, 1 pending, 0 failed (3 total)")
|
||||
assert.Contains(t, out, "a.example.com")
|
||||
assert.Contains(t, out, "c.example.com")
|
||||
assert.Contains(t, out, "acc-1")
|
||||
}
|
||||
|
||||
func TestPrintHealth_Minimal(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
c := NewClient("localhost:8444", false, &buf)
|
||||
|
||||
data := map[string]any{
|
||||
"status": "ok",
|
||||
"uptime": "5m",
|
||||
"management_connected": false,
|
||||
"all_clients_healthy": false,
|
||||
}
|
||||
|
||||
c.printHealth(data)
|
||||
out := buf.String()
|
||||
|
||||
assert.Contains(t, out, "Status: ok")
|
||||
assert.Contains(t, out, "Uptime: 5m")
|
||||
assert.NotContains(t, out, "Certificates")
|
||||
assert.NotContains(t, out, "ACCOUNT ID")
|
||||
}
|
||||
@@ -17,11 +17,11 @@
|
||||
<h2>Client Control</h2>
|
||||
<div class="form-row">
|
||||
<div class="form-group">
|
||||
<span> </span>
|
||||
<label> </label>
|
||||
<button onclick="startClient()">Start</button>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<span> </span>
|
||||
<label> </label>
|
||||
<button onclick="stopClient()">Stop</button>
|
||||
</div>
|
||||
</div>
|
||||
@@ -30,7 +30,7 @@
|
||||
<h2>Log Level</h2>
|
||||
<div class="form-row">
|
||||
<div class="form-group">
|
||||
<label for="log-level">Level</label>
|
||||
<label>Level</label>
|
||||
<select id="log-level" style="width: 120px;">
|
||||
<option value="trace">trace</option>
|
||||
<option value="debug">debug</option>
|
||||
@@ -40,7 +40,7 @@
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<span> </span>
|
||||
<label> </label>
|
||||
<button onclick="setLogLevel()">Set Level</button>
|
||||
</div>
|
||||
</div>
|
||||
@@ -49,15 +49,15 @@
|
||||
<h2>TCP Ping</h2>
|
||||
<div class="form-row">
|
||||
<div class="form-group">
|
||||
<label for="tcp-host">Host</label>
|
||||
<label>Host</label>
|
||||
<input type="text" id="tcp-host" placeholder="100.0.0.1 or hostname.netbird.cloud" style="width: 300px;">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="tcp-port">Port</label>
|
||||
<label>Port</label>
|
||||
<input type="number" id="tcp-port" placeholder="80" style="width: 80px;">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<span> </span>
|
||||
<label> </label>
|
||||
<button onclick="doTcpPing()">Connect</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -323,7 +323,7 @@ func NewServer(addr string, checker *Checker, logger *log.Logger, metricsHandler
|
||||
if metricsHandler != nil {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/metrics", metricsHandler)
|
||||
mux.Handle("/", handler)
|
||||
mux.Handle("/", checker.Handler())
|
||||
handler = mux
|
||||
}
|
||||
|
||||
|
||||
@@ -404,70 +404,3 @@ func TestChecker_Handler_Full(t *testing.T) {
|
||||
// Clients may be empty map when no clients exist.
|
||||
assert.Empty(t, resp.Clients)
|
||||
}
|
||||
|
||||
func TestChecker_SetShuttingDown(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
|
||||
assert.True(t, checker.ReadinessProbe(), "should be ready before shutdown")
|
||||
|
||||
checker.SetShuttingDown()
|
||||
|
||||
assert.False(t, checker.ReadinessProbe(), "should not be ready after shutdown")
|
||||
}
|
||||
|
||||
func TestChecker_Handler_Readiness_ShuttingDown(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
checker.SetShuttingDown()
|
||||
handler := checker.Handler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
|
||||
var resp ProbeResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
assert.Equal(t, "fail", resp.Status)
|
||||
}
|
||||
|
||||
func TestNewServer_WithMetricsHandler(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
|
||||
metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("metrics"))
|
||||
})
|
||||
|
||||
srv := NewServer(":0", checker, nil, metricsHandler)
|
||||
require.NotNil(t, srv)
|
||||
|
||||
// Verify health endpoint still works through the mux.
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.server.Handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Verify metrics endpoint is mounted.
|
||||
req = httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
srv.server.Handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "metrics", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestNewServer_WithoutMetricsHandler(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
|
||||
srv := NewServer(":0", checker, nil, nil)
|
||||
require.NotNil(t, srv)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.server.Handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
@@ -320,8 +320,7 @@ func getRequestID(r *http.Request) string {
|
||||
// status code, and component status based on the error type.
|
||||
func classifyProxyError(err error) (title, message string, code int, status web.ErrorStatus) {
|
||||
switch {
|
||||
case errors.Is(err, context.DeadlineExceeded),
|
||||
isNetTimeout(err):
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
return "Request Timeout",
|
||||
"The request timed out while trying to reach the service. Please refresh the page and try again.",
|
||||
http.StatusGatewayTimeout,
|
||||
@@ -346,12 +345,6 @@ func classifyProxyError(err error) (title, message string, code int, status web.
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: false, Destination: false}
|
||||
|
||||
case errors.Is(err, roundtrip.ErrTooManyInflight):
|
||||
return "Service Overloaded",
|
||||
"The service is currently handling too many requests. Please try again shortly.",
|
||||
http.StatusServiceUnavailable,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
|
||||
case isConnectionRefused(err):
|
||||
return "Service Unavailable",
|
||||
"The connection to the service was refused. Please verify that the service is running and try again.",
|
||||
@@ -363,6 +356,12 @@ func classifyProxyError(err error) (title, message string, code int, status web.
|
||||
"The connection to the peer could not be established. Please ensure the peer is running and connected to the NetBird network.",
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
|
||||
case isNetTimeout(err):
|
||||
return "Request Timeout",
|
||||
"The request timed out while trying to reach the service. Please refresh the page and try again.",
|
||||
http.StatusGatewayTimeout,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
}
|
||||
|
||||
return "Connection Error",
|
||||
|
||||
@@ -24,9 +24,6 @@ import (
|
||||
|
||||
const deviceNamePrefix = "ingress-proxy-"
|
||||
|
||||
// backendKey identifies a backend by its host:port from the target URL.
|
||||
type backendKey = string
|
||||
|
||||
var (
|
||||
// ErrNoAccountID is returned when a request context is missing the account ID.
|
||||
ErrNoAccountID = errors.New("no account ID in request context")
|
||||
@@ -34,8 +31,6 @@ var (
|
||||
ErrNoPeerConnection = errors.New("no peer connection found")
|
||||
// ErrClientStartFailed is returned when the embedded client fails to start.
|
||||
ErrClientStartFailed = errors.New("client start failed")
|
||||
// ErrTooManyInflight is returned when the per-backend in-flight limit is reached.
|
||||
ErrTooManyInflight = errors.New("too many in-flight requests")
|
||||
)
|
||||
|
||||
// domainInfo holds metadata about a registered domain.
|
||||
@@ -43,11 +38,6 @@ type domainInfo struct {
|
||||
serviceID string
|
||||
}
|
||||
|
||||
type domainNotification struct {
|
||||
domain domain.Domain
|
||||
serviceID string
|
||||
}
|
||||
|
||||
// clientEntry holds an embedded NetBird client and tracks which domains use it.
|
||||
type clientEntry struct {
|
||||
client *embed.Client
|
||||
@@ -55,35 +45,6 @@ type clientEntry struct {
|
||||
domains map[domain.Domain]domainInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
// Per-backend in-flight limiting keyed by target host:port.
|
||||
// TODO: clean up stale entries when backend targets change.
|
||||
inflightMu sync.Mutex
|
||||
inflightMap map[backendKey]chan struct{}
|
||||
maxInflight int
|
||||
}
|
||||
|
||||
// acquireInflight attempts to acquire an in-flight slot for the given backend.
|
||||
// It returns a release function that must always be called, and true on success.
|
||||
func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bool) {
|
||||
noop := func() {}
|
||||
if e.maxInflight <= 0 {
|
||||
return noop, true
|
||||
}
|
||||
|
||||
e.inflightMu.Lock()
|
||||
sem, exists := e.inflightMap[backend]
|
||||
if !exists {
|
||||
sem = make(chan struct{}, e.maxInflight)
|
||||
e.inflightMap[backend] = sem
|
||||
}
|
||||
e.inflightMu.Unlock()
|
||||
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
return func() { <-sem }, true
|
||||
default:
|
||||
return noop, false
|
||||
}
|
||||
}
|
||||
|
||||
type statusNotifier interface {
|
||||
@@ -98,13 +59,12 @@ type managementClient interface {
|
||||
// backed by underlying NetBird connections.
|
||||
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
|
||||
type NetBird struct {
|
||||
mgmtAddr string
|
||||
proxyID string
|
||||
proxyAddr string
|
||||
wgPort int
|
||||
logger *log.Logger
|
||||
mgmtClient managementClient
|
||||
transportCfg transportConfig
|
||||
mgmtAddr string
|
||||
proxyID string
|
||||
proxyAddr string
|
||||
wgPort int
|
||||
logger *log.Logger
|
||||
mgmtClient managementClient
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[types.AccountID]*clientEntry
|
||||
@@ -154,30 +114,6 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
|
||||
return nil
|
||||
}
|
||||
|
||||
entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID)
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
// retry on the first request via RoundTrip.
|
||||
go n.runClientStartup(ctx, accountID, entry.client)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
||||
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": serviceID,
|
||||
@@ -185,7 +121,8 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
|
||||
privateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate wireguard private key: %w", err)
|
||||
n.clientsMux.Unlock()
|
||||
return fmt.Errorf("generate wireguard private key: %w", err)
|
||||
}
|
||||
publicKey := privateKey.PublicKey()
|
||||
|
||||
@@ -195,6 +132,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
"public_key": publicKey.String(),
|
||||
}).Debug("authenticating new proxy peer with management")
|
||||
|
||||
// Authenticate with management using the one-time token and send public key
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
|
||||
ServiceId: serviceID,
|
||||
AccountId: string(accountID),
|
||||
@@ -203,14 +141,16 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
Cluster: n.proxyAddr,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate proxy peer with management: %w", err)
|
||||
n.clientsMux.Unlock()
|
||||
return fmt.Errorf("authenticate proxy peer with management: %w", err)
|
||||
}
|
||||
if resp != nil && !resp.GetSuccess() {
|
||||
n.clientsMux.Unlock()
|
||||
errMsg := "unknown error"
|
||||
if resp.ErrorMessage != nil {
|
||||
errMsg = *resp.ErrorMessage
|
||||
}
|
||||
return nil, fmt.Errorf("proxy peer authentication failed: %s", errMsg)
|
||||
return fmt.Errorf("proxy peer authentication failed: %s", errMsg)
|
||||
}
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -236,80 +176,95 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
WireguardPort: &n.wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create netbird client: %w", err)
|
||||
n.clientsMux.Unlock()
|
||||
return fmt.Errorf("create netbird client: %w", err)
|
||||
}
|
||||
|
||||
// Create a transport using the client dialer. We do this instead of using
|
||||
// the client's HTTPClient to avoid issues with request validation that do
|
||||
// not work with reverse proxied requests.
|
||||
return &clientEntry{
|
||||
entry = &clientEntry{
|
||||
client: client,
|
||||
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
|
||||
transport: &http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: n.transportCfg.maxIdleConns,
|
||||
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
|
||||
IdleConnTimeout: n.transportCfg.idleConnTimeout,
|
||||
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
|
||||
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
|
||||
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
|
||||
WriteBufferSize: n.transportCfg.writeBufferSize,
|
||||
ReadBufferSize: n.transportCfg.readBufferSize,
|
||||
DisableCompression: n.transportCfg.disableCompression,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
started: false,
|
||||
inflightMap: make(map[backendKey]chan struct{}),
|
||||
maxInflight: n.transportCfg.maxInflight,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// runClientStartup starts the client and notifies registered domains on success.
|
||||
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Start(startCtx); err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
n.logger.WithField("account_id", accountID).Warn("netbird client start timed out, will retry on first request")
|
||||
} else {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Error("failed to start netbird client")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Mark client as started and collect domains to notify outside the lock.
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.started = true
|
||||
}
|
||||
var domainsToNotify []domainNotification
|
||||
if exists {
|
||||
for dom, info := range entry.domains {
|
||||
domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID})
|
||||
}
|
||||
createdAt: time.Now(),
|
||||
started: false,
|
||||
}
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
}
|
||||
for _, dn := range domainsToNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": dn.domain,
|
||||
}).WithError(err).Warn("failed to notify tunnel connection status")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": dn.domain,
|
||||
}).Info("notified management about tunnel connection")
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background, if this fails
|
||||
// then it is not ideal, but it isn't the end of the world because
|
||||
// we will try to start the client again before we use it.
|
||||
go func() {
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Start(startCtx); err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).Warn("netbird client start timed out, will retry on first request")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Error("failed to start netbird client")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Mark client as started and notify all registered domains
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.started = true
|
||||
}
|
||||
// Copy domain info while holding lock
|
||||
var domainsToNotify []struct {
|
||||
domain domain.Domain
|
||||
serviceID string
|
||||
}
|
||||
if exists {
|
||||
for dom, info := range entry.domains {
|
||||
domainsToNotify = append(domainsToNotify, struct {
|
||||
domain domain.Domain
|
||||
serviceID string
|
||||
}{domain: dom, serviceID: info.serviceID})
|
||||
}
|
||||
}
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// Notify all domains that they're connected
|
||||
if n.statusNotifier != nil {
|
||||
for _, domInfo := range domainsToNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(domInfo.domain), true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": domInfo.domain,
|
||||
}).WithError(err).Warn("failed to notify tunnel connection status")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": domInfo.domain,
|
||||
}).Info("notified management about tunnel connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePeer unregisters a domain from an account. The client is only stopped
|
||||
@@ -410,12 +365,6 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
transport := entry.transport
|
||||
n.clientsMux.RUnlock()
|
||||
|
||||
release, ok := entry.acquireInflight(req.URL.Host)
|
||||
defer release()
|
||||
if !ok {
|
||||
return nil, ErrTooManyInflight
|
||||
}
|
||||
|
||||
// Attempt to start the client, if the client is already running then
|
||||
// it will return an error that we ignore, if this hits a timeout then
|
||||
// this request is unprocessable.
|
||||
@@ -552,7 +501,6 @@ func NewNetBird(mgmtAddr, proxyID, proxyAddr string, wgPort int, logger *log.Log
|
||||
clients: make(map[types.AccountID]*clientEntry),
|
||||
statusNotifier: notifier,
|
||||
mgmtClient: mgmtClient,
|
||||
transportCfg: loadTransportConfig(logger),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package roundtrip
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -21,31 +20,6 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
type mockStatusNotifier struct {
|
||||
mu sync.Mutex
|
||||
statuses []statusCall
|
||||
}
|
||||
|
||||
type statusCall struct {
|
||||
accountID string
|
||||
serviceID string
|
||||
domain string
|
||||
connected bool
|
||||
}
|
||||
|
||||
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStatusNotifier) calls() []statusCall {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return append([]statusCall{}, m.statuses...)
|
||||
}
|
||||
|
||||
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
||||
// It uses an invalid management URL to prevent real connections.
|
||||
func mockNetBird() *NetBird {
|
||||
@@ -279,50 +253,3 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer connection found for account")
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
notifier := &mockStatusNotifier{}
|
||||
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first domain — creates a new client entry.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually mark client as started to simulate background startup completing.
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID].started = true
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
// Add second domain — should notify immediately since client is already started.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
|
||||
require.NoError(t, err)
|
||||
|
||||
calls := notifier.calls()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, string(accountID), calls[0].accountID)
|
||||
assert.Equal(t, "svc-2", calls[0].serviceID)
|
||||
assert.Equal(t, "domain2.test", calls[0].domain)
|
||||
assert.True(t, calls[0].connected)
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
notifier := &mockStatusNotifier{}
|
||||
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove one domain — client stays, but disconnection notification fires.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
|
||||
calls := notifier.calls()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "domain1.test", calls[0].domain)
|
||||
assert.False(t, calls[0].connected)
|
||||
}
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Environment variable names for tuning the backend HTTP transport.
|
||||
const (
|
||||
EnvMaxIdleConns = "NB_PROXY_MAX_IDLE_CONNS"
|
||||
EnvMaxIdleConnsPerHost = "NB_PROXY_MAX_IDLE_CONNS_PER_HOST"
|
||||
EnvMaxConnsPerHost = "NB_PROXY_MAX_CONNS_PER_HOST"
|
||||
EnvIdleConnTimeout = "NB_PROXY_IDLE_CONN_TIMEOUT"
|
||||
EnvTLSHandshakeTimeout = "NB_PROXY_TLS_HANDSHAKE_TIMEOUT"
|
||||
EnvExpectContinueTimeout = "NB_PROXY_EXPECT_CONTINUE_TIMEOUT"
|
||||
EnvResponseHeaderTimeout = "NB_PROXY_RESPONSE_HEADER_TIMEOUT"
|
||||
EnvWriteBufferSize = "NB_PROXY_WRITE_BUFFER_SIZE"
|
||||
EnvReadBufferSize = "NB_PROXY_READ_BUFFER_SIZE"
|
||||
EnvDisableCompression = "NB_PROXY_DISABLE_COMPRESSION"
|
||||
EnvMaxInflight = "NB_PROXY_MAX_INFLIGHT"
|
||||
)
|
||||
|
||||
// transportConfig holds tunable parameters for the per-account HTTP transport.
|
||||
type transportConfig struct {
|
||||
maxIdleConns int
|
||||
maxIdleConnsPerHost int
|
||||
maxConnsPerHost int
|
||||
idleConnTimeout time.Duration
|
||||
tlsHandshakeTimeout time.Duration
|
||||
expectContinueTimeout time.Duration
|
||||
responseHeaderTimeout time.Duration
|
||||
writeBufferSize int
|
||||
readBufferSize int
|
||||
disableCompression bool
|
||||
// maxInflight limits per-backend concurrent requests. 0 means unlimited.
|
||||
maxInflight int
|
||||
}
|
||||
|
||||
func defaultTransportConfig() transportConfig {
|
||||
return transportConfig{
|
||||
maxIdleConns: 100,
|
||||
maxIdleConnsPerHost: 100,
|
||||
maxConnsPerHost: 0, // unlimited
|
||||
idleConnTimeout: 90 * time.Second,
|
||||
tlsHandshakeTimeout: 10 * time.Second,
|
||||
expectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func loadTransportConfig(logger *log.Logger) transportConfig {
|
||||
cfg := defaultTransportConfig()
|
||||
|
||||
if v, ok := envInt(EnvMaxIdleConns, logger); ok {
|
||||
cfg.maxIdleConns = v
|
||||
}
|
||||
if v, ok := envInt(EnvMaxIdleConnsPerHost, logger); ok {
|
||||
cfg.maxIdleConnsPerHost = v
|
||||
}
|
||||
if v, ok := envInt(EnvMaxConnsPerHost, logger); ok {
|
||||
cfg.maxConnsPerHost = v
|
||||
}
|
||||
if v, ok := envDuration(EnvIdleConnTimeout, logger); ok {
|
||||
cfg.idleConnTimeout = v
|
||||
}
|
||||
if v, ok := envDuration(EnvTLSHandshakeTimeout, logger); ok {
|
||||
cfg.tlsHandshakeTimeout = v
|
||||
}
|
||||
if v, ok := envDuration(EnvExpectContinueTimeout, logger); ok {
|
||||
cfg.expectContinueTimeout = v
|
||||
}
|
||||
if v, ok := envDuration(EnvResponseHeaderTimeout, logger); ok {
|
||||
cfg.responseHeaderTimeout = v
|
||||
}
|
||||
if v, ok := envInt(EnvWriteBufferSize, logger); ok {
|
||||
cfg.writeBufferSize = v
|
||||
}
|
||||
if v, ok := envInt(EnvReadBufferSize, logger); ok {
|
||||
cfg.readBufferSize = v
|
||||
}
|
||||
if v, ok := envBool(EnvDisableCompression, logger); ok {
|
||||
cfg.disableCompression = v
|
||||
}
|
||||
if v, ok := envInt(EnvMaxInflight, logger); ok {
|
||||
cfg.maxInflight = v
|
||||
}
|
||||
|
||||
logger.WithFields(log.Fields{
|
||||
"max_idle_conns": cfg.maxIdleConns,
|
||||
"max_idle_conns_per_host": cfg.maxIdleConnsPerHost,
|
||||
"max_conns_per_host": cfg.maxConnsPerHost,
|
||||
"idle_conn_timeout": cfg.idleConnTimeout,
|
||||
"tls_handshake_timeout": cfg.tlsHandshakeTimeout,
|
||||
"expect_continue_timeout": cfg.expectContinueTimeout,
|
||||
"response_header_timeout": cfg.responseHeaderTimeout,
|
||||
"write_buffer_size": cfg.writeBufferSize,
|
||||
"read_buffer_size": cfg.readBufferSize,
|
||||
"disable_compression": cfg.disableCompression,
|
||||
"max_inflight": cfg.maxInflight,
|
||||
}).Debug("backend transport configuration")
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func envInt(key string, logger *log.Logger) (int, bool) {
|
||||
s := os.Getenv(key)
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
v, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
logger.Warnf("failed to parse %s=%q as int: %v", key, s, err)
|
||||
return 0, false
|
||||
}
|
||||
if v < 0 {
|
||||
logger.Warnf("ignoring negative value for %s=%d", key, v)
|
||||
return 0, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
func envDuration(key string, logger *log.Logger) (time.Duration, bool) {
|
||||
s := os.Getenv(key)
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
v, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
logger.Warnf("failed to parse %s=%q as duration: %v", key, s, err)
|
||||
return 0, false
|
||||
}
|
||||
if v < 0 {
|
||||
logger.Warnf("ignoring negative value for %s=%s", key, v)
|
||||
return 0, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
func envBool(key string, logger *log.Logger) (bool, bool) {
|
||||
s := os.Getenv(key)
|
||||
if s == "" {
|
||||
return false, false
|
||||
}
|
||||
v, err := strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
logger.Warnf("failed to parse %s=%q as bool: %v", key, s, err)
|
||||
return false, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
21
proxy/log.go
21
proxy/log.go
@@ -1,21 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
stdlog "log"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// HTTP server type identifiers for logging
|
||||
logtagFieldHTTPServer = "http-server"
|
||||
logtagValueHTTPS = "https"
|
||||
logtagValueACME = "acme"
|
||||
logtagValueDebug = "debug"
|
||||
)
|
||||
|
||||
// newHTTPServerLogger creates a standard library logger that writes to logrus
|
||||
// with the specified server type field.
|
||||
func newHTTPServerLogger(logger *log.Logger, serverType string) *stdlog.Logger {
|
||||
return stdlog.New(logger.WithField(logtagFieldHTTPServer, serverType).WriterLevel(log.WarnLevel), "", 0)
|
||||
}
|
||||
@@ -1,106 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
proxyproto "github.com/pires/go-proxyproto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWrapProxyProtocol_OverridesRemoteAddr(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")},
|
||||
ProxyProtocol: true,
|
||||
}
|
||||
|
||||
raw, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer raw.Close()
|
||||
|
||||
ln := srv.wrapProxyProtocol(raw)
|
||||
|
||||
realClientIP := "203.0.113.50"
|
||||
realClientPort := uint16(54321)
|
||||
|
||||
accepted := make(chan net.Conn, 1)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
accepted <- conn
|
||||
}()
|
||||
|
||||
// Connect and send a PROXY v2 header.
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
header := &proxyproto.Header{
|
||||
Version: 2,
|
||||
Command: proxyproto.PROXY,
|
||||
TransportProtocol: proxyproto.TCPv4,
|
||||
SourceAddr: &net.TCPAddr{IP: net.ParseIP(realClientIP), Port: int(realClientPort)},
|
||||
DestinationAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443},
|
||||
}
|
||||
_, err = header.WriteTo(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case accepted := <-accepted:
|
||||
defer accepted.Close()
|
||||
host, _, err := net.SplitHostPort(accepted.RemoteAddr().String())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, realClientIP, host, "RemoteAddr should reflect the PROXY header source IP")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_TrustedRequires(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 1234},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.REQUIRE, policy, "trusted source should require PROXY header")
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_UntrustedIgnores(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.TCPAddr{IP: net.ParseIP("203.0.113.50"), Port: 1234},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.IGNORE, policy, "untrusted source should have PROXY header ignored")
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_InvalidIPRejects(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.REJECT, policy, "unparsable address should be rejected")
|
||||
}
|
||||
446
proxy/server.go
446
proxy/server.go
@@ -23,7 +23,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
proxyproto "github.com/pires/go-proxyproto"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -93,7 +92,7 @@ type Server struct {
|
||||
DebugEndpointEnabled bool
|
||||
// DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444").
|
||||
DebugEndpointAddress string
|
||||
// HealthAddress is the address for the health probe endpoint.
|
||||
// HealthAddress is the address for the health probe endpoint (default: "localhost:8080").
|
||||
HealthAddress string
|
||||
// ProxyToken is the access token for authenticating with the management server.
|
||||
ProxyToken string
|
||||
@@ -108,10 +107,6 @@ type Server struct {
|
||||
// random OS-assigned port. A fixed port only works with single-account
|
||||
// deployments; multiple accounts will fail to bind the same port.
|
||||
WireguardPort int
|
||||
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
|
||||
// When enabled, the real client IP is extracted from the PROXY header
|
||||
// sent by upstream L4 proxies that support PROXY protocol.
|
||||
ProxyProtocol bool
|
||||
}
|
||||
|
||||
// NotifyStatus sends a status update to management about tunnel connectivity
|
||||
@@ -142,87 +137,6 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service
|
||||
}
|
||||
|
||||
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
s.initDefaults()
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
s.meter = metrics.New(reg)
|
||||
|
||||
mgmtConn, err := s.dialManagement()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := mgmtConn.Close(); err != nil {
|
||||
s.Logger.Debugf("management connection close: %v", err)
|
||||
}
|
||||
}()
|
||||
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
|
||||
go s.newManagementMappingWorker(ctx, s.mgmtClient)
|
||||
|
||||
// Initialize the netbird client, this is required to build peer connections
|
||||
// to proxy over.
|
||||
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient)
|
||||
|
||||
tlsConfig, err := s.configureTLS(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
|
||||
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
|
||||
|
||||
// Configure the authentication middleware with session validator for OIDC group checks.
|
||||
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient)
|
||||
|
||||
// Configure Access logs to management server.
|
||||
accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
|
||||
|
||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||
|
||||
s.startDebugEndpoint()
|
||||
|
||||
if err := s.startHealthServer(reg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start the reverse proxy HTTPS server.
|
||||
s.https = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: s.meter.Middleware(accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy)))),
|
||||
TLSConfig: tlsConfig,
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
|
||||
}
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
ln, err := lc.Listen(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", addr, err)
|
||||
}
|
||||
if s.ProxyProtocol {
|
||||
ln = s.wrapProxyProtocol(ln)
|
||||
}
|
||||
|
||||
httpsErr := make(chan error, 1)
|
||||
go func() {
|
||||
s.Logger.Debugf("starting reverse proxy server on %s", addr)
|
||||
httpsErr <- s.https.ServeTLS(ln, "", "")
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-httpsErr:
|
||||
s.shutdownServices()
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("https server: %w", err)
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
s.gracefulShutdown()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// initDefaults sets fallback values for optional Server fields.
|
||||
func (s *Server) initDefaults() {
|
||||
s.startTime = time.Now()
|
||||
|
||||
// If no ID is set then one can be generated.
|
||||
@@ -238,36 +152,141 @@ func (s *Server) initDefaults() {
|
||||
if s.Logger == nil {
|
||||
s.Logger = log.StandardLogger()
|
||||
}
|
||||
}
|
||||
|
||||
// startDebugEndpoint launches the debug HTTP server if enabled.
|
||||
func (s *Server) startDebugEndpoint() {
|
||||
if !s.DebugEndpointEnabled {
|
||||
return
|
||||
// Start up metrics gathering
|
||||
reg := prometheus.NewRegistry()
|
||||
s.meter = metrics.New(reg)
|
||||
|
||||
// The very first thing to do should be to connect to the Management server.
|
||||
// Without this connection, the Proxy cannot do anything.
|
||||
mgmtURL, err := url.Parse(s.ManagementAddress)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse management address: %w", err)
|
||||
}
|
||||
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
|
||||
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger)
|
||||
if s.acme != nil {
|
||||
debugHandler.SetCertStatus(s.acme)
|
||||
creds := insecure.NewCredentials()
|
||||
// Simple TLS check using management URL.
|
||||
// Assume management TLS is enabled for gRPC as well if using HTTPS for the API.
|
||||
if mgmtURL.Scheme == "https" {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
// Fall back to embedded CAs if no OS-provided ones are available.
|
||||
certPool = embeddedroots.Get()
|
||||
}
|
||||
|
||||
creds = credentials.NewTLS(&tls.Config{
|
||||
RootCAs: certPool,
|
||||
})
|
||||
}
|
||||
s.debug = &http.Server{
|
||||
Addr: debugAddr,
|
||||
Handler: debugHandler,
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug),
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"gRPC_address": mgmtURL.Host,
|
||||
"TLS_enabled": mgmtURL.Scheme == "https",
|
||||
}).Debug("starting management gRPC client")
|
||||
mgmtConn, err := grpc.NewClient(mgmtURL.Host,
|
||||
grpc.WithTransportCredentials(creds),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 20 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}),
|
||||
proxygrpc.WithProxyToken(s.ProxyToken),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create management connection: %w", err)
|
||||
}
|
||||
go func() {
|
||||
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
|
||||
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Errorf("debug endpoint error: %v", err)
|
||||
defer func() {
|
||||
if err := mgmtConn.Close(); err != nil {
|
||||
s.Logger.Debugf("management connection close: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
|
||||
go s.newManagementMappingWorker(ctx, s.mgmtClient)
|
||||
|
||||
// startHealthServer launches the health probe and metrics server.
|
||||
func (s *Server) startHealthServer(reg *prometheus.Registry) error {
|
||||
// Initialize the netbird client, this is required to build peer connections
|
||||
// to proxy over.
|
||||
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient)
|
||||
|
||||
// When generating ACME certificates, start a challenge server.
|
||||
tlsConfig := &tls.Config{}
|
||||
if s.GenerateACMECertificates {
|
||||
// Default to TLS-ALPN-01 challenge if not specified
|
||||
if s.ACMEChallengeType == "" {
|
||||
s.ACMEChallengeType = "tls-alpn-01"
|
||||
}
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"acme_server": s.ACMEDirectory,
|
||||
"challenge_type": s.ACMEChallengeType,
|
||||
}).Debug("ACME certificates enabled, configuring certificate manager")
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod)
|
||||
|
||||
// Only start HTTP server for HTTP-01 challenge type
|
||||
if s.ACMEChallengeType == "http-01" {
|
||||
s.http = &http.Server{
|
||||
Addr: s.ACMEChallengeAddress,
|
||||
Handler: s.acme.HTTPHandler(nil),
|
||||
}
|
||||
go func() {
|
||||
if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed")
|
||||
}
|
||||
}()
|
||||
}
|
||||
tlsConfig = s.acme.TLSConfig()
|
||||
|
||||
// ServerName needs to be set to allow for ACME to work correctly
|
||||
// when using CNAME URLs to access the proxy.
|
||||
tlsConfig.ServerName = s.ProxyURL
|
||||
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"ServerName": s.ProxyURL,
|
||||
"challenge_type": s.ACMEChallengeType,
|
||||
}).Debug("ACME certificate manager configured")
|
||||
} else {
|
||||
s.Logger.Debug("ACME certificates disabled, using static certificates with file watching")
|
||||
certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile)
|
||||
keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile)
|
||||
|
||||
certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize certificate watcher: %w", err)
|
||||
}
|
||||
go certWatcher.Watch(ctx)
|
||||
|
||||
tlsConfig.GetCertificate = certWatcher.GetCertificate
|
||||
}
|
||||
|
||||
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
|
||||
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
|
||||
|
||||
// Configure the authentication middleware with session validator for OIDC group checks.
|
||||
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient)
|
||||
|
||||
// Configure Access logs to management server.
|
||||
accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
|
||||
|
||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||
|
||||
if s.DebugEndpointEnabled {
|
||||
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
|
||||
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger)
|
||||
if s.acme != nil {
|
||||
debugHandler.SetCertStatus(s.acme)
|
||||
}
|
||||
s.debug = &http.Server{
|
||||
Addr: debugAddr,
|
||||
Handler: debugHandler,
|
||||
}
|
||||
go func() {
|
||||
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
|
||||
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Errorf("debug endpoint error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Start health probe server.
|
||||
healthAddr := s.HealthAddress
|
||||
if healthAddr == "" {
|
||||
healthAddr = defaultHealthAddr
|
||||
healthAddr = "localhost:8080"
|
||||
}
|
||||
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
||||
healthListener, err := net.Listen("tcp", healthAddr)
|
||||
@@ -279,57 +298,34 @@ func (s *Server) startHealthServer(reg *prometheus.Registry) error {
|
||||
s.Logger.Errorf("health probe server: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// wrapProxyProtocol wraps a listener with PROXY protocol support.
|
||||
// When TrustedProxies is configured, only those sources may send PROXY headers;
|
||||
// connections from untrusted sources have any PROXY header ignored.
|
||||
func (s *Server) wrapProxyProtocol(ln net.Listener) net.Listener {
|
||||
ppListener := &proxyproto.Listener{
|
||||
Listener: ln,
|
||||
ReadHeaderTimeout: proxyProtoHeaderTimeout,
|
||||
// Start the reverse proxy HTTPS server.
|
||||
s.https = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: s.meter.Middleware(accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy)))),
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
if len(s.TrustedProxies) > 0 {
|
||||
ppListener.ConnPolicy = s.proxyProtocolPolicy
|
||||
} else {
|
||||
s.Logger.Warn("PROXY protocol enabled without trusted proxies; any source may send PROXY headers")
|
||||
}
|
||||
s.Logger.Info("PROXY protocol enabled on listener")
|
||||
return ppListener
|
||||
}
|
||||
|
||||
// proxyProtocolPolicy returns whether to require, skip, or reject the PROXY
|
||||
// header based on whether the connection source is in TrustedProxies.
|
||||
func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) {
|
||||
// No logging on reject to prevent abuse
|
||||
tcpAddr, ok := opts.Upstream.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return proxyproto.REJECT, nil
|
||||
}
|
||||
addr, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||
if !ok {
|
||||
return proxyproto.REJECT, nil
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
httpsErr := make(chan error, 1)
|
||||
go func() {
|
||||
s.Logger.Debugf("starting reverse proxy server on %s", addr)
|
||||
httpsErr <- s.https.ListenAndServeTLS("", "")
|
||||
}()
|
||||
|
||||
// called per accept
|
||||
for _, prefix := range s.TrustedProxies {
|
||||
if prefix.Contains(addr) {
|
||||
return proxyproto.REQUIRE, nil
|
||||
select {
|
||||
case err := <-httpsErr:
|
||||
s.shutdownServices()
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("https server: %w", err)
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
s.gracefulShutdown()
|
||||
return nil
|
||||
}
|
||||
return proxyproto.IGNORE, nil
|
||||
}
|
||||
|
||||
const (
|
||||
defaultHealthAddr = "localhost:8080"
|
||||
defaultDebugAddr = "localhost:8444"
|
||||
|
||||
// proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol
|
||||
// header after accepting a connection.
|
||||
proxyProtoHeaderTimeout = 5 * time.Second
|
||||
|
||||
// shutdownPreStopDelay is the time to wait after receiving a shutdown signal
|
||||
// before draining connections. This allows the load balancer to propagate
|
||||
// the endpoint removal.
|
||||
@@ -344,92 +340,6 @@ const (
|
||||
shutdownServiceTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
func (s *Server) dialManagement() (*grpc.ClientConn, error) {
|
||||
mgmtURL, err := url.Parse(s.ManagementAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse management address: %w", err)
|
||||
}
|
||||
creds := insecure.NewCredentials()
|
||||
// Assume management TLS is enabled for gRPC as well if using HTTPS for the API.
|
||||
if mgmtURL.Scheme == "https" {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
// Fall back to embedded CAs if no OS-provided ones are available.
|
||||
certPool = embeddedroots.Get()
|
||||
}
|
||||
creds = credentials.NewTLS(&tls.Config{
|
||||
RootCAs: certPool,
|
||||
})
|
||||
}
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"gRPC_address": mgmtURL.Host,
|
||||
"TLS_enabled": mgmtURL.Scheme == "https",
|
||||
}).Debug("starting management gRPC client")
|
||||
conn, err := grpc.NewClient(mgmtURL.Host,
|
||||
grpc.WithTransportCredentials(creds),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 20 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}),
|
||||
proxygrpc.WithProxyToken(s.ProxyToken),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create management connection: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
|
||||
tlsConfig := &tls.Config{}
|
||||
if !s.GenerateACMECertificates {
|
||||
s.Logger.Debug("ACME certificates disabled, using static certificates with file watching")
|
||||
certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile)
|
||||
keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile)
|
||||
|
||||
certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initialize certificate watcher: %w", err)
|
||||
}
|
||||
go certWatcher.Watch(ctx)
|
||||
tlsConfig.GetCertificate = certWatcher.GetCertificate
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
if s.ACMEChallengeType == "" {
|
||||
s.ACMEChallengeType = "tls-alpn-01"
|
||||
}
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"acme_server": s.ACMEDirectory,
|
||||
"challenge_type": s.ACMEChallengeType,
|
||||
}).Debug("ACME certificates enabled, configuring certificate manager")
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod)
|
||||
|
||||
if s.ACMEChallengeType == "http-01" {
|
||||
s.http = &http.Server{
|
||||
Addr: s.ACMEChallengeAddress,
|
||||
Handler: s.acme.HTTPHandler(nil),
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueACME),
|
||||
}
|
||||
go func() {
|
||||
if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed")
|
||||
}
|
||||
}()
|
||||
}
|
||||
tlsConfig = s.acme.TLSConfig()
|
||||
|
||||
// ServerName needs to be set to allow for ACME to work correctly
|
||||
// when using CNAME URLs to access the proxy.
|
||||
tlsConfig.ServerName = s.ProxyURL
|
||||
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"ServerName": s.ProxyURL,
|
||||
"challenge_type": s.ACMEChallengeType,
|
||||
}).Debug("ACME certificate manager configured")
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// gracefulShutdown performs a zero-downtime shutdown sequence. It marks the
|
||||
// readiness probe as failing, waits for load balancer propagation, drains
|
||||
// in-flight connections, and then stops all background services.
|
||||
@@ -582,7 +492,36 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
return fmt.Errorf("receive msg: %w", err)
|
||||
}
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
// Process msg updates sequentially to avoid conflict, so block
|
||||
// additional receiving until this processing is completed.
|
||||
for _, mapping := range msg.GetMapping() {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"type": mapping.GetType(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"path": mapping.GetPath(),
|
||||
"id": mapping.GetId(),
|
||||
}).Debug("Processing mapping update")
|
||||
switch mapping.GetType() {
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
|
||||
if err := s.addMapping(ctx, mapping); err != nil {
|
||||
// TODO: Retry this? Or maybe notify the management server that this mapping has failed?
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"error": err,
|
||||
}).Error("Error adding new mapping, ignoring this mapping and continuing processing")
|
||||
}
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
|
||||
if err := s.updateMapping(ctx, mapping); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
}).Errorf("failed to update mapping: %v", err)
|
||||
}
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED:
|
||||
s.removeMapping(ctx, mapping)
|
||||
}
|
||||
}
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
|
||||
if !*initialSyncDone && msg.GetInitialSyncComplete() {
|
||||
@@ -596,37 +535,6 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
for _, mapping := range mappings {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"type": mapping.GetType(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"path": mapping.GetPath(),
|
||||
"id": mapping.GetId(),
|
||||
}).Debug("Processing mapping update")
|
||||
switch mapping.GetType() {
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
|
||||
if err := s.addMapping(ctx, mapping); err != nil {
|
||||
// TODO: Retry this? Or maybe notify the management server that this mapping has failed?
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"error": err,
|
||||
}).Error("Error adding new mapping, ignoring this mapping and continuing processing")
|
||||
}
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
|
||||
if err := s.updateMapping(ctx, mapping); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
}).Errorf("failed to update mapping: %v", err)
|
||||
}
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED:
|
||||
s.removeMapping(ctx, mapping)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
||||
d := domain.Domain(mapping.GetDomain())
|
||||
accountID := types.AccountID(mapping.GetAccountId())
|
||||
@@ -725,7 +633,7 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
||||
// If addr is empty, it defaults to localhost:8444 for security.
|
||||
func debugEndpointAddr(addr string) string {
|
||||
if addr == "" {
|
||||
return defaultDebugAddr
|
||||
return "localhost:8444"
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
6
proxy/web/dist/assets/index.js
vendored
6
proxy/web/dist/assets/index.js
vendored
File diff suppressed because one or more lines are too long
2
proxy/web/dist/assets/style.css
vendored
2
proxy/web/dist/assets/style.css
vendored
File diff suppressed because one or more lines are too long
@@ -59,7 +59,7 @@ function App() {
|
||||
formData.append(methods.pin!, value);
|
||||
}
|
||||
|
||||
fetch(globalThis.location.href, {
|
||||
fetch(window.location.href, {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
redirect: "manual",
|
||||
@@ -67,7 +67,7 @@ function App() {
|
||||
.then((res) => {
|
||||
if (res.type === "opaqueredirect" || res.status === 0) {
|
||||
setSubmitting("redirect");
|
||||
globalThis.location.reload();
|
||||
window.location.reload();
|
||||
} else {
|
||||
handleAuthError(method, "Authentication failed. Please try again.");
|
||||
}
|
||||
@@ -92,7 +92,6 @@ function App() {
|
||||
|
||||
const hasCredentialAuth = methods.password || methods.pin;
|
||||
const hasBothCredentials = methods.password && methods.pin;
|
||||
const buttonLabel = activeTab === "password" ? "Sign in" : "Submit";
|
||||
|
||||
if (submitting === "redirect") {
|
||||
return (
|
||||
@@ -125,7 +124,7 @@ function App() {
|
||||
<Button
|
||||
variant="primary"
|
||||
className="w-full"
|
||||
onClick={() => { globalThis.location.href = methods.oidc!; }}
|
||||
onClick={() => (window.location.href = methods.oidc!)}
|
||||
>
|
||||
<LogIn size={16} />
|
||||
Sign in with SSO
|
||||
@@ -171,7 +170,7 @@ function App() {
|
||||
<div className="mb-4">
|
||||
{methods.password && (activeTab === "password" || !methods.pin) && (
|
||||
<>
|
||||
{!hasBothCredentials && <Label htmlFor="password">Password</Label>}
|
||||
{!hasBothCredentials && <Label>Password</Label>}
|
||||
<Input
|
||||
ref={passwordRef}
|
||||
type="password"
|
||||
@@ -187,7 +186,7 @@ function App() {
|
||||
)}
|
||||
{methods.pin && (activeTab === "pin" || !methods.password) && (
|
||||
<>
|
||||
{!hasBothCredentials && <Label htmlFor="pin-0">Enter PIN Code</Label>}
|
||||
{!hasBothCredentials && <Label>Enter PIN Code</Label>}
|
||||
<PinCodeInput
|
||||
ref={pinRef}
|
||||
value={pin}
|
||||
@@ -205,13 +204,13 @@ function App() {
|
||||
variant="secondary"
|
||||
className="w-full"
|
||||
>
|
||||
{submitting === null ? (
|
||||
buttonLabel
|
||||
) : (
|
||||
{submitting !== null ? (
|
||||
<>
|
||||
<Loader2 className="animate-spin" size={16} />
|
||||
Verifying...
|
||||
</>
|
||||
) : (
|
||||
activeTab === "password" ? "Sign in" : "Submit"
|
||||
)}
|
||||
</Button>
|
||||
</form>
|
||||
|
||||
@@ -7,7 +7,7 @@ import { PoweredByNetBird } from "@/components/PoweredByNetBird";
|
||||
import { StatusCard } from "@/components/StatusCard";
|
||||
import type { ErrorData } from "@/data";
|
||||
|
||||
export function ErrorPage({ code, title, message, proxy = true, destination = true, requestId, simple = false, retryUrl }: Readonly<ErrorData>) {
|
||||
export function ErrorPage({ code, title, message, proxy = true, destination = true, requestId, simple = false, retryUrl }: ErrorData) {
|
||||
useEffect(() => {
|
||||
document.title = `${title} - NetBird Service`;
|
||||
}, [title]);
|
||||
@@ -38,19 +38,13 @@ export function ErrorPage({ code, title, message, proxy = true, destination = tr
|
||||
|
||||
{/* Buttons */}
|
||||
<div className="flex gap-3 justify-center items-center mb-6 z-10 relative">
|
||||
<Button variant="primary" onClick={() => {
|
||||
if (retryUrl) {
|
||||
globalThis.location.href = retryUrl;
|
||||
} else {
|
||||
globalThis.location.reload();
|
||||
}
|
||||
}}>
|
||||
<Button variant="primary" onClick={() => retryUrl ? window.location.href = retryUrl : window.location.reload()}>
|
||||
<RotateCw size={16} />
|
||||
Refresh Page
|
||||
</Button>
|
||||
<Button
|
||||
variant="secondary"
|
||||
onClick={() => globalThis.open("https://docs.netbird.io", "_blank", "noopener,noreferrer")}
|
||||
onClick={() => window.open("https://docs.netbird.io", "_blank", "noopener,noreferrer")}
|
||||
>
|
||||
<BookText size={16} />
|
||||
Documentation
|
||||
|
||||
@@ -4,7 +4,7 @@ interface ConnectionLineProps {
|
||||
success?: boolean;
|
||||
}
|
||||
|
||||
export function ConnectionLine({ success = true }: Readonly<ConnectionLineProps>) {
|
||||
export function ConnectionLine({ success = true }: ConnectionLineProps) {
|
||||
if (success) {
|
||||
return (
|
||||
<div className="flex-1 flex items-center justify-center h-12 w-full px-5">
|
||||
|
||||
@@ -5,7 +5,7 @@ type Props = {
|
||||
className?: string;
|
||||
};
|
||||
|
||||
export function Description({ children, className }: Readonly<Props>) {
|
||||
export function Description({ children, className }: Props) {
|
||||
return (
|
||||
<div className={cn("text-sm text-nb-gray-300 font-light mt-2 block text-center z-10 relative", className)}>
|
||||
{children}
|
||||
|
||||
@@ -5,7 +5,7 @@ interface HelpTextProps {
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function HelpText({ children, className }: Readonly<HelpTextProps>) {
|
||||
export default function HelpText({ children, className }: HelpTextProps) {
|
||||
return (
|
||||
<span
|
||||
className={cn(
|
||||
|
||||
@@ -2,10 +2,9 @@ import { cn } from "@/utils/helpers";
|
||||
|
||||
type LabelProps = React.LabelHTMLAttributes<HTMLLabelElement>;
|
||||
|
||||
export function Label({ className, htmlFor, ...props }: Readonly<LabelProps>) {
|
||||
export function Label({ className, ...props }: LabelProps) {
|
||||
return (
|
||||
<label
|
||||
htmlFor={htmlFor}
|
||||
className={cn(
|
||||
"text-sm font-medium tracking-wider leading-none",
|
||||
"peer-disabled:cursor-not-allowed peer-disabled:opacity-70",
|
||||
|
||||
@@ -20,7 +20,7 @@ interface Props {
|
||||
autoFocus?: boolean;
|
||||
}
|
||||
|
||||
const PinCodeInput = forwardRef<PinCodeInputRef, Readonly<Props>>(function PinCodeInput(
|
||||
const PinCodeInput = forwardRef<PinCodeInputRef, Props>(function PinCodeInput(
|
||||
{ value, onChange, length = 6, disabled = false, className, autoFocus = false },
|
||||
ref,
|
||||
) {
|
||||
@@ -32,15 +32,14 @@ const PinCodeInput = forwardRef<PinCodeInputRef, Readonly<Props>>(function PinCo
|
||||
},
|
||||
}));
|
||||
|
||||
const digits = value.split("").concat(new Array(length).fill("")).slice(0, length);
|
||||
const slotIds = Array.from({ length }, (_, i) => `pin-${i}`);
|
||||
const digits = value.split("").concat(Array(length).fill("")).slice(0, length);
|
||||
|
||||
const handleChange = (index: number, digit: string) => {
|
||||
if (!/^\d*$/.test(digit)) return;
|
||||
|
||||
const newDigits = [...digits];
|
||||
newDigits[index] = digit.slice(-1);
|
||||
const newValue = newDigits.join("").replaceAll(/\s/g, "");
|
||||
const newValue = newDigits.join("").replace(/\s/g, "");
|
||||
onChange(newValue);
|
||||
|
||||
if (digit && index < length - 1) {
|
||||
@@ -62,7 +61,7 @@ const PinCodeInput = forwardRef<PinCodeInputRef, Readonly<Props>>(function PinCo
|
||||
|
||||
const handlePaste = (e: ClipboardEvent<HTMLInputElement>) => {
|
||||
e.preventDefault();
|
||||
const pastedData = e.clipboardData.getData("text").replaceAll(/\D/g, "").slice(0, length);
|
||||
const pastedData = e.clipboardData.getData("text").replace(/\D/g, "").slice(0, length);
|
||||
onChange(pastedData);
|
||||
|
||||
const nextIndex = Math.min(pastedData.length, length - 1);
|
||||
@@ -77,8 +76,7 @@ const PinCodeInput = forwardRef<PinCodeInputRef, Readonly<Props>>(function PinCo
|
||||
<div className={cn("flex gap-2 w-full min-w-0", className)}>
|
||||
{digits.map((digit, index) => (
|
||||
<input
|
||||
key={slotIds[index]}
|
||||
id={slotIds[index]}
|
||||
key={index}
|
||||
ref={(el) => {
|
||||
inputRefs.current[index] = el;
|
||||
}}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { cn } from "@/utils/helpers";
|
||||
import { useState, useMemo, useCallback } from "react";
|
||||
import { useState } from "react";
|
||||
import { TabContext, useTabContext } from "./TabContext";
|
||||
|
||||
type TabsProps = {
|
||||
@@ -11,24 +11,19 @@ type TabsProps = {
|
||||
| ((context: { value: string; onChange: (value: string) => void }) => React.ReactNode);
|
||||
};
|
||||
|
||||
function SegmentedTabs({ value, defaultValue, onChange, children }: Readonly<TabsProps>) {
|
||||
const [internalValue, setInternalValue] = useState(defaultValue ?? "");
|
||||
const currentValue = value ?? internalValue;
|
||||
function SegmentedTabs({ value, defaultValue, onChange, children }: TabsProps) {
|
||||
const [internalValue, setInternalValue] = useState(defaultValue || "");
|
||||
const currentValue = value !== undefined ? value : internalValue;
|
||||
|
||||
const handleChange = useCallback((newValue: string) => {
|
||||
const handleChange = (newValue: string) => {
|
||||
if (value === undefined) {
|
||||
setInternalValue(newValue);
|
||||
}
|
||||
onChange?.(newValue);
|
||||
}, [value, onChange]);
|
||||
|
||||
const contextValue = useMemo(
|
||||
() => ({ value: currentValue, onChange: handleChange }),
|
||||
[currentValue, handleChange],
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<TabContext.Provider value={contextValue}>
|
||||
<TabContext.Provider value={{ value: currentValue, onChange: handleChange }}>
|
||||
<div>
|
||||
{typeof children === "function"
|
||||
? children({ value: currentValue, onChange: handleChange })
|
||||
@@ -41,10 +36,10 @@ function SegmentedTabs({ value, defaultValue, onChange, children }: Readonly<Tab
|
||||
function List({
|
||||
children,
|
||||
className,
|
||||
}: Readonly<{
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
className?: string;
|
||||
}>) {
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
role="tablist"
|
||||
@@ -65,23 +60,16 @@ function Trigger({
|
||||
className,
|
||||
selected,
|
||||
onClick,
|
||||
}: Readonly<{
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
value: string;
|
||||
disabled?: boolean;
|
||||
className?: string;
|
||||
selected?: boolean;
|
||||
onClick?: () => void;
|
||||
}>) {
|
||||
}) {
|
||||
const context = useTabContext();
|
||||
const isSelected = selected ?? value === context.value;
|
||||
|
||||
let stateClassName = "";
|
||||
if (isSelected) {
|
||||
stateClassName = "bg-nb-gray-900 text-white";
|
||||
} else if (!disabled) {
|
||||
stateClassName = "text-nb-gray-400 hover:bg-nb-gray-900/50";
|
||||
}
|
||||
const isSelected = selected !== undefined ? selected : value === context.value;
|
||||
|
||||
const handleClick = () => {
|
||||
context.onChange(value);
|
||||
@@ -98,7 +86,11 @@ function Trigger({
|
||||
className={cn(
|
||||
"px-4 py-2 text-sm rounded-md w-full transition-all cursor-pointer",
|
||||
disabled && "opacity-30 cursor-not-allowed",
|
||||
stateClassName,
|
||||
isSelected
|
||||
? "bg-nb-gray-900 text-white"
|
||||
: disabled
|
||||
? ""
|
||||
: "text-nb-gray-400 hover:bg-nb-gray-900/50",
|
||||
className
|
||||
)}
|
||||
>
|
||||
@@ -114,14 +106,14 @@ function Content({
|
||||
value,
|
||||
className,
|
||||
visible,
|
||||
}: Readonly<{
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
value: string;
|
||||
className?: string;
|
||||
visible?: boolean;
|
||||
}>) {
|
||||
}) {
|
||||
const context = useTabContext();
|
||||
const isVisible = visible ?? value === context.value;
|
||||
const isVisible = visible !== undefined ? visible : value === context.value;
|
||||
|
||||
if (!isVisible) return null;
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ export function StatusCard({
|
||||
detail,
|
||||
success = true,
|
||||
line = true,
|
||||
}: Readonly<StatusCardProps>) {
|
||||
}: StatusCardProps) {
|
||||
return (
|
||||
<>
|
||||
{line && <ConnectionLine success={success} />}
|
||||
|
||||
@@ -5,7 +5,7 @@ type Props = {
|
||||
className?: string;
|
||||
};
|
||||
|
||||
export function Title({ children, className }: Readonly<Props>) {
|
||||
export function Title({ children, className }: Props) {
|
||||
return (
|
||||
<h1 className={cn("text-xl! text-center z-10 relative", className)}>
|
||||
{children}
|
||||
|
||||
@@ -24,16 +24,17 @@ export interface Data {
|
||||
}
|
||||
|
||||
declare global {
|
||||
// eslint-disable-next-line no-var
|
||||
var __DATA__: Data | undefined
|
||||
interface Window {
|
||||
__DATA__?: Data
|
||||
}
|
||||
}
|
||||
|
||||
export function getData(): Data {
|
||||
const data = globalThis.__DATA__ ?? {}
|
||||
const data = window.__DATA__ ?? {}
|
||||
|
||||
// Dev mode: allow ?page=error query param to preview error page
|
||||
if (import.meta.env.DEV) {
|
||||
const params = new URLSearchParams(globalThis.location.search)
|
||||
const params = new URLSearchParams(window.location.search)
|
||||
const page = params.get('page')
|
||||
if (page === 'error') {
|
||||
return {
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
@import "tailwindcss";
|
||||
|
||||
@custom-variant dark (&:where(.dark, .dark *));
|
||||
|
||||
@font-face {
|
||||
font-family: "Inter";
|
||||
font-style: normal;
|
||||
@@ -186,6 +184,9 @@
|
||||
|
||||
html {
|
||||
color-scheme: dark;
|
||||
}
|
||||
|
||||
html{
|
||||
@apply bg-nb-gray;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
import tailwindcss from '@tailwindcss/vite'
|
||||
import path from 'node:path'
|
||||
import path from 'path'
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [react(), tailwindcss()],
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// BillingAPI APIs for billing and invoices
|
||||
type BillingAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// GetUsage retrieves current usage statistics for the account
|
||||
// See more: https://docs.netbird.io/api/resources/billing#get-current-usage
|
||||
func (a *BillingAPI) GetUsage(ctx context.Context) (*api.UsageStats, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/usage", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.UsageStats](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// GetSubscription retrieves the current subscription details
|
||||
// See more: https://docs.netbird.io/api/resources/billing#get-current-subscription
|
||||
func (a *BillingAPI) GetSubscription(ctx context.Context) (*api.Subscription, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/subscription", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Subscription](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// GetInvoices retrieves the account's paid invoices
|
||||
// See more: https://docs.netbird.io/api/resources/billing#list-all-invoices
|
||||
func (a *BillingAPI) GetInvoices(ctx context.Context) ([]api.InvoiceResponse, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/invoices", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.InvoiceResponse](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// GetInvoicePDF retrieves the invoice PDF URL
|
||||
// See more: https://docs.netbird.io/api/resources/billing#get-invoice-pdf
|
||||
func (a *BillingAPI) GetInvoicePDF(ctx context.Context, invoiceID string) (*api.InvoicePDFResponse, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/invoices/"+invoiceID+"/pdf", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.InvoicePDFResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// GetInvoiceCSV retrieves the invoice CSV content
|
||||
// See more: https://docs.netbird.io/api/resources/billing#get-invoice-csv
|
||||
func (a *BillingAPI) GetInvoiceCSV(ctx context.Context, invoiceID string) (string, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/invoices/"+invoiceID+"/csv", nil, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[string](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testUsageStats = api.UsageStats{
|
||||
ActiveUsers: 15,
|
||||
TotalUsers: 20,
|
||||
ActivePeers: 10,
|
||||
TotalPeers: 25,
|
||||
}
|
||||
|
||||
testSubscription = api.Subscription{
|
||||
Active: true,
|
||||
PlanTier: "basic",
|
||||
PriceId: "price_1HhxOp",
|
||||
Currency: "USD",
|
||||
Price: 1000,
|
||||
Provider: "stripe",
|
||||
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
testInvoice = api.InvoiceResponse{
|
||||
Id: "inv_123",
|
||||
PeriodStart: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
PeriodEnd: time.Date(2024, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
Type: "invoice",
|
||||
}
|
||||
|
||||
testInvoicePDF = api.InvoicePDFResponse{
|
||||
Url: "https://example.com/invoice.pdf",
|
||||
}
|
||||
)
|
||||
|
||||
func TestBilling_GetUsage_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/billing/usage", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal(testUsageStats)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Billing.GetUsage(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testUsageStats, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBilling_GetUsage_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/billing/usage", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Billing.GetUsage(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBilling_GetSubscription_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/billing/subscription", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal(testSubscription)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Billing.GetSubscription(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testSubscription, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBilling_GetSubscription_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/billing/subscription", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Billing.GetSubscription(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBilling_GetInvoices_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/billing/invoices", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal([]api.InvoiceResponse{testInvoice})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Billing.GetInvoices(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testInvoice, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestBilling_GetInvoices_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/billing/invoices", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Billing.GetInvoices(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBilling_GetInvoicePDF_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/billing/invoices/inv_123/pdf", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal(testInvoicePDF)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Billing.GetInvoicePDF(context.Background(), "inv_123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testInvoicePDF, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBilling_GetInvoicePDF_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/billing/invoices/inv_123/pdf", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Billing.GetInvoicePDF(context.Background(), "inv_123")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBilling_GetInvoiceCSV_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/billing/invoices/inv_123/csv", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal("col1,col2\nval1,val2")
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Billing.GetInvoiceCSV(context.Background(), "inv_123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "col1,col2\nval1,val2", ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBilling_GetInvoiceCSV_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/billing/invoices/inv_123/csv", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Billing.GetInvoiceCSV(context.Background(), "inv_123")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
@@ -73,38 +73,6 @@ type Client struct {
|
||||
// Events NetBird Events APIs
|
||||
// see more: https://docs.netbird.io/api/resources/events
|
||||
Events *EventsAPI
|
||||
|
||||
// Billing NetBird Billing APIs for subscriptions, plans, and invoices
|
||||
// see more: https://docs.netbird.io/api/resources/billing
|
||||
Billing *BillingAPI
|
||||
|
||||
// MSP NetBird MSP tenant management APIs
|
||||
// see more: https://docs.netbird.io/api/resources/msp
|
||||
MSP *MSPAPI
|
||||
|
||||
// EDR NetBird EDR integration APIs (Intune, SentinelOne, Falcon, Huntress)
|
||||
// see more: https://docs.netbird.io/api/resources/edr
|
||||
EDR *EDRAPI
|
||||
|
||||
// SCIM NetBird SCIM IDP integration APIs
|
||||
// see more: https://docs.netbird.io/api/resources/scim
|
||||
SCIM *SCIMAPI
|
||||
|
||||
// EventStreaming NetBird Event Streaming integration APIs
|
||||
// see more: https://docs.netbird.io/api/resources/event-streaming
|
||||
EventStreaming *EventStreamingAPI
|
||||
|
||||
// IdentityProviders NetBird Identity Providers APIs
|
||||
// see more: https://docs.netbird.io/api/resources/identity-providers
|
||||
IdentityProviders *IdentityProvidersAPI
|
||||
|
||||
// Ingress NetBird Ingress Peers APIs
|
||||
// see more: https://docs.netbird.io/api/resources/ingress-ports
|
||||
Ingress *IngressAPI
|
||||
|
||||
// Instance NetBird Instance API
|
||||
// see more: https://docs.netbird.io/api/resources/instance
|
||||
Instance *InstanceAPI
|
||||
}
|
||||
|
||||
// New initialize new Client instance using PAT token
|
||||
@@ -152,14 +120,6 @@ func (c *Client) initialize() {
|
||||
c.DNSZones = &DNSZonesAPI{c}
|
||||
c.GeoLocation = &GeoLocationAPI{c}
|
||||
c.Events = &EventsAPI{c}
|
||||
c.Billing = &BillingAPI{c}
|
||||
c.MSP = &MSPAPI{c}
|
||||
c.EDR = &EDRAPI{c}
|
||||
c.SCIM = &SCIMAPI{c}
|
||||
c.EventStreaming = &EventStreamingAPI{c}
|
||||
c.IdentityProviders = &IdentityProvidersAPI{c}
|
||||
c.Ingress = &IngressAPI{c}
|
||||
c.Instance = &InstanceAPI{c}
|
||||
}
|
||||
|
||||
// NewRequest creates and executes new management API request
|
||||
|
||||
@@ -1,307 +0,0 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// EDRAPI APIs for EDR integrations (Intune, SentinelOne, Falcon, Huntress)
|
||||
type EDRAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// GetIntuneIntegration retrieves the EDR Intune integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#get-intune-integration
|
||||
func (a *EDRAPI) GetIntuneIntegration(ctx context.Context) (*api.EDRIntuneResponse, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/intune", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.EDRIntuneResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// CreateIntuneIntegration creates a new EDR Intune integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#create-intune-integration
|
||||
func (a *EDRAPI) CreateIntuneIntegration(ctx context.Context, request api.EDRIntuneRequest) (*api.EDRIntuneResponse, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/intune", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.EDRIntuneResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// UpdateIntuneIntegration updates an existing EDR Intune integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#update-intune-integration
|
||||
func (a *EDRAPI) UpdateIntuneIntegration(ctx context.Context, request api.EDRIntuneRequest) (*api.EDRIntuneResponse, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/intune", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.EDRIntuneResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// DeleteIntuneIntegration deletes the EDR Intune integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#delete-intune-integration
|
||||
func (a *EDRAPI) DeleteIntuneIntegration(ctx context.Context) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/intune", nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSentinelOneIntegration retrieves the EDR SentinelOne integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#get-sentinelone-integration
|
||||
func (a *EDRAPI) GetSentinelOneIntegration(ctx context.Context) (*api.EDRSentinelOneResponse, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/sentinelone", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.EDRSentinelOneResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// CreateSentinelOneIntegration creates a new EDR SentinelOne integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#create-sentinelone-integration
|
||||
func (a *EDRAPI) CreateSentinelOneIntegration(ctx context.Context, request api.EDRSentinelOneRequest) (*api.EDRSentinelOneResponse, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/sentinelone", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.EDRSentinelOneResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// UpdateSentinelOneIntegration updates an existing EDR SentinelOne integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#update-sentinelone-integration
|
||||
func (a *EDRAPI) UpdateSentinelOneIntegration(ctx context.Context, request api.EDRSentinelOneRequest) (*api.EDRSentinelOneResponse, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/sentinelone", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.EDRSentinelOneResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// DeleteSentinelOneIntegration deletes the EDR SentinelOne integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#delete-sentinelone-integration
|
||||
func (a *EDRAPI) DeleteSentinelOneIntegration(ctx context.Context) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/sentinelone", nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFalconIntegration retrieves the EDR Falcon integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#get-falcon-integration
|
||||
func (a *EDRAPI) GetFalconIntegration(ctx context.Context) (*api.EDRFalconResponse, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/falcon", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.EDRFalconResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// CreateFalconIntegration creates a new EDR Falcon integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#create-falcon-integration
|
||||
func (a *EDRAPI) CreateFalconIntegration(ctx context.Context, request api.EDRFalconRequest) (*api.EDRFalconResponse, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/falcon", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.EDRFalconResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// UpdateFalconIntegration updates an existing EDR Falcon integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#update-falcon-integration
|
||||
func (a *EDRAPI) UpdateFalconIntegration(ctx context.Context, request api.EDRFalconRequest) (*api.EDRFalconResponse, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/falcon", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.EDRFalconResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// DeleteFalconIntegration deletes the EDR Falcon integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#delete-falcon-integration
|
||||
func (a *EDRAPI) DeleteFalconIntegration(ctx context.Context) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/falcon", nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetHuntressIntegration retrieves the EDR Huntress integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#get-huntress-integration
|
||||
func (a *EDRAPI) GetHuntressIntegration(ctx context.Context) (*api.EDRHuntressResponse, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/huntress", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.EDRHuntressResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// CreateHuntressIntegration creates a new EDR Huntress integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#create-huntress-integration
|
||||
func (a *EDRAPI) CreateHuntressIntegration(ctx context.Context, request api.EDRHuntressRequest) (*api.EDRHuntressResponse, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/huntress", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.EDRHuntressResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// UpdateHuntressIntegration updates an existing EDR Huntress integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#update-huntress-integration
|
||||
func (a *EDRAPI) UpdateHuntressIntegration(ctx context.Context, request api.EDRHuntressRequest) (*api.EDRHuntressResponse, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/huntress", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.EDRHuntressResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// DeleteHuntressIntegration deletes the EDR Huntress integration
|
||||
// See more: https://docs.netbird.io/api/resources/edr#delete-huntress-integration
|
||||
func (a *EDRAPI) DeleteHuntressIntegration(ctx context.Context) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/huntress", nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BypassPeerCompliance bypasses compliance for a non-compliant peer
|
||||
// See more: https://docs.netbird.io/api/resources/edr#bypass-peer-compliance
|
||||
func (a *EDRAPI) BypassPeerCompliance(ctx context.Context, peerID string) (*api.BypassResponse, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/peers/"+peerID+"/edr/bypass", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.BypassResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// RevokePeerBypass revokes the compliance bypass for a peer
|
||||
// See more: https://docs.netbird.io/api/resources/edr#revoke-peer-bypass
|
||||
func (a *EDRAPI) RevokePeerBypass(ctx context.Context, peerID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/peers/"+peerID+"/edr/bypass", nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListBypassedPeers returns all peers that have compliance bypassed
|
||||
// See more: https://docs.netbird.io/api/resources/edr#list-all-bypassed-peers
|
||||
func (a *EDRAPI) ListBypassedPeers(ctx context.Context) ([]api.BypassResponse, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/edr/bypassed", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.BypassResponse](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -1,422 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testIntuneResponse = api.EDRIntuneResponse{
|
||||
AccountId: "acc-1",
|
||||
ClientId: "client-1",
|
||||
TenantId: "tenant-1",
|
||||
Enabled: true,
|
||||
Id: 1,
|
||||
Groups: []api.Group{},
|
||||
LastSyncedInterval: 24,
|
||||
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
CreatedBy: "user-1",
|
||||
}
|
||||
|
||||
testSentinelOneResponse = api.EDRSentinelOneResponse{
|
||||
AccountId: "acc-1",
|
||||
ApiUrl: "https://sentinelone.example.com",
|
||||
Enabled: true,
|
||||
Id: 2,
|
||||
Groups: []api.Group{},
|
||||
LastSyncedInterval: 24,
|
||||
MatchAttributes: api.SentinelOneMatchAttributes{},
|
||||
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
CreatedBy: "user-1",
|
||||
}
|
||||
|
||||
testFalconResponse = api.EDRFalconResponse{
|
||||
AccountId: "acc-1",
|
||||
CloudId: "us-1",
|
||||
Enabled: true,
|
||||
Id: 3,
|
||||
Groups: []api.Group{},
|
||||
ZtaScoreThreshold: 50,
|
||||
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
CreatedBy: "user-1",
|
||||
}
|
||||
|
||||
testHuntressResponse = api.EDRHuntressResponse{
|
||||
AccountId: "acc-1",
|
||||
Enabled: true,
|
||||
Id: 4,
|
||||
Groups: []api.Group{},
|
||||
LastSyncedInterval: 24,
|
||||
MatchAttributes: api.HuntressMatchAttributes{},
|
||||
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
CreatedBy: "user-1",
|
||||
}
|
||||
|
||||
testBypassResponse = api.BypassResponse{
|
||||
PeerId: "peer-1",
|
||||
}
|
||||
)
|
||||
|
||||
// Intune tests
|
||||
|
||||
func TestEDR_GetIntuneIntegration_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal(testIntuneResponse)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.GetIntuneIntegration(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testIntuneResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_GetIntuneIntegration_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.GetIntuneIntegration(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_CreateIntuneIntegration_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.EDRIntuneRequest
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "client-1", req.ClientId)
|
||||
retBytes, _ := json.Marshal(testIntuneResponse)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.CreateIntuneIntegration(context.Background(), api.EDRIntuneRequest{
|
||||
ClientId: "client-1",
|
||||
Secret: "secret",
|
||||
TenantId: "tenant-1",
|
||||
Groups: []string{"group-1"},
|
||||
LastSyncedInterval: 24,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testIntuneResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_CreateIntuneIntegration_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.CreateIntuneIntegration(context.Background(), api.EDRIntuneRequest{})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_UpdateIntuneIntegration_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
retBytes, _ := json.Marshal(testIntuneResponse)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.UpdateIntuneIntegration(context.Background(), api.EDRIntuneRequest{
|
||||
ClientId: "client-1",
|
||||
Secret: "new-secret",
|
||||
TenantId: "tenant-1",
|
||||
Groups: []string{"group-1"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testIntuneResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_DeleteIntuneIntegration_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.EDR.DeleteIntuneIntegration(context.Background())
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_DeleteIntuneIntegration_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.EDR.DeleteIntuneIntegration(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
// SentinelOne tests
|
||||
|
||||
func TestEDR_GetSentinelOneIntegration_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/sentinelone", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal(testSentinelOneResponse)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.GetSentinelOneIntegration(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testSentinelOneResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_CreateSentinelOneIntegration_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/sentinelone", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
retBytes, _ := json.Marshal(testSentinelOneResponse)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.CreateSentinelOneIntegration(context.Background(), api.EDRSentinelOneRequest{
|
||||
ApiToken: "token",
|
||||
ApiUrl: "https://sentinelone.example.com",
|
||||
Groups: []string{"group-1"},
|
||||
LastSyncedInterval: 24,
|
||||
MatchAttributes: api.SentinelOneMatchAttributes{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testSentinelOneResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_DeleteSentinelOneIntegration_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/sentinelone", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.EDR.DeleteSentinelOneIntegration(context.Background())
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// Falcon tests
|
||||
|
||||
func TestEDR_GetFalconIntegration_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/falcon", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal(testFalconResponse)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.GetFalconIntegration(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testFalconResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_CreateFalconIntegration_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/falcon", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
retBytes, _ := json.Marshal(testFalconResponse)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.CreateFalconIntegration(context.Background(), api.EDRFalconRequest{
|
||||
ClientId: "client-1",
|
||||
Secret: "secret",
|
||||
CloudId: "us-1",
|
||||
Groups: []string{"group-1"},
|
||||
ZtaScoreThreshold: 50,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testFalconResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_DeleteFalconIntegration_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/falcon", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.EDR.DeleteFalconIntegration(context.Background())
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// Huntress tests
|
||||
|
||||
func TestEDR_GetHuntressIntegration_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/huntress", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal(testHuntressResponse)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.GetHuntressIntegration(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testHuntressResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_CreateHuntressIntegration_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/huntress", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
retBytes, _ := json.Marshal(testHuntressResponse)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.CreateHuntressIntegration(context.Background(), api.EDRHuntressRequest{
|
||||
ApiKey: "key",
|
||||
ApiSecret: "secret",
|
||||
Groups: []string{"group-1"},
|
||||
LastSyncedInterval: 24,
|
||||
MatchAttributes: api.HuntressMatchAttributes{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testHuntressResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_DeleteHuntressIntegration_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/edr/huntress", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.EDR.DeleteHuntressIntegration(context.Background())
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// Peer bypass tests
|
||||
|
||||
func TestEDR_BypassPeerCompliance_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers/peer-1/edr/bypass", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
retBytes, _ := json.Marshal(testBypassResponse)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.BypassPeerCompliance(context.Background(), "peer-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testBypassResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_BypassPeerCompliance_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers/peer-1/edr/bypass", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Bad request", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.BypassPeerCompliance(context.Background(), "peer-1")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Bad request", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_RevokePeerBypass_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers/peer-1/edr/bypass", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.EDR.RevokePeerBypass(context.Background(), "peer-1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_RevokePeerBypass_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers/peer-1/edr/bypass", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.EDR.RevokePeerBypass(context.Background(), "peer-1")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_ListBypassedPeers_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers/edr/bypassed", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal([]api.BypassResponse{testBypassResponse})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.ListBypassedPeers(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testBypassResponse, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestEDR_ListBypassedPeers_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers/edr/bypassed", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EDR.ListBypassedPeers(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// EventStreamingAPI APIs for event streaming integrations
|
||||
type EventStreamingAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List retrieves all event streaming integrations
|
||||
// See more: https://docs.netbird.io/api/resources/event-streaming#list-all-event-streaming-integrations
|
||||
func (a *EventStreamingAPI) List(ctx context.Context) ([]api.IntegrationResponse, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/event-streaming", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.IntegrationResponse](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Get retrieves a specific event streaming integration by ID
|
||||
// See more: https://docs.netbird.io/api/resources/event-streaming#retrieve-an-event-streaming-integration
|
||||
func (a *EventStreamingAPI) Get(ctx context.Context, integrationID int) (*api.IntegrationResponse, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/event-streaming/"+strconv.Itoa(integrationID), nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.IntegrationResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Create creates a new event streaming integration
|
||||
// See more: https://docs.netbird.io/api/resources/event-streaming#create-an-event-streaming-integration
|
||||
func (a *EventStreamingAPI) Create(ctx context.Context, request api.CreateIntegrationRequest) (*api.IntegrationResponse, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/event-streaming", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.IntegrationResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Update updates an existing event streaming integration
|
||||
// See more: https://docs.netbird.io/api/resources/event-streaming#update-an-event-streaming-integration
|
||||
func (a *EventStreamingAPI) Update(ctx context.Context, integrationID int, request api.CreateIntegrationRequest) (*api.IntegrationResponse, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/event-streaming/"+strconv.Itoa(integrationID), bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.IntegrationResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete deletes an event streaming integration
|
||||
// See more: https://docs.netbird.io/api/resources/event-streaming#delete-an-event-streaming-integration
|
||||
func (a *EventStreamingAPI) Delete(ctx context.Context, integrationID int) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/event-streaming/"+strconv.Itoa(integrationID), nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testIntegrationResponse = api.IntegrationResponse{
|
||||
Id: ptr[int64](1),
|
||||
AccountId: ptr("acc-1"),
|
||||
Platform: (*api.IntegrationResponsePlatform)(ptr("datadog")),
|
||||
Enabled: ptr(true),
|
||||
Config: &map[string]string{"api_key": "****"},
|
||||
CreatedAt: ptr(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)),
|
||||
UpdatedAt: ptr(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)),
|
||||
}
|
||||
)
|
||||
|
||||
func TestEventStreaming_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/event-streaming", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal([]api.IntegrationResponse{testIntegrationResponse})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EventStreaming.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testIntegrationResponse, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestEventStreaming_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/event-streaming", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EventStreaming.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEventStreaming_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal(testIntegrationResponse)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EventStreaming.Get(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testIntegrationResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEventStreaming_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EventStreaming.Get(context.Background(), 1)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEventStreaming_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/event-streaming", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.CreateIntegrationRequest
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, api.CreateIntegrationRequestPlatformDatadog, req.Platform)
|
||||
assert.Equal(t, true, req.Enabled)
|
||||
retBytes, _ := json.Marshal(testIntegrationResponse)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EventStreaming.Create(context.Background(), api.CreateIntegrationRequest{
|
||||
Platform: api.CreateIntegrationRequestPlatformDatadog,
|
||||
Enabled: true,
|
||||
Config: map[string]string{"api_key": "test-key"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testIntegrationResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEventStreaming_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/event-streaming", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EventStreaming.Create(context.Background(), api.CreateIntegrationRequest{})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEventStreaming_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.CreateIntegrationRequest
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, false, req.Enabled)
|
||||
retBytes, _ := json.Marshal(testIntegrationResponse)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EventStreaming.Update(context.Background(), 1, api.CreateIntegrationRequest{
|
||||
Platform: api.CreateIntegrationRequestPlatformDatadog,
|
||||
Enabled: false,
|
||||
Config: map[string]string{"api_key": "updated-key"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testIntegrationResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEventStreaming_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.EventStreaming.Update(context.Background(), 1, api.CreateIntegrationRequest{})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEventStreaming_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.EventStreaming.Delete(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEventStreaming_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.EventStreaming.Delete(context.Background(), 1)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
@@ -2,8 +2,6 @@ package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
@@ -13,79 +11,10 @@ type EventsAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// NetworkTrafficOption options for ListNetworkTrafficEvents API
|
||||
type NetworkTrafficOption func(query map[string]string)
|
||||
|
||||
func NetworkTrafficPage(page int) NetworkTrafficOption {
|
||||
return func(query map[string]string) {
|
||||
query["page"] = fmt.Sprintf("%d", page)
|
||||
}
|
||||
}
|
||||
|
||||
func NetworkTrafficPageSize(pageSize int) NetworkTrafficOption {
|
||||
return func(query map[string]string) {
|
||||
query["page_size"] = fmt.Sprintf("%d", pageSize)
|
||||
}
|
||||
}
|
||||
|
||||
func NetworkTrafficUserID(userID string) NetworkTrafficOption {
|
||||
return func(query map[string]string) {
|
||||
query["user_id"] = userID
|
||||
}
|
||||
}
|
||||
|
||||
func NetworkTrafficReporterID(reporterID string) NetworkTrafficOption {
|
||||
return func(query map[string]string) {
|
||||
query["reporter_id"] = reporterID
|
||||
}
|
||||
}
|
||||
|
||||
func NetworkTrafficProtocol(protocol int) NetworkTrafficOption {
|
||||
return func(query map[string]string) {
|
||||
query["protocol"] = fmt.Sprintf("%d", protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func NetworkTrafficType(t api.GetApiEventsNetworkTrafficParamsType) NetworkTrafficOption {
|
||||
return func(query map[string]string) {
|
||||
query["type"] = string(t)
|
||||
}
|
||||
}
|
||||
|
||||
func NetworkTrafficConnectionType(ct api.GetApiEventsNetworkTrafficParamsConnectionType) NetworkTrafficOption {
|
||||
return func(query map[string]string) {
|
||||
query["connection_type"] = string(ct)
|
||||
}
|
||||
}
|
||||
|
||||
func NetworkTrafficDirection(d api.GetApiEventsNetworkTrafficParamsDirection) NetworkTrafficOption {
|
||||
return func(query map[string]string) {
|
||||
query["direction"] = string(d)
|
||||
}
|
||||
}
|
||||
|
||||
func NetworkTrafficSearch(search string) NetworkTrafficOption {
|
||||
return func(query map[string]string) {
|
||||
query["search"] = search
|
||||
}
|
||||
}
|
||||
|
||||
func NetworkTrafficStartDate(t time.Time) NetworkTrafficOption {
|
||||
return func(query map[string]string) {
|
||||
query["start_date"] = t.Format(time.RFC3339)
|
||||
}
|
||||
}
|
||||
|
||||
func NetworkTrafficEndDate(t time.Time) NetworkTrafficOption {
|
||||
return func(query map[string]string) {
|
||||
query["end_date"] = t.Format(time.RFC3339)
|
||||
}
|
||||
}
|
||||
|
||||
// ListAuditEvents list all audit events
|
||||
// See more: https://docs.netbird.io/api/resources/events#list-all-audit-events
|
||||
func (a *EventsAPI) ListAuditEvents(ctx context.Context) ([]api.Event, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/events/audit", nil, nil)
|
||||
// List list all events
|
||||
// See more: https://docs.netbird.io/api/resources/events#list-all-events
|
||||
func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/events", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -95,21 +24,3 @@ func (a *EventsAPI) ListAuditEvents(ctx context.Context) ([]api.Event, error) {
|
||||
ret, err := parseResponse[[]api.Event](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// ListNetworkTrafficEvents list network traffic events
|
||||
// See more: https://docs.netbird.io/api/resources/events#list-network-traffic-events
|
||||
func (a *EventsAPI) ListNetworkTrafficEvents(ctx context.Context, opts ...NetworkTrafficOption) (*api.NetworkTrafficEventsResponse, error) {
|
||||
query := make(map[string]string)
|
||||
for _, o := range opts {
|
||||
o(query)
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/events/network-traffic", nil, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NetworkTrafficEventsResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
@@ -21,76 +21,37 @@ var (
|
||||
Activity: "AccountCreate",
|
||||
ActivityCode: api.EventActivityCodeAccountCreate,
|
||||
}
|
||||
|
||||
testNetworkTrafficResponse = api.NetworkTrafficEventsResponse{
|
||||
Data: []api.NetworkTrafficEvent{},
|
||||
Page: 1,
|
||||
PageSize: 50,
|
||||
}
|
||||
)
|
||||
|
||||
func TestEvents_ListAuditEvents_200(t *testing.T) {
|
||||
func TestEvents_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/events/audit", func(w http.ResponseWriter, r *http.Request) {
|
||||
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.Event{testEvent})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Events.ListAuditEvents(context.Background())
|
||||
ret, err := c.Events.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testEvent, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvents_ListAuditEvents_Err(t *testing.T) {
|
||||
func TestEvents_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/events/audit", func(w http.ResponseWriter, r *http.Request) {
|
||||
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Events.ListAuditEvents(context.Background())
|
||||
ret, err := c.Events.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvents_ListNetworkTrafficEvents_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/events/network-traffic", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "1", r.URL.Query().Get("page"))
|
||||
assert.Equal(t, "50", r.URL.Query().Get("page_size"))
|
||||
retBytes, _ := json.Marshal(testNetworkTrafficResponse)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Events.ListNetworkTrafficEvents(context.Background(),
|
||||
rest.NetworkTrafficPage(1),
|
||||
rest.NetworkTrafficPageSize(50),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testNetworkTrafficResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvents_ListNetworkTrafficEvents_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/events/network-traffic", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Events.ListNetworkTrafficEvents(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvents_Integration(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
// Do something that would trigger any event
|
||||
@@ -101,7 +62,7 @@ func TestEvents_Integration(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
events, err := c.Events.ListAuditEvents(context.Background())
|
||||
events, err := c.Events.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, events)
|
||||
})
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// IdentityProvidersAPI APIs for Identity Providers, do not use directly
|
||||
type IdentityProvidersAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List all identity providers
|
||||
// See more: https://docs.netbird.io/api/resources/identity-providers#list-all-identity-providers
|
||||
func (a *IdentityProvidersAPI) List(ctx context.Context) ([]api.IdentityProvider, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/identity-providers", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.IdentityProvider](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Get identity provider info
|
||||
// See more: https://docs.netbird.io/api/resources/identity-providers#retrieve-an-identity-provider
|
||||
func (a *IdentityProvidersAPI) Get(ctx context.Context, idpID string) (*api.IdentityProvider, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/identity-providers/"+idpID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.IdentityProvider](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Create new identity provider
|
||||
// See more: https://docs.netbird.io/api/resources/identity-providers#create-an-identity-provider
|
||||
func (a *IdentityProvidersAPI) Create(ctx context.Context, request api.PostApiIdentityProvidersJSONRequestBody) (*api.IdentityProvider, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/identity-providers", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.IdentityProvider](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Update update identity provider
|
||||
// See more: https://docs.netbird.io/api/resources/identity-providers#update-an-identity-provider
|
||||
func (a *IdentityProvidersAPI) Update(ctx context.Context, idpID string, request api.PutApiIdentityProvidersIdpIdJSONRequestBody) (*api.IdentityProvider, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/identity-providers/"+idpID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.IdentityProvider](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete delete identity provider
|
||||
// See more: https://docs.netbird.io/api/resources/identity-providers#delete-an-identity-provider
|
||||
func (a *IdentityProvidersAPI) Delete(ctx context.Context, idpID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/identity-providers/"+idpID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,183 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
var testIdentityProvider = api.IdentityProvider{
|
||||
ClientId: "test-client-id",
|
||||
Id: ptr("Test"),
|
||||
}
|
||||
|
||||
func TestIdentityProviders_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/identity-providers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.IdentityProvider{testIdentityProvider})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.IdentityProviders.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testIdentityProvider, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestIdentityProviders_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/identity-providers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.IdentityProviders.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIdentityProviders_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testIdentityProvider)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.IdentityProviders.Get(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testIdentityProvider, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIdentityProviders_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.IdentityProviders.Get(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIdentityProviders_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/identity-providers", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostApiIdentityProvidersJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "new-client-id", req.ClientId)
|
||||
retBytes, _ := json.Marshal(testIdentityProvider)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.IdentityProviders.Create(context.Background(), api.PostApiIdentityProvidersJSONRequestBody{
|
||||
ClientId: "new-client-id",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testIdentityProvider, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIdentityProviders_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/identity-providers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.IdentityProviders.Create(context.Background(), api.PostApiIdentityProvidersJSONRequestBody{
|
||||
ClientId: "new-client-id",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIdentityProviders_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiIdentityProvidersIdpIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updated-client-id", req.ClientId)
|
||||
retBytes, _ := json.Marshal(testIdentityProvider)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.IdentityProviders.Update(context.Background(), "Test", api.PutApiIdentityProvidersIdpIdJSONRequestBody{
|
||||
ClientId: "updated-client-id",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testIdentityProvider, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIdentityProviders_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.IdentityProviders.Update(context.Background(), "Test", api.PutApiIdentityProvidersIdpIdJSONRequestBody{
|
||||
ClientId: "updated-client-id",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIdentityProviders_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.IdentityProviders.Delete(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIdentityProviders_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.IdentityProviders.Delete(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// IngressAPI APIs for Ingress Peers, do not use directly
|
||||
type IngressAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List all ingress peers
|
||||
// See more: https://docs.netbird.io/api/resources/ingress#list-all-ingress-peers
|
||||
func (a *IngressAPI) List(ctx context.Context) ([]api.IngressPeer, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/ingress/peers", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.IngressPeer](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Get ingress peer info
|
||||
// See more: https://docs.netbird.io/api/resources/ingress#retrieve-an-ingress-peer
|
||||
func (a *IngressAPI) Get(ctx context.Context, ingressPeerID string) (*api.IngressPeer, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/ingress/peers/"+ingressPeerID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.IngressPeer](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Create new ingress peer
|
||||
// See more: https://docs.netbird.io/api/resources/ingress#create-an-ingress-peer
|
||||
func (a *IngressAPI) Create(ctx context.Context, request api.PostApiIngressPeersJSONRequestBody) (*api.IngressPeer, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/ingress/peers", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.IngressPeer](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Update update ingress peer
|
||||
// See more: https://docs.netbird.io/api/resources/ingress#update-an-ingress-peer
|
||||
func (a *IngressAPI) Update(ctx context.Context, ingressPeerID string, request api.PutApiIngressPeersIngressPeerIdJSONRequestBody) (*api.IngressPeer, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/ingress/peers/"+ingressPeerID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.IngressPeer](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete delete ingress peer
|
||||
// See more: https://docs.netbird.io/api/resources/ingress#delete-an-ingress-peer
|
||||
func (a *IngressAPI) Delete(ctx context.Context, ingressPeerID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/ingress/peers/"+ingressPeerID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,184 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
var testIngressPeer = api.IngressPeer{
|
||||
Connected: true,
|
||||
Enabled: true,
|
||||
Id: "Test",
|
||||
}
|
||||
|
||||
func TestIngress_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/ingress/peers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.IngressPeer{testIngressPeer})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Ingress.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testIngressPeer, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestIngress_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/ingress/peers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Ingress.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIngress_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testIngressPeer)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Ingress.Get(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testIngressPeer, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIngress_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Ingress.Get(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIngress_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/ingress/peers", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostApiIngressPeersJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "peer-id", req.PeerId)
|
||||
retBytes, _ := json.Marshal(testIngressPeer)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Ingress.Create(context.Background(), api.PostApiIngressPeersJSONRequestBody{
|
||||
PeerId: "peer-id",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testIngressPeer, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIngress_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/ingress/peers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Ingress.Create(context.Background(), api.PostApiIngressPeersJSONRequestBody{
|
||||
PeerId: "peer-id",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIngress_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiIngressPeersIngressPeerIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, true, req.Enabled)
|
||||
retBytes, _ := json.Marshal(testIngressPeer)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Ingress.Update(context.Background(), "Test", api.PutApiIngressPeersIngressPeerIdJSONRequestBody{
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testIngressPeer, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIngress_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Ingress.Update(context.Background(), "Test", api.PutApiIngressPeersIngressPeerIdJSONRequestBody{
|
||||
Enabled: true,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIngress_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.Ingress.Delete(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIngress_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.Ingress.Delete(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// InstanceAPI APIs for Instance status and version, do not use directly
|
||||
type InstanceAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// GetStatus get instance status
|
||||
// See more: https://docs.netbird.io/api/resources/instance#get-instance-status
|
||||
func (a *InstanceAPI) GetStatus(ctx context.Context) (*api.InstanceStatus, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/instance", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.InstanceStatus](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Setup perform initial instance setup
|
||||
// See more: https://docs.netbird.io/api/resources/instance#setup-instance
|
||||
func (a *InstanceAPI) Setup(ctx context.Context, request api.PostApiSetupJSONRequestBody) (*api.SetupResponse, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/setup", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.SetupResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testInstanceStatus = api.InstanceStatus{
|
||||
SetupRequired: true,
|
||||
}
|
||||
|
||||
testSetupResponse = api.SetupResponse{
|
||||
Email: "admin@example.com",
|
||||
UserId: "user-123",
|
||||
}
|
||||
)
|
||||
|
||||
func TestInstance_GetStatus_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/instance", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testInstanceStatus)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Instance.GetStatus(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testInstanceStatus, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInstance_GetStatus_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/instance", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Instance.GetStatus(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInstance_Setup_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/setup", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostApiSetupJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "admin@example.com", req.Email)
|
||||
retBytes, _ := json.Marshal(testSetupResponse)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Instance.Setup(context.Background(), api.PostApiSetupJSONRequestBody{
|
||||
Email: "admin@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testSetupResponse, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInstance_Setup_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/setup", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Instance.Setup(context.Background(), api.PostApiSetupJSONRequestBody{
|
||||
Email: "admin@example.com",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
@@ -1,122 +0,0 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// MSPAPI APIs for MSP tenant management
|
||||
type MSPAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// ListTenants retrieves all MSP tenants
|
||||
// See more: https://docs.netbird.io/api/resources/msp#list-all-tenants
|
||||
func (a *MSPAPI) ListTenants(ctx context.Context) (*api.GetTenantsResponse, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/msp/tenants", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.GetTenantsResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// CreateTenant creates a new MSP tenant
|
||||
// See more: https://docs.netbird.io/api/resources/msp#create-a-tenant
|
||||
func (a *MSPAPI) CreateTenant(ctx context.Context, request api.CreateTenantRequest) (*api.TenantResponse, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/msp/tenants", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.TenantResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// UpdateTenant updates an existing MSP tenant
|
||||
// See more: https://docs.netbird.io/api/resources/msp#update-a-tenant
|
||||
func (a *MSPAPI) UpdateTenant(ctx context.Context, tenantID string, request api.UpdateTenantRequest) (*api.TenantResponse, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/msp/tenants/"+tenantID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.TenantResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// DeleteTenant deletes an MSP tenant
|
||||
// See more: https://docs.netbird.io/api/resources/msp#delete-a-tenant
|
||||
func (a *MSPAPI) DeleteTenant(ctx context.Context, tenantID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/msp/tenants/"+tenantID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnlinkTenant unlinks a tenant from the MSP account
|
||||
// See more: https://docs.netbird.io/api/resources/msp#unlink-a-tenant
|
||||
func (a *MSPAPI) UnlinkTenant(ctx context.Context, tenantID, owner string) error {
|
||||
params := map[string]string{"owner": owner}
|
||||
requestBytes, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/msp/tenants/"+tenantID+"/unlink", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyTenantDNS verifies a tenant domain DNS challenge
|
||||
// See more: https://docs.netbird.io/api/resources/msp#verify-tenant-dns
|
||||
func (a *MSPAPI) VerifyTenantDNS(ctx context.Context, tenantID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/msp/tenants/"+tenantID+"/dns", nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteTenant invites an existing account as a tenant to the MSP account
|
||||
// See more: https://docs.netbird.io/api/resources/msp#invite-a-tenant
|
||||
func (a *MSPAPI) InviteTenant(ctx context.Context, tenantID string) (*api.TenantResponse, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/msp/tenants/"+tenantID+"/invite", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.TenantResponse](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -1,251 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testTenant = api.TenantResponse{
|
||||
Id: "tenant-1",
|
||||
Name: "Test Tenant",
|
||||
Domain: "test.example.com",
|
||||
DnsChallenge: "challenge-123",
|
||||
Status: "active",
|
||||
Groups: []api.TenantGroupResponse{},
|
||||
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
)
|
||||
|
||||
func TestMSP_ListTenants_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/msp/tenants", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal([]api.TenantResponse{testTenant})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.MSP.ListTenants(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, *ret, 1)
|
||||
assert.Equal(t, testTenant, (*ret)[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestMSP_ListTenants_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/msp/tenants", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.MSP.ListTenants(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMSP_CreateTenant_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/msp/tenants", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.CreateTenantRequest
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test Tenant", req.Name)
|
||||
assert.Equal(t, "test.example.com", req.Domain)
|
||||
retBytes, _ := json.Marshal(testTenant)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.MSP.CreateTenant(context.Background(), api.CreateTenantRequest{
|
||||
Name: "Test Tenant",
|
||||
Domain: "test.example.com",
|
||||
Groups: []api.TenantGroupResponse{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testTenant, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMSP_CreateTenant_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/msp/tenants", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.MSP.CreateTenant(context.Background(), api.CreateTenantRequest{
|
||||
Name: "Test Tenant",
|
||||
Domain: "test.example.com",
|
||||
Groups: []api.TenantGroupResponse{},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMSP_UpdateTenant_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.UpdateTenantRequest
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated Tenant", req.Name)
|
||||
retBytes, _ := json.Marshal(testTenant)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.MSP.UpdateTenant(context.Background(), "tenant-1", api.UpdateTenantRequest{
|
||||
Name: "Updated Tenant",
|
||||
Groups: []api.TenantGroupResponse{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testTenant, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMSP_UpdateTenant_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.MSP.UpdateTenant(context.Background(), "tenant-1", api.UpdateTenantRequest{
|
||||
Name: "Updated Tenant",
|
||||
Groups: []api.TenantGroupResponse{},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMSP_DeleteTenant_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.MSP.DeleteTenant(context.Background(), "tenant-1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMSP_DeleteTenant_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.MSP.DeleteTenant(context.Background(), "tenant-1")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestMSP_UnlinkTenant_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/unlink", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.MSP.UnlinkTenant(context.Background(), "tenant-1", "owner-1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMSP_UnlinkTenant_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/unlink", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.MSP.UnlinkTenant(context.Background(), "tenant-1", "owner-1")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestMSP_VerifyTenantDNS_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/dns", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.MSP.VerifyTenantDNS(context.Background(), "tenant-1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMSP_VerifyTenantDNS_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/dns", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Failed", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.MSP.VerifyTenantDNS(context.Background(), "tenant-1")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Failed", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestMSP_InviteTenant_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/invite", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
retBytes, _ := json.Marshal(testTenant)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.MSP.InviteTenant(context.Background(), "tenant-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testTenant, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMSP_InviteTenant_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/invite", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.MSP.InviteTenant(context.Background(), "tenant-1")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
@@ -91,20 +91,6 @@ func (a *NetworksAPI) Delete(ctx context.Context, networkID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListAllRouters list all routers across all networks
|
||||
// See more: https://docs.netbird.io/api/resources/networks#list-all-network-routers
|
||||
func (a *NetworksAPI) ListAllRouters(ctx context.Context) ([]api.NetworkRouter, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/routers", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.NetworkRouter](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// NetworkResourcesAPI APIs for Network Resources, do not use directly
|
||||
type NetworkResourcesAPI struct {
|
||||
c *Client
|
||||
|
||||
@@ -219,35 +219,6 @@ func TestNetworks_Integration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworks_ListAllRouters_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/routers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.NetworkRouter{testNetworkRouter})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.ListAllRouters(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testNetworkRouter, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworks_ListAllRouters_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/routers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.ListAllRouters(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkResources_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user