mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-04 06:59:54 +00:00
Compare commits
247 Commits
test/conne
...
dn-reverse
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76fb153d76 | ||
|
|
eee4d75932 | ||
|
|
62b8875f67 | ||
|
|
47a5478964 | ||
|
|
9922d6f953 | ||
|
|
f9bab22f61 | ||
|
|
3d8fdb7a89 | ||
|
|
fb10153ab8 | ||
|
|
57d3ee5aac | ||
|
|
cfdfdecc14 | ||
|
|
b00babb8b1 | ||
|
|
ac995bae6d | ||
|
|
41a5509ce0 | ||
|
|
db5e26db94 | ||
|
|
fe975fb834 | ||
|
|
e368d2995b | ||
|
|
a3241d8376 | ||
|
|
6dfc5772ba | ||
|
|
f70925178c | ||
|
|
9554934b92 | ||
|
|
7fdb824a37 | ||
|
|
412407adc0 | ||
|
|
e0874d7de7 | ||
|
|
8df1536cbb | ||
|
|
fcbacc62ec | ||
|
|
ee2ae45653 | ||
|
|
3bc8cbb13f | ||
|
|
bf7bdf6c4f | ||
|
|
6f2f0f9ae4 | ||
|
|
c37ebc6fb3 | ||
|
|
23abb5743c | ||
|
|
0a895ffc22 | ||
|
|
b87aa0bc15 | ||
|
|
f1a65d732d | ||
|
|
a3c0ea3e71 | ||
|
|
abaf061c2a | ||
|
|
e531fb54b1 | ||
|
|
5fcfed5b16 | ||
|
|
b81837a364 | ||
|
|
5f43449f67 | ||
|
|
6796601aa6 | ||
|
|
1fc25c301b | ||
|
|
08ae281b2d | ||
|
|
bd47f44c63 | ||
|
|
381260911b | ||
|
|
38db42e7d6 | ||
|
|
5d606d909d | ||
|
|
d689718b50 | ||
|
|
54a73c6649 | ||
|
|
418377842e | ||
|
|
15ef56e03d | ||
|
|
917035f8e8 | ||
|
|
963e3f5457 | ||
|
|
e20b969188 | ||
|
|
1c7059ee67 | ||
|
|
22a3365658 | ||
|
|
08ab1e3478 | ||
|
|
ebb1f4007d | ||
|
|
acb53ece93 | ||
|
|
e020950cfd | ||
|
|
9dba262a20 | ||
|
|
5bcdf36377 | ||
|
|
1ffe8deb10 | ||
|
|
d069145bd1 | ||
|
|
f3493ee042 | ||
|
|
b782ac6f56 | ||
|
|
bf48044e5c | ||
|
|
fb4cc37a4a | ||
|
|
55b8d89a79 | ||
|
|
6968a32a5a | ||
|
|
cfe6753349 | ||
|
|
5ae15b3af3 | ||
|
|
b79adb706c | ||
|
|
f22497d5da | ||
|
|
95d672c9df | ||
|
|
7d08a609e6 | ||
|
|
eea6120cd0 | ||
|
|
0cb02bd906 | ||
|
|
08d3867f41 | ||
|
|
b16d63643c | ||
|
|
940d01bdea | ||
|
|
ba9158d159 | ||
|
|
ca9a7e11ef | ||
|
|
a803f47685 | ||
|
|
79fed32f01 | ||
|
|
6b00bb0a66 | ||
|
|
e2adef1eea | ||
|
|
9e5fa11792 | ||
|
|
1ff75acb31 | ||
|
|
1754160686 | ||
|
|
423f6266fb | ||
|
|
16d1b4a14a | ||
|
|
7c14056faf | ||
|
|
62e37dc2e2 | ||
|
|
6a08695ee8 | ||
|
|
9a67a8e427 | ||
|
|
73aa0785ba | ||
|
|
53c1016a8e | ||
|
|
fd442138e6 | ||
|
|
be5f30225a | ||
|
|
7467e9fb8c | ||
|
|
2390c2e46e | ||
|
|
778c223176 | ||
|
|
36cd0dd85c | ||
|
|
09a1d5a02d | ||
|
|
7c996ac9b5 | ||
|
|
cf9fd5d960 | ||
|
|
1c5ab7cb8f | ||
|
|
aaad3b25a7 | ||
|
|
9904235a2f | ||
|
|
780e9f57a5 | ||
|
|
a8db73285b | ||
|
|
3b43c00d12 | ||
|
|
2f390e1794 | ||
|
|
3630ebb3ae | ||
|
|
260c46df04 | ||
|
|
7f11e3205d | ||
|
|
1c8f92a96f | ||
|
|
7b6294b624 | ||
|
|
156d0b1fef | ||
|
|
2cf00dba58 | ||
|
|
d2a7f3ae36 | ||
|
|
6a64d4e4dd | ||
|
|
51e63c246b | ||
|
|
99e6b1eda4 | ||
|
|
dc26a5a436 | ||
|
|
3883b2fb41 | ||
|
|
ed58659a01 | ||
|
|
5190923c70 | ||
|
|
7c647dd160 | ||
|
|
07e59b2708 | ||
|
|
0a3a9f977d | ||
|
|
2f263bf7e6 | ||
|
|
f65f4fc280 | ||
|
|
adbd7ab4c3 | ||
|
|
0419834482 | ||
|
|
f797d2d9cb | ||
|
|
5ae7efe8f7 | ||
|
|
d6e35bd0fe | ||
|
|
0e00f1c8f7 | ||
|
|
4433f44a12 | ||
|
|
7504e718d7 | ||
|
|
9b0387e7ee | ||
|
|
5ccce1ab3f | ||
|
|
e366fe340e | ||
|
|
b01809f8e3 | ||
|
|
790ef39187 | ||
|
|
3af16cf333 | ||
|
|
d09c69f303 | ||
|
|
096d4ac529 | ||
|
|
8fafde614a | ||
|
|
694ae13418 | ||
|
|
b5b7dd4f53 | ||
|
|
476785b122 | ||
|
|
907677f835 | ||
|
|
7d844b9410 | ||
|
|
eeabc64a73 | ||
|
|
5da2b0fdcc | ||
|
|
a0005a604e | ||
|
|
a89bb807a6 | ||
|
|
28f3354ffa | ||
|
|
562923c600 | ||
|
|
0dd0c67b3b | ||
|
|
ca33849f31 | ||
|
|
18cd0f1480 | ||
|
|
b02982f6b1 | ||
|
|
4d89ae27ef | ||
|
|
733ea77c5c | ||
|
|
92f72bfce6 | ||
|
|
bffb25bea7 | ||
|
|
3af4543e80 | ||
|
|
146774860b | ||
|
|
5243481316 | ||
|
|
76a39c1dcb | ||
|
|
02ce918114 | ||
|
|
30cfc22cb6 | ||
|
|
3168afbfcb | ||
|
|
a73ee47557 | ||
|
|
fa6ff005f2 | ||
|
|
095379fa60 | ||
|
|
30572fe1b8 | ||
|
|
3a6f364b03 | ||
|
|
5345d716ee | ||
|
|
f882c36e0a | ||
|
|
e95cfa1a00 | ||
|
|
0d480071b6 | ||
|
|
8e0b7b6c25 | ||
|
|
f204da0d68 | ||
|
|
7d74904d62 | ||
|
|
760ac5e07d | ||
|
|
4352228797 | ||
|
|
74c770609c | ||
|
|
f4ca36ed7e | ||
|
|
c86da92fc6 | ||
|
|
3f0c577456 | ||
|
|
717da8c7b7 | ||
|
|
a0a61d4f47 | ||
|
|
5b1fced872 | ||
|
|
c98dcf5ef9 | ||
|
|
57cb6bfccb | ||
|
|
95bf97dc3c | ||
|
|
3d116c9d33 | ||
|
|
a9ce9f8d5a | ||
|
|
10b981a855 | ||
|
|
7700b4333d | ||
|
|
7d0131111e | ||
|
|
1daea35e4b | ||
|
|
f97544af0d | ||
|
|
231e80cc15 | ||
|
|
a4c1362bff | ||
|
|
b611d4a751 | ||
|
|
2c9decfa55 | ||
|
|
3c5ac17e2f | ||
|
|
ae42bbb898 | ||
|
|
b86722394b | ||
|
|
a103f69767 | ||
|
|
73fbb3fc62 | ||
|
|
7b3523e25e | ||
|
|
6e4e1386e7 | ||
|
|
671e9af6eb | ||
|
|
50f42caf94 | ||
|
|
b7eeefc102 | ||
|
|
8dd22f3a4f | ||
|
|
4b89427447 | ||
|
|
b71e2860cf | ||
|
|
160b27bc60 | ||
|
|
c084386b88 | ||
|
|
6889047350 | ||
|
|
245bbb4acf | ||
|
|
2b2fc02d83 | ||
|
|
703ef29199 | ||
|
|
b0b60b938a | ||
|
|
e3a026bf1c | ||
|
|
94503465ee | ||
|
|
8d959b0abc | ||
|
|
1d8390b935 | ||
|
|
2851e38a1f | ||
|
|
51261fe7a9 | ||
|
|
304321d019 | ||
|
|
f8c3295645 | ||
|
|
183619d1e1 | ||
|
|
3b832d1f21 | ||
|
|
fcb849698f | ||
|
|
7527e0ebdb | ||
|
|
ed5f98da5b | ||
|
|
12b38e25da | ||
|
|
626e892e3b |
@@ -39,7 +39,7 @@ jobs:
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
fi
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name "combined" -not -name ".git*" | sort)
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name ".git*" | sort)
|
||||
|
||||
echo ""
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
|
||||
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -43,5 +43,5 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
|
||||
|
||||
|
||||
14
.github/workflows/golang-test-linux.yml
vendored
14
.github/workflows/golang-test-linux.yml
vendored
@@ -97,16 +97,6 @@ jobs:
|
||||
working-directory: relay
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||
|
||||
- name: Build combined
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: combined
|
||||
run: CGO_ENABLED=1 go build .
|
||||
|
||||
- name: Build combined 386
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: combined
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o combined-386 .
|
||||
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
needs: [build-cache]
|
||||
@@ -154,7 +144,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
@@ -214,7 +204,7 @@ jobs:
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
|
||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' })" >> $env:GITHUB_ENV
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
|
||||
16
.github/workflows/release.yml
vendored
16
.github/workflows/release.yml
vendored
@@ -160,7 +160,7 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Log in to the GitHub container registry
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
@@ -176,7 +176,6 @@ jobs:
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
@@ -186,19 +185,6 @@ jobs:
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
- name: Tag and push PR images (amd64 only)
|
||||
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
|
||||
run: |
|
||||
PR_TAG="pr-${{ github.event.pull_request.number }}"
|
||||
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||
grep '^ghcr.io/' | while read -r SRC; do
|
||||
IMG_NAME="${SRC%%:*}"
|
||||
DST="${IMG_NAME}:${PR_TAG}"
|
||||
echo "Tagging ${SRC} -> ${DST}"
|
||||
docker tag "$SRC" "$DST"
|
||||
docker push "$DST"
|
||||
done
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
|
||||
181
.goreleaser.yaml
181
.goreleaser.yaml
@@ -106,26 +106,6 @@ builds:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-server
|
||||
dir: combined
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- >-
|
||||
{{- if eq .Runtime.Goos "linux" }}
|
||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||
{{- end }}
|
||||
binary: netbird-server
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-upload
|
||||
dir: upload-server
|
||||
env: [CGO_ENABLED=0]
|
||||
@@ -140,20 +120,6 @@ builds:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-proxy
|
||||
dir: proxy/cmd/proxy
|
||||
env: [CGO_ENABLED=0]
|
||||
binary: netbird-proxy
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird
|
||||
|
||||
@@ -554,104 +520,6 @@ dockers:
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
docker_manifests:
|
||||
- name_template: netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
@@ -730,18 +598,6 @@ docker_manifests:
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird-server:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird-server:latest
|
||||
image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
@@ -819,43 +675,6 @@ docker_manifests:
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird-server:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/reverse-proxy:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/reverse-proxy:latest
|
||||
image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
brews:
|
||||
- ids:
|
||||
- default
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -1,4 +1,4 @@
|
||||
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/, relay/ and combined/.
|
||||
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||
|
||||
BSD 3-Clause License
|
||||
|
||||
@@ -1,19 +1,10 @@
|
||||
package android
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||
|
||||
var (
|
||||
// EnvKeyNBForceRelay Exported for Android java client to force relay connections
|
||||
// EnvKeyNBForceRelay Exported for Android java client
|
||||
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
|
||||
|
||||
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
|
||||
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
|
||||
|
||||
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
|
||||
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold
|
||||
)
|
||||
|
||||
// EnvList wraps a Go map for export to Java
|
||||
|
||||
@@ -358,9 +358,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
||||
} else {
|
||||
// Fallback for other lengths
|
||||
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||
|
||||
@@ -589,101 +589,6 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func Test_UserSpaceAddAllowedIPs(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+5)
|
||||
wgIP := "10.99.99.21/30"
|
||||
wgPort := 33105
|
||||
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = iface.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := iface.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = iface.Up()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
keepAlive := 15 * time.Second
|
||||
initialAllowedIP := netip.MustParsePrefix("10.99.99.22/32")
|
||||
endpoint, err := net.ResolveUDPAddr("udp", "127.0.0.1:9905")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add peer with initial endpoint and first allowed IP
|
||||
err = iface.UpdatePeer(peerPubKey, []netip.Prefix{initialAllowedIP}, keepAlive, endpoint, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Phase 1: generate 500 allowed IPs into a list
|
||||
const extraIPs = 500
|
||||
addedPrefixes := make([]netip.Prefix, 0, extraIPs)
|
||||
for i := 0; i < extraIPs; i++ {
|
||||
// Use 172.16.x.y/32 range: i encoded as two octets
|
||||
prefix := netip.MustParsePrefix(fmt.Sprintf("172.16.%d.%d/32", i/256, i%256))
|
||||
addedPrefixes = append(addedPrefixes, prefix)
|
||||
}
|
||||
|
||||
// Phase 2: iterate over the list and add each allowed IP to the peer
|
||||
phase2Start := time.Now()
|
||||
for _, prefix := range addedPrefixes {
|
||||
if addErr := iface.AddAllowedIP(peerPubKey, prefix); addErr != nil {
|
||||
t.Fatalf("failed to add allowed IP %s: %v", prefix, addErr)
|
||||
}
|
||||
}
|
||||
t.Logf("Phase 2 (add %d IPs to peer): %s", extraIPs, time.Since(phase2Start))
|
||||
|
||||
// Verify the peer has all 101 allowed IPs (1 initial + 100 added)
|
||||
peer, err := getPeer(ifaceName, peerPubKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if peer.Endpoint.String() != endpoint.String() {
|
||||
t.Fatalf("expected endpoint %s, got %s", endpoint, peer.Endpoint)
|
||||
}
|
||||
|
||||
allExpected := append([]netip.Prefix{initialAllowedIP}, addedPrefixes...)
|
||||
if len(peer.AllowedIPs) != len(allExpected) {
|
||||
t.Fatalf("expected %d allowed IPs, got %d", len(allExpected), len(peer.AllowedIPs))
|
||||
}
|
||||
|
||||
allowedIPSet := make(map[string]struct{}, len(peer.AllowedIPs))
|
||||
for _, aip := range peer.AllowedIPs {
|
||||
allowedIPSet[aip.String()] = struct{}{}
|
||||
}
|
||||
for _, expected := range allExpected {
|
||||
if _, found := allowedIPSet[expected.String()]; !found {
|
||||
t.Errorf("expected allowed IP %s not found in peer config", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
|
||||
@@ -290,10 +290,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
if relayClient.IsDisableRelay() {
|
||||
relayURLs = []string{}
|
||||
}
|
||||
|
||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
||||
c.statusRecorder.SetRelayMgr(relayManager)
|
||||
if len(relayURLs) > 0 {
|
||||
@@ -314,8 +310,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
c.engineMutex.Lock()
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
|
||||
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
engine.SetReadyChan(runningChan)
|
||||
runningChan = nil
|
||||
c.engine = engine
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
@@ -336,6 +330,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil {
|
||||
close(runningChan)
|
||||
runningChan = nil
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
|
||||
@@ -42,8 +42,6 @@ const (
|
||||
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
||||
dnsPolicyConfigConfigOptionsValue = 0x8
|
||||
|
||||
nrptMaxDomainsPerRule = 50
|
||||
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
@@ -200,11 +198,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||
// Update count even on error to ensure cleanup covers partially created rules
|
||||
r.nrptEntryCount = count
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns match policy: %w", err)
|
||||
}
|
||||
r.nrptEntryCount = count
|
||||
} else {
|
||||
r.nrptEntryCount = 0
|
||||
}
|
||||
@@ -242,33 +239,23 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
||||
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
|
||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||
for i, domain := range domains {
|
||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||
|
||||
// We need to batch domains into chunks and create one NRPT rule per batch.
|
||||
ruleIndex := 0
|
||||
for i := 0; i < len(domains); i += nrptMaxDomainsPerRule {
|
||||
end := i + nrptMaxDomainsPerRule
|
||||
if end > len(domains) {
|
||||
end = len(domains)
|
||||
singleDomain := []string{domain}
|
||||
|
||||
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
||||
}
|
||||
batchDomains := domains[i:end]
|
||||
|
||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex)
|
||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex)
|
||||
|
||||
if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil {
|
||||
return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err)
|
||||
}
|
||||
|
||||
// Increment immediately so the caller's cleanup path knows about this rule
|
||||
ruleIndex++
|
||||
|
||||
if r.gpo {
|
||||
if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil {
|
||||
return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex-1, err)
|
||||
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("added NRPT rule %d with %d domains", ruleIndex-1, len(batchDomains))
|
||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
||||
}
|
||||
|
||||
if r.gpo {
|
||||
@@ -277,8 +264,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("added %d NRPT rules for %d domains. Domain list: %v", ruleIndex, len(domains), domains)
|
||||
return ruleIndex, nil
|
||||
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
|
||||
return len(domains), nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||
// when the number of match domains decreases between configuration changes.
|
||||
// With batching enabled (50 domains per rule), we need enough domains to create multiple rules.
|
||||
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping registry integration test in short mode")
|
||||
@@ -38,60 +37,51 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
|
||||
domains125 := make([]DomainConfig, 125)
|
||||
for i := 0; i < 125; i++ {
|
||||
domains125[i] = DomainConfig{
|
||||
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||
MatchOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
config125 := HostDNSConfig{
|
||||
config5 := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: domains125,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "domain1.com", MatchOnly: true},
|
||||
{Domain: "domain2.com", MatchOnly: true},
|
||||
{Domain: "domain3.com", MatchOnly: true},
|
||||
{Domain: "domain4.com", MatchOnly: true},
|
||||
{Domain: "domain5.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err = cfg.applyDNSConfig(config125, nil)
|
||||
err = cfg.applyDNSConfig(config5, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify 3 NRPT rules exist
|
||||
assert.Equal(t, 3, cfg.nrptEntryCount, "Should create 3 NRPT rules for 125 domains")
|
||||
for i := 0; i < 3; i++ {
|
||||
// Verify all 5 entries exist
|
||||
for i := 0; i < 5; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "NRPT rule %d should exist after first config", i)
|
||||
assert.True(t, exists, "Entry %d should exist after first config", i)
|
||||
}
|
||||
|
||||
// Reduce to 75 domains which will result in 2 NRPT rules (50+25)
|
||||
domains75 := make([]DomainConfig, 75)
|
||||
for i := 0; i < 75; i++ {
|
||||
domains75[i] = DomainConfig{
|
||||
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||
MatchOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
config75 := HostDNSConfig{
|
||||
config2 := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: domains75,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "domain1.com", MatchOnly: true},
|
||||
{Domain: "domain2.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err = cfg.applyDNSConfig(config75, nil)
|
||||
err = cfg.applyDNSConfig(config2, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify first 2 NRPT rules exist
|
||||
assert.Equal(t, 2, cfg.nrptEntryCount, "Should create 2 NRPT rules for 75 domains")
|
||||
// Verify first 2 entries exist
|
||||
for i := 0; i < 2; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "NRPT rule %d should exist after second config", i)
|
||||
assert.True(t, exists, "Entry %d should exist after second config", i)
|
||||
}
|
||||
|
||||
// Verify rule 2 is cleaned up
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, 2))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "NRPT rule 2 should NOT exist after reducing to 75 domains")
|
||||
// Verify entries 2-4 are cleaned up
|
||||
for i := 2; i < 5; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
|
||||
}
|
||||
}
|
||||
|
||||
func registryKeyExists(path string) (bool, error) {
|
||||
@@ -107,106 +97,6 @@ func registryKeyExists(path string) (bool, error) {
|
||||
}
|
||||
|
||||
func cleanupRegistryKeys(*testing.T) {
|
||||
// Clean up more entries to account for batching tests with many domains
|
||||
cfg := ®istryConfigurator{nrptEntryCount: 20}
|
||||
cfg := ®istryConfigurator{nrptEntryCount: 10}
|
||||
_ = cfg.removeDNSMatchPolicies()
|
||||
}
|
||||
|
||||
// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules.
|
||||
func TestNRPTDomainBatching(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping registry integration test in short mode")
|
||||
}
|
||||
|
||||
defer cleanupRegistryKeys(t)
|
||||
cleanupRegistryKeys(t)
|
||||
|
||||
testIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
// Create a test interface registry key so updateSearchDomains doesn't fail
|
||||
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
|
||||
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
|
||||
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
|
||||
require.NoError(t, err, "Should create test interface registry key")
|
||||
testKey.Close()
|
||||
defer func() {
|
||||
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
domainCount int
|
||||
expectedRuleCount int
|
||||
}{
|
||||
{
|
||||
name: "Less than 50 domains (single rule)",
|
||||
domainCount: 30,
|
||||
expectedRuleCount: 1,
|
||||
},
|
||||
{
|
||||
name: "Exactly 50 domains (single rule)",
|
||||
domainCount: 50,
|
||||
expectedRuleCount: 1,
|
||||
},
|
||||
{
|
||||
name: "51 domains (two rules)",
|
||||
domainCount: 51,
|
||||
expectedRuleCount: 2,
|
||||
},
|
||||
{
|
||||
name: "100 domains (two rules)",
|
||||
domainCount: 100,
|
||||
expectedRuleCount: 2,
|
||||
},
|
||||
{
|
||||
name: "125 domains (three rules: 50+50+25)",
|
||||
domainCount: 125,
|
||||
expectedRuleCount: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Clean up before each subtest
|
||||
cleanupRegistryKeys(t)
|
||||
|
||||
// Generate domains
|
||||
domains := make([]DomainConfig, tc.domainCount)
|
||||
for i := 0; i < tc.domainCount; i++ {
|
||||
domains[i] = DomainConfig{
|
||||
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||
MatchOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: domains,
|
||||
}
|
||||
|
||||
err := cfg.applyDNSConfig(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify that exactly expectedRuleCount rules were created
|
||||
assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount,
|
||||
"Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount)
|
||||
|
||||
// Verify all expected rules exist
|
||||
for i := 0; i < tc.expectedRuleCount; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "NRPT rule %d should exist", i)
|
||||
}
|
||||
|
||||
// Verify no extra rules were created
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,18 +84,3 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeginBatch mock implementation of BeginBatch from Server interface
|
||||
func (m *MockServer) BeginBatch() {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// EndBatch mock implementation of EndBatch from Server interface
|
||||
func (m *MockServer) EndBatch() {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// CancelBatch mock implementation of CancelBatch from Server interface
|
||||
func (m *MockServer) CancelBatch() {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
@@ -45,9 +45,6 @@ type IosDnsManager interface {
|
||||
type Server interface {
|
||||
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
||||
DeregisterHandler(domains domain.List, priority int)
|
||||
BeginBatch()
|
||||
EndBatch()
|
||||
CancelBatch()
|
||||
Initialize() error
|
||||
Stop()
|
||||
DnsIP() netip.Addr
|
||||
@@ -90,7 +87,6 @@ type DefaultServer struct {
|
||||
currentConfigHash uint64
|
||||
handlerChain *HandlerChain
|
||||
extraDomains map[domain.Domain]int
|
||||
batchMode bool
|
||||
|
||||
mgmtCacheResolver *mgmt.Resolver
|
||||
|
||||
@@ -238,9 +234,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
|
||||
// convert to zone with simple ref counter
|
||||
s.extraDomains[toZone(domain)]++
|
||||
}
|
||||
if !s.batchMode {
|
||||
s.applyHostConfig()
|
||||
}
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||
@@ -269,41 +263,9 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
||||
delete(s.extraDomains, zone)
|
||||
}
|
||||
}
|
||||
if !s.batchMode {
|
||||
s.applyHostConfig()
|
||||
}
|
||||
}
|
||||
|
||||
// BeginBatch starts batch mode for DNS handler registration/deregistration.
|
||||
// In batch mode, applyHostConfig() is not called after each handler operation,
|
||||
// allowing multiple handlers to be registered/deregistered efficiently.
|
||||
// Must be followed by EndBatch() to apply the accumulated changes.
|
||||
func (s *DefaultServer) BeginBatch() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
log.Debugf("DNS batch mode enabled")
|
||||
s.batchMode = true
|
||||
}
|
||||
|
||||
// EndBatch ends batch mode and applies all accumulated DNS configuration changes.
|
||||
func (s *DefaultServer) EndBatch() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
log.Debugf("DNS batch mode disabled, applying accumulated changes")
|
||||
s.batchMode = false
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
// CancelBatch cancels batch mode without applying accumulated changes.
|
||||
// This is useful when operations fail partway through and you want to
|
||||
// discard partial state rather than applying it.
|
||||
func (s *DefaultServer) CancelBatch() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
log.Debugf("DNS batch mode cancelled, discarding accumulated changes")
|
||||
s.batchMode = false
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||
log.Debugf("deregistering handler with priority %d for %v", priority, domains)
|
||||
|
||||
@@ -561,7 +523,6 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
s.currentConfig.RouteAll = false
|
||||
}
|
||||
|
||||
// Always apply host config for management updates, regardless of batch mode
|
||||
s.applyHostConfig()
|
||||
|
||||
s.shutdownWg.Add(1)
|
||||
@@ -926,7 +887,6 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
}
|
||||
}
|
||||
|
||||
// Always apply host config when nameserver goes down, regardless of batch mode
|
||||
s.applyHostConfig()
|
||||
|
||||
go func() {
|
||||
@@ -962,7 +922,6 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||
}
|
||||
|
||||
// Always apply host config when nameserver reactivates, regardless of batch mode
|
||||
s.applyHostConfig()
|
||||
|
||||
s.updateNSState(nsGroup, nil, true)
|
||||
|
||||
@@ -18,12 +18,7 @@ func TestGetServerDns(t *testing.T) {
|
||||
t.Errorf("invalid dns server instance: %s", err)
|
||||
}
|
||||
|
||||
mockSrvB, ok := srvB.(*MockServer)
|
||||
if !ok {
|
||||
t.Errorf("returned server is not a MockServer")
|
||||
}
|
||||
|
||||
if mockSrvB != srv {
|
||||
if srvB != srv {
|
||||
t.Errorf("mismatch dns instances")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,8 +28,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
@@ -217,10 +217,6 @@ type Engine struct {
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
|
||||
// readyChan is closed when the first sync message is received from management
|
||||
readyChan chan struct{}
|
||||
readyChanOnce sync.Once
|
||||
|
||||
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
|
||||
shutdownWg sync.WaitGroup
|
||||
|
||||
@@ -279,10 +275,6 @@ func NewEngine(
|
||||
return engine
|
||||
}
|
||||
|
||||
func (e *Engine) SetReadyChan(ch chan struct{}) {
|
||||
e.readyChan = ch
|
||||
}
|
||||
|
||||
func (e *Engine) Stop() error {
|
||||
if e == nil {
|
||||
// this seems to be a very odd case but there was the possibility if the netbird down command comes before the engine is fully started
|
||||
@@ -842,13 +834,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
defer func() {
|
||||
log.Infof("sync finished in %s", time.Since(started))
|
||||
}()
|
||||
|
||||
e.readyChanOnce.Do(func() {
|
||||
if e.readyChan != nil {
|
||||
close(e.readyChan)
|
||||
}
|
||||
})
|
||||
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
@@ -895,11 +880,9 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
uCheckTime := time.Now()
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("update check finished in %s", time.Since(uCheckTime))
|
||||
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
@@ -942,9 +925,7 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||
return fmt.Errorf("update relay token: %w", err)
|
||||
}
|
||||
|
||||
if !relayClient.IsDisableRelay() {
|
||||
e.relayManager.UpdateServerURLs(update.Urls)
|
||||
}
|
||||
e.relayManager.UpdateServerURLs(update.Urls)
|
||||
|
||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||
|
||||
@@ -410,7 +410,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||
}
|
||||
|
||||
func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
|
||||
func (conn *Conn) onICEStateDisconnected() {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
@@ -430,18 +430,14 @@ func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
|
||||
if conn.isReadyToUpgrade() {
|
||||
conn.Log.Infof("ICE disconnected, set Relay to active connection")
|
||||
conn.dumpState.SwitchToRelay()
|
||||
if sessionChanged {
|
||||
conn.resetEndpoint()
|
||||
}
|
||||
|
||||
// todo consider to move after the ConfigureWGEndpoint
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
||||
if err := conn.endpointUpdater.SwitchWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
|
||||
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
|
||||
conn.wgProxyRelay.Work()
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
} else {
|
||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||
@@ -503,22 +499,20 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
controller := isController(conn.config)
|
||||
wgProxy.Work()
|
||||
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
||||
|
||||
if controller {
|
||||
wgProxy.Work()
|
||||
}
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), conn.presharedKey(rci.rosenpassPubKey)); err != nil {
|
||||
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||
}
|
||||
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
if !controller {
|
||||
wgProxy.Work()
|
||||
}
|
||||
|
||||
wgConfigWorkaround()
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
conn.statusRelay.SetConnected()
|
||||
@@ -763,17 +757,6 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
return wgProxy, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) resetEndpoint() {
|
||||
if !isController(conn.config) {
|
||||
return
|
||||
}
|
||||
conn.Log.Infof("reset wg endpoint")
|
||||
conn.wgWatcher.Reset()
|
||||
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
|
||||
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) isReadyToUpgrade() bool {
|
||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
|
||||
}
|
||||
@@ -879,3 +862,9 @@ func isController(config ConnConfig) bool {
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
return remoteRosenpassPubKey != nil
|
||||
}
|
||||
|
||||
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
|
||||
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
|
||||
func wgConfigWorkaround() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -34,27 +34,28 @@ func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *E
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureWGEndpoint sets up the WireGuard endpoint configuration.
|
||||
// The initiator immediately configures the endpoint, while the non-initiator
|
||||
// waits for a fallback period before configuring to avoid handshake congestion.
|
||||
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.initiator {
|
||||
e.log.Debugf("configure up WireGuard as initiator")
|
||||
return e.configureAsInitiator(addr, presharedKey)
|
||||
e.log.Debugf("configure up WireGuard as initiatr")
|
||||
return e.updateWireGuardPeer(addr, presharedKey)
|
||||
}
|
||||
|
||||
e.log.Debugf("configure up WireGuard as responder")
|
||||
return e.configureAsResponder(addr, presharedKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) SwitchWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
// prevent to run new update while cancel the previous update
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
|
||||
return e.updateWireGuardPeer(addr, presharedKey)
|
||||
var ctx context.Context
|
||||
ctx, e.cancelFunc = context.WithCancel(context.Background())
|
||||
e.updateWg.Add(1)
|
||||
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
|
||||
|
||||
e.log.Debugf("configure up WireGuard and wait for handshake")
|
||||
return e.updateWireGuardPeer(nil, presharedKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||
@@ -65,38 +66,6 @@ func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) RemoveEndpointAddress() error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) configureAsInitiator(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) configureAsResponder(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
// prevent to run new update while cancel the previous update
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
|
||||
e.log.Debugf("configure up WireGuard and wait for handshake")
|
||||
var ctx context.Context
|
||||
ctx, e.cancelFunc = context.WithCancel(context.Background())
|
||||
e.updateWg.Add(1)
|
||||
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
|
||||
|
||||
if err := e.updateWireGuardPeer(nil, presharedKey); err != nil {
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
||||
if e.cancelFunc == nil {
|
||||
return
|
||||
@@ -132,9 +101,3 @@ func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKe
|
||||
presharedKey,
|
||||
)
|
||||
}
|
||||
|
||||
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
|
||||
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
|
||||
func wgConfigWorkaround() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -126,10 +125,6 @@ func GenerateICECredentials() (string, string, error) {
|
||||
}
|
||||
|
||||
func CandidateTypes() []ice.CandidateType {
|
||||
if relayClient.IsDisableRelay() {
|
||||
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
|
||||
}
|
||||
|
||||
if hasICEForceRelayConn() {
|
||||
return []ice.CandidateType{ice.CandidateTypeRelay}
|
||||
}
|
||||
|
||||
@@ -32,8 +32,6 @@ type WGWatcher struct {
|
||||
|
||||
enabled bool
|
||||
muEnabled sync.RWMutex
|
||||
|
||||
resetCh chan struct{}
|
||||
}
|
||||
|
||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||
@@ -42,7 +40,6 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
||||
wgIfaceStater: wgIfaceStater,
|
||||
peerKey: peerKey,
|
||||
stateDump: stateDump,
|
||||
resetCh: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,15 +76,6 @@ func (w *WGWatcher) IsEnabled() bool {
|
||||
return w.enabled
|
||||
}
|
||||
|
||||
// Reset signals the watcher that the WireGuard peer has been reset and a new
|
||||
// handshake is expected. This restarts the handshake timeout from scratch.
|
||||
func (w *WGWatcher) Reset() {
|
||||
select {
|
||||
case w.resetCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
|
||||
w.log.Infof("WireGuard watcher started")
|
||||
@@ -117,12 +105,6 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
|
||||
w.stateDump.WGcheckSuccess()
|
||||
|
||||
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
||||
case <-w.resetCh:
|
||||
w.log.Infof("WireGuard watcher received peer reset, restarting handshake timeout")
|
||||
lastHandshake = time.Time{}
|
||||
enabledTime = time.Now()
|
||||
timer.Stop()
|
||||
timer.Reset(wgHandshakeOvertime)
|
||||
case <-ctx.Done():
|
||||
w.log.Infof("WireGuard watcher stopped")
|
||||
return
|
||||
|
||||
@@ -52,9 +52,8 @@ type WorkerICE struct {
|
||||
// increase by one when disconnecting the agent
|
||||
// with it the remote peer can discard the already deprecated offer/answer
|
||||
// Without it the remote peer may recreate a workable ICE connection
|
||||
sessionID ICESessionID
|
||||
remoteSessionChanged bool
|
||||
muxAgent sync.Mutex
|
||||
sessionID ICESessionID
|
||||
muxAgent sync.Mutex
|
||||
|
||||
localUfrag string
|
||||
localPwd string
|
||||
@@ -107,7 +106,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
return
|
||||
}
|
||||
w.log.Debugf("agent already exists, recreate the connection")
|
||||
w.remoteSessionChanged = true
|
||||
w.agentDialerCancel()
|
||||
if w.agent != nil {
|
||||
if err := w.agent.Close(); err != nil {
|
||||
@@ -308,17 +306,13 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
||||
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
|
||||
}
|
||||
|
||||
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) bool {
|
||||
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) {
|
||||
cancel()
|
||||
if err := agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
|
||||
w.muxAgent.Lock()
|
||||
defer w.muxAgent.Unlock()
|
||||
|
||||
sessionChanged := w.remoteSessionChanged
|
||||
w.remoteSessionChanged = false
|
||||
|
||||
if w.agent == agent {
|
||||
// consider to remove from here and move to the OnNewOffer
|
||||
@@ -331,7 +325,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
||||
w.agentConnecting = false
|
||||
w.remoteSessionID = ""
|
||||
}
|
||||
return sessionChanged
|
||||
w.muxAgent.Unlock()
|
||||
}
|
||||
|
||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||
@@ -432,11 +426,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||
|
||||
sessionChanged := w.closeAgent(agent, dialerCancel)
|
||||
w.closeAgent(agent, dialerCancel)
|
||||
|
||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.onICEStateDisconnected(sessionChanged)
|
||||
w.conn.onICEStateDisconnected()
|
||||
}
|
||||
default:
|
||||
return
|
||||
|
||||
@@ -346,23 +346,6 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
// Begin batch mode to avoid calling applyHostConfig() after each DNS handler operation
|
||||
batchStarted := false
|
||||
if m.dnsServer != nil {
|
||||
m.dnsServer.BeginBatch()
|
||||
batchStarted = true
|
||||
defer func() {
|
||||
if merr != nil {
|
||||
// On error, cancel batch to discard partial DNS state
|
||||
m.dnsServer.CancelBatch()
|
||||
} else {
|
||||
// On success, apply accumulated DNS changes
|
||||
m.dnsServer.EndBatch()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for id, handler := range toRemove {
|
||||
if err := handler.RemoveRoute(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
|
||||
@@ -393,7 +376,6 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
||||
m.activeRoutes[id] = handler
|
||||
}
|
||||
|
||||
_ = batchStarted // Mark as used
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,10 +2,7 @@
|
||||
|
||||
package NetBirdSDK
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||
|
||||
// EnvList is an exported struct to be bound by gomobile
|
||||
type EnvList struct {
|
||||
@@ -35,13 +32,3 @@ func (el *EnvList) AllItems() map[string]string {
|
||||
func GetEnvKeyNBForceRelay() string {
|
||||
return peer.EnvKeyNBForceRelay
|
||||
}
|
||||
|
||||
// GetEnvKeyNBLazyConn Exports the environment variable for the iOS client
|
||||
func GetEnvKeyNBLazyConn() string {
|
||||
return lazyconn.EnvEnableLazyConn
|
||||
}
|
||||
|
||||
// GetEnvKeyNBInactivityThreshold Exports the environment variable for the iOS client
|
||||
func GetEnvKeyNBInactivityThreshold() string {
|
||||
return lazyconn.EnvInactivityThreshold
|
||||
}
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
FROM ubuntu:24.04
|
||||
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||
ENTRYPOINT [ "/go/bin/netbird-server" ]
|
||||
CMD ["--config", "/etc/netbird/config.yaml"]
|
||||
COPY netbird-server /go/bin/netbird-server
|
||||
@@ -1,25 +0,0 @@
|
||||
FROM golang:1.25-bookworm AS builder
|
||||
WORKDIR /app
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update && apt-get install -y gcc libc6-dev git && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
# Build with version info from git (matching goreleaser ldflags)
|
||||
RUN CGO_ENABLED=1 GOOS=linux go build \
|
||||
-ldflags="-s -w \
|
||||
-X github.com/netbirdio/netbird/version.version=$(git describe --tags --always --dirty 2>/dev/null || echo 'dev') \
|
||||
-X main.commit=$(git rev-parse --short HEAD 2>/dev/null || echo 'unknown') \
|
||||
-X main.date=$(date -u +%Y-%m-%dT%H:%M:%SZ) \
|
||||
-X main.builtBy=docker" \
|
||||
-o netbird-server ./combined
|
||||
|
||||
FROM ubuntu:24.04
|
||||
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||
ENTRYPOINT [ "/go/bin/netbird-server" ]
|
||||
CMD ["--config", "/etc/netbird/config.yaml"]
|
||||
COPY --from=builder /app/netbird-server /go/bin/netbird-server
|
||||
661
combined/LICENSE
661
combined/LICENSE
@@ -1,661 +0,0 @@
|
||||
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||
Version 3, 19 November 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU Affero General Public License is a free, copyleft license for
|
||||
software and other kinds of works, specifically designed to ensure
|
||||
cooperation with the community in the case of network server software.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
our General Public Licenses are intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
Developers that use our General Public Licenses protect your rights
|
||||
with two steps: (1) assert copyright on the software, and (2) offer
|
||||
you this License which gives you legal permission to copy, distribute
|
||||
and/or modify the software.
|
||||
|
||||
A secondary benefit of defending all users' freedom is that
|
||||
improvements made in alternate versions of the program, if they
|
||||
receive widespread use, become available for other developers to
|
||||
incorporate. Many developers of free software are heartened and
|
||||
encouraged by the resulting cooperation. However, in the case of
|
||||
software used on network servers, this result may fail to come about.
|
||||
The GNU General Public License permits making a modified version and
|
||||
letting the public access it on a server without ever releasing its
|
||||
source code to the public.
|
||||
|
||||
The GNU Affero General Public License is designed specifically to
|
||||
ensure that, in such cases, the modified source code becomes available
|
||||
to the community. It requires the operator of a network server to
|
||||
provide the source code of the modified version running there to the
|
||||
users of that server. Therefore, public use of a modified version, on
|
||||
a publicly accessible server, gives the public access to the source
|
||||
code of the modified version.
|
||||
|
||||
An older license, called the Affero General Public License and
|
||||
published by Affero, was designed to accomplish similar goals. This is
|
||||
a different license, not a version of the Affero GPL, but Affero has
|
||||
released a new version of the Affero GPL which permits relicensing under
|
||||
this license.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, if you modify the
|
||||
Program, your modified version must prominently offer all users
|
||||
interacting with it remotely through a computer network (if your version
|
||||
supports such interaction) an opportunity to receive the Corresponding
|
||||
Source of your version by providing access to the Corresponding Source
|
||||
from a network server at no charge, through some standard or customary
|
||||
means of facilitating copying of software. This Corresponding Source
|
||||
shall include the Corresponding Source for any work covered by version 3
|
||||
of the GNU General Public License that is incorporated pursuant to the
|
||||
following paragraph.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the work with which it is combined will remain governed by version
|
||||
3 of the GNU General Public License.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU Affero General Public License from time to time. Such new versions
|
||||
will be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU Affero General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU Affero General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU Affero General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If your software can interact with users remotely through a computer
|
||||
network, you should also make sure that it provides a way for users to
|
||||
get its source. For example, if your program is a web application, its
|
||||
interface could display a "Source" link that leads users to an archive
|
||||
of the code. There are many ways you could offer source, and different
|
||||
solutions will be better for different programs; see section 13 for the
|
||||
specific requirements.
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
@@ -1,723 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
)
|
||||
|
||||
// CombinedConfig is the root configuration for the combined server.
|
||||
// The combined server is primarily a Management server with optional embedded
|
||||
// Signal, Relay, and STUN services.
|
||||
//
|
||||
// Architecture:
|
||||
// - Management: Always runs locally (this IS the management server)
|
||||
// - Signal: Runs locally by default; disabled if server.signalUri is set
|
||||
// - Relay: Runs locally by default; disabled if server.relays is set
|
||||
// - STUN: Runs locally on port 3478 by default; disabled if server.stuns is set
|
||||
//
|
||||
// All user-facing settings are under "server". The relay/signal/management
|
||||
// fields are internal and populated automatically from server settings.
|
||||
type CombinedConfig struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
|
||||
// Internal configs - populated from Server settings, not user-configurable
|
||||
Relay RelayConfig `yaml:"-"`
|
||||
Signal SignalConfig `yaml:"-"`
|
||||
Management ManagementConfig `yaml:"-"`
|
||||
}
|
||||
|
||||
// ServerConfig contains server-wide settings
|
||||
// In simplified mode, this contains all configuration
|
||||
type ServerConfig struct {
|
||||
ListenAddress string `yaml:"listenAddress"`
|
||||
MetricsPort int `yaml:"metricsPort"`
|
||||
HealthcheckAddress string `yaml:"healthcheckAddress"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
LogFile string `yaml:"logFile"`
|
||||
TLS TLSConfig `yaml:"tls"`
|
||||
|
||||
// Simplified config fields (used when relay/signal/management sections are omitted)
|
||||
ExposedAddress string `yaml:"exposedAddress"` // Public address with protocol (e.g., "https://example.com:443")
|
||||
StunPorts []int `yaml:"stunPorts"` // STUN ports (empty to disable local STUN)
|
||||
AuthSecret string `yaml:"authSecret"` // Shared secret for relay authentication
|
||||
DataDir string `yaml:"dataDir"` // Data directory for all services
|
||||
|
||||
// External service overrides (simplified mode)
|
||||
// When these are set, the corresponding local service is NOT started
|
||||
// and these values are used for client configuration instead
|
||||
Stuns []HostConfig `yaml:"stuns"` // External STUN servers (disables local STUN)
|
||||
Relays RelaysConfig `yaml:"relays"` // External relay servers (disables local relay)
|
||||
SignalURI string `yaml:"signalUri"` // External signal server (disables local signal)
|
||||
|
||||
// Management settings (simplified mode)
|
||||
DisableAnonymousMetrics bool `yaml:"disableAnonymousMetrics"`
|
||||
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
Store StoreConfig `yaml:"store"`
|
||||
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
|
||||
}
|
||||
|
||||
// TLSConfig contains TLS/HTTPS settings
|
||||
type TLSConfig struct {
|
||||
CertFile string `yaml:"certFile"`
|
||||
KeyFile string `yaml:"keyFile"`
|
||||
LetsEncrypt LetsEncryptConfig `yaml:"letsencrypt"`
|
||||
}
|
||||
|
||||
// LetsEncryptConfig contains Let's Encrypt settings
|
||||
type LetsEncryptConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
DataDir string `yaml:"dataDir"`
|
||||
Domains []string `yaml:"domains"`
|
||||
Email string `yaml:"email"`
|
||||
AWSRoute53 bool `yaml:"awsRoute53"`
|
||||
}
|
||||
|
||||
// RelayConfig contains relay service settings
|
||||
type RelayConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
ExposedAddress string `yaml:"exposedAddress"`
|
||||
AuthSecret string `yaml:"authSecret"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
Stun StunConfig `yaml:"stun"`
|
||||
}
|
||||
|
||||
// StunConfig contains embedded STUN service settings
|
||||
type StunConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Ports []int `yaml:"ports"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
}
|
||||
|
||||
// SignalConfig contains signal service settings
|
||||
type SignalConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
}
|
||||
|
||||
// ManagementConfig contains management service settings
|
||||
type ManagementConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
DataDir string `yaml:"dataDir"`
|
||||
DnsDomain string `yaml:"dnsDomain"`
|
||||
DisableAnonymousMetrics bool `yaml:"disableAnonymousMetrics"`
|
||||
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
|
||||
DisableDefaultPolicy bool `yaml:"disableDefaultPolicy"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
Stuns []HostConfig `yaml:"stuns"`
|
||||
Relays RelaysConfig `yaml:"relays"`
|
||||
SignalURI string `yaml:"signalUri"`
|
||||
Store StoreConfig `yaml:"store"`
|
||||
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
|
||||
}
|
||||
|
||||
// AuthConfig contains authentication/identity provider settings
|
||||
type AuthConfig struct {
|
||||
Issuer string `yaml:"issuer"`
|
||||
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
|
||||
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
|
||||
Storage AuthStorageConfig `yaml:"storage"`
|
||||
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
|
||||
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
|
||||
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
|
||||
}
|
||||
|
||||
// AuthStorageConfig contains auth storage settings
|
||||
type AuthStorageConfig struct {
|
||||
Type string `yaml:"type"`
|
||||
File string `yaml:"file"`
|
||||
}
|
||||
|
||||
// AuthOwnerConfig contains initial admin user settings
|
||||
type AuthOwnerConfig struct {
|
||||
Email string `yaml:"email"`
|
||||
Password string `yaml:"password"`
|
||||
}
|
||||
|
||||
// HostConfig represents a STUN/TURN/Signal host
|
||||
type HostConfig struct {
|
||||
URI string `yaml:"uri"`
|
||||
Proto string `yaml:"proto,omitempty"` // udp, dtls, tcp, http, https - defaults based on URI scheme
|
||||
Username string `yaml:"username,omitempty"`
|
||||
Password string `yaml:"password,omitempty"`
|
||||
}
|
||||
|
||||
// RelaysConfig contains external relay server settings for clients
|
||||
type RelaysConfig struct {
|
||||
Addresses []string `yaml:"addresses"`
|
||||
CredentialsTTL string `yaml:"credentialsTTL"`
|
||||
Secret string `yaml:"secret"`
|
||||
}
|
||||
|
||||
// StoreConfig contains database settings
|
||||
type StoreConfig struct {
|
||||
Engine string `yaml:"engine"`
|
||||
EncryptionKey string `yaml:"encryptionKey"`
|
||||
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
|
||||
}
|
||||
|
||||
// ReverseProxyConfig contains reverse proxy settings
|
||||
type ReverseProxyConfig struct {
|
||||
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
|
||||
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
|
||||
TrustedPeers []string `yaml:"trustedPeers"`
|
||||
}
|
||||
|
||||
// DefaultConfig returns a CombinedConfig with default values
|
||||
func DefaultConfig() *CombinedConfig {
|
||||
return &CombinedConfig{
|
||||
Server: ServerConfig{
|
||||
ListenAddress: ":443",
|
||||
MetricsPort: 9090,
|
||||
HealthcheckAddress: ":9000",
|
||||
LogLevel: "info",
|
||||
LogFile: "console",
|
||||
StunPorts: []int{3478},
|
||||
DataDir: "/var/lib/netbird/",
|
||||
Auth: AuthConfig{
|
||||
Storage: AuthStorageConfig{
|
||||
Type: "sqlite3",
|
||||
},
|
||||
},
|
||||
Store: StoreConfig{
|
||||
Engine: "sqlite",
|
||||
},
|
||||
},
|
||||
Relay: RelayConfig{
|
||||
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
|
||||
Stun: StunConfig{
|
||||
Enabled: false,
|
||||
Ports: []int{3478},
|
||||
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
|
||||
},
|
||||
},
|
||||
Signal: SignalConfig{
|
||||
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
|
||||
},
|
||||
Management: ManagementConfig{
|
||||
DataDir: "/var/lib/netbird/",
|
||||
Auth: AuthConfig{
|
||||
Storage: AuthStorageConfig{
|
||||
Type: "sqlite3",
|
||||
},
|
||||
},
|
||||
Relays: RelaysConfig{
|
||||
CredentialsTTL: "12h",
|
||||
},
|
||||
Store: StoreConfig{
|
||||
Engine: "sqlite",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// hasRequiredSettings returns true if the configuration has the required server settings
|
||||
func (c *CombinedConfig) hasRequiredSettings() bool {
|
||||
return c.Server.ExposedAddress != ""
|
||||
}
|
||||
|
||||
// parseExposedAddress extracts protocol, host, and host:port from the exposed address
|
||||
// Input format: "https://example.com:443" or "http://example.com:8080" or "example.com:443"
|
||||
// Returns: protocol ("https" or "http"), hostname only, and host:port
|
||||
func parseExposedAddress(exposedAddress string) (protocol, hostname, hostPort string) {
|
||||
// Default to https if no protocol specified
|
||||
protocol = "https"
|
||||
hostPort = exposedAddress
|
||||
|
||||
// Check for protocol prefix
|
||||
if strings.HasPrefix(exposedAddress, "https://") {
|
||||
protocol = "https"
|
||||
hostPort = strings.TrimPrefix(exposedAddress, "https://")
|
||||
} else if strings.HasPrefix(exposedAddress, "http://") {
|
||||
protocol = "http"
|
||||
hostPort = strings.TrimPrefix(exposedAddress, "http://")
|
||||
}
|
||||
|
||||
// Extract hostname (without port)
|
||||
hostname = hostPort
|
||||
if host, _, err := net.SplitHostPort(hostPort); err == nil {
|
||||
hostname = host
|
||||
}
|
||||
|
||||
return protocol, hostname, hostPort
|
||||
}
|
||||
|
||||
// ApplySimplifiedDefaults populates internal relay/signal/management configs from server settings.
|
||||
// Management is always enabled. Signal, Relay, and STUN are enabled unless external
|
||||
// overrides are configured (server.signalUri, server.relays, server.stuns).
|
||||
func (c *CombinedConfig) ApplySimplifiedDefaults() {
|
||||
if !c.hasRequiredSettings() {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse exposed address to extract protocol and hostname
|
||||
exposedProto, exposedHost, exposedHostPort := parseExposedAddress(c.Server.ExposedAddress)
|
||||
|
||||
// Check for external service overrides
|
||||
hasExternalRelay := len(c.Server.Relays.Addresses) > 0
|
||||
hasExternalSignal := c.Server.SignalURI != ""
|
||||
hasExternalStuns := len(c.Server.Stuns) > 0
|
||||
|
||||
// Default stunPorts to [3478] if not specified and no external STUN
|
||||
if len(c.Server.StunPorts) == 0 && !hasExternalStuns {
|
||||
c.Server.StunPorts = []int{3478}
|
||||
}
|
||||
|
||||
c.applyRelayDefaults(exposedProto, exposedHostPort, hasExternalRelay, hasExternalStuns)
|
||||
c.applySignalDefaults(hasExternalSignal)
|
||||
c.applyManagementDefaults(exposedHost)
|
||||
|
||||
// Auto-configure client settings (stuns, relays, signalUri)
|
||||
c.autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort, hasExternalStuns, hasExternalRelay, hasExternalSignal)
|
||||
}
|
||||
|
||||
// applyRelayDefaults configures the relay service if no external relay is configured.
|
||||
func (c *CombinedConfig) applyRelayDefaults(exposedProto, exposedHostPort string, hasExternalRelay, hasExternalStuns bool) {
|
||||
if hasExternalRelay {
|
||||
return
|
||||
}
|
||||
|
||||
c.Relay.Enabled = true
|
||||
relayProto := "rel"
|
||||
if exposedProto == "https" {
|
||||
relayProto = "rels"
|
||||
}
|
||||
c.Relay.ExposedAddress = fmt.Sprintf("%s://%s", relayProto, exposedHostPort)
|
||||
c.Relay.AuthSecret = c.Server.AuthSecret
|
||||
if c.Relay.LogLevel == "" {
|
||||
c.Relay.LogLevel = c.Server.LogLevel
|
||||
}
|
||||
|
||||
// Enable local STUN only if no external STUN servers and stunPorts are configured
|
||||
if !hasExternalStuns && len(c.Server.StunPorts) > 0 {
|
||||
c.Relay.Stun.Enabled = true
|
||||
c.Relay.Stun.Ports = c.Server.StunPorts
|
||||
if c.Relay.Stun.LogLevel == "" {
|
||||
c.Relay.Stun.LogLevel = c.Server.LogLevel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// applySignalDefaults configures the signal service if no external signal is configured.
|
||||
func (c *CombinedConfig) applySignalDefaults(hasExternalSignal bool) {
|
||||
if hasExternalSignal {
|
||||
return
|
||||
}
|
||||
|
||||
c.Signal.Enabled = true
|
||||
if c.Signal.LogLevel == "" {
|
||||
c.Signal.LogLevel = c.Server.LogLevel
|
||||
}
|
||||
}
|
||||
|
||||
// applyManagementDefaults configures the management service (always enabled).
|
||||
func (c *CombinedConfig) applyManagementDefaults(exposedHost string) {
|
||||
c.Management.Enabled = true
|
||||
if c.Management.LogLevel == "" {
|
||||
c.Management.LogLevel = c.Server.LogLevel
|
||||
}
|
||||
if c.Management.DataDir == "" || c.Management.DataDir == "/var/lib/netbird/" {
|
||||
c.Management.DataDir = c.Server.DataDir
|
||||
}
|
||||
c.Management.DnsDomain = exposedHost
|
||||
c.Management.DisableAnonymousMetrics = c.Server.DisableAnonymousMetrics
|
||||
c.Management.DisableGeoliteUpdate = c.Server.DisableGeoliteUpdate
|
||||
// Copy auth config from server if management auth issuer is not set
|
||||
if c.Management.Auth.Issuer == "" && c.Server.Auth.Issuer != "" {
|
||||
c.Management.Auth = c.Server.Auth
|
||||
}
|
||||
|
||||
// Copy store config from server if not set
|
||||
if c.Management.Store.Engine == "" || c.Management.Store.Engine == "sqlite" {
|
||||
if c.Server.Store.Engine != "" {
|
||||
c.Management.Store = c.Server.Store
|
||||
}
|
||||
}
|
||||
|
||||
// Copy reverse proxy config from server
|
||||
if len(c.Server.ReverseProxy.TrustedHTTPProxies) > 0 || c.Server.ReverseProxy.TrustedHTTPProxiesCount > 0 || len(c.Server.ReverseProxy.TrustedPeers) > 0 {
|
||||
c.Management.ReverseProxy = c.Server.ReverseProxy
|
||||
}
|
||||
}
|
||||
|
||||
// autoConfigureClientSettings sets up STUN/relay/signal URIs for clients
|
||||
// External overrides from server config take precedence over auto-generated values
|
||||
func (c *CombinedConfig) autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort string, hasExternalStuns, hasExternalRelay, hasExternalSignal bool) {
|
||||
// Determine relay protocol from exposed protocol
|
||||
relayProto := "rel"
|
||||
if exposedProto == "https" {
|
||||
relayProto = "rels"
|
||||
}
|
||||
|
||||
// Configure STUN servers for clients
|
||||
if hasExternalStuns {
|
||||
// Use external STUN servers from server config
|
||||
c.Management.Stuns = c.Server.Stuns
|
||||
} else if len(c.Server.StunPorts) > 0 && len(c.Management.Stuns) == 0 {
|
||||
// Auto-configure local STUN servers for all ports
|
||||
for _, port := range c.Server.StunPorts {
|
||||
c.Management.Stuns = append(c.Management.Stuns, HostConfig{
|
||||
URI: fmt.Sprintf("stun:%s:%d", exposedHost, port),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Configure relay for clients
|
||||
if hasExternalRelay {
|
||||
// Use external relay config from server
|
||||
c.Management.Relays = c.Server.Relays
|
||||
} else if len(c.Management.Relays.Addresses) == 0 {
|
||||
// Auto-configure local relay
|
||||
c.Management.Relays.Addresses = []string{
|
||||
fmt.Sprintf("%s://%s", relayProto, exposedHostPort),
|
||||
}
|
||||
}
|
||||
if c.Management.Relays.Secret == "" {
|
||||
c.Management.Relays.Secret = c.Server.AuthSecret
|
||||
}
|
||||
if c.Management.Relays.CredentialsTTL == "" {
|
||||
c.Management.Relays.CredentialsTTL = "12h"
|
||||
}
|
||||
|
||||
// Configure signal for clients
|
||||
if hasExternalSignal {
|
||||
// Use external signal URI from server config
|
||||
c.Management.SignalURI = c.Server.SignalURI
|
||||
} else if c.Management.SignalURI == "" {
|
||||
// Auto-configure local signal
|
||||
c.Management.SignalURI = fmt.Sprintf("%s://%s", exposedProto, exposedHostPort)
|
||||
}
|
||||
}
|
||||
|
||||
// LoadConfig loads configuration from a YAML file
|
||||
func LoadConfig(configPath string) (*CombinedConfig, error) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if configPath == "" {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
if err := yaml.Unmarshal(data, cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
// Populate internal configs from server settings
|
||||
cfg.ApplySimplifiedDefaults()
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Validate validates the configuration
|
||||
func (c *CombinedConfig) Validate() error {
|
||||
if c.Server.ExposedAddress == "" {
|
||||
return fmt.Errorf("server.exposedAddress is required")
|
||||
}
|
||||
if c.Server.DataDir == "" {
|
||||
return fmt.Errorf("server.dataDir is required")
|
||||
}
|
||||
|
||||
// Validate STUN ports
|
||||
seen := make(map[int]bool)
|
||||
for _, port := range c.Server.StunPorts {
|
||||
if port <= 0 || port > 65535 {
|
||||
return fmt.Errorf("invalid server.stunPorts value %d: must be between 1 and 65535", port)
|
||||
}
|
||||
if seen[port] {
|
||||
return fmt.Errorf("duplicate STUN port %d in server.stunPorts", port)
|
||||
}
|
||||
seen[port] = true
|
||||
}
|
||||
|
||||
// authSecret is required only if running local relay (no external relay configured)
|
||||
hasExternalRelay := len(c.Server.Relays.Addresses) > 0
|
||||
if !hasExternalRelay && c.Server.AuthSecret == "" {
|
||||
return fmt.Errorf("server.authSecret is required when running local relay")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasTLSCert returns true if TLS certificate files are configured
|
||||
func (c *CombinedConfig) HasTLSCert() bool {
|
||||
return c.Server.TLS.CertFile != "" && c.Server.TLS.KeyFile != ""
|
||||
}
|
||||
|
||||
// HasLetsEncrypt returns true if Let's Encrypt is configured
|
||||
func (c *CombinedConfig) HasLetsEncrypt() bool {
|
||||
return c.Server.TLS.LetsEncrypt.Enabled &&
|
||||
c.Server.TLS.LetsEncrypt.DataDir != "" &&
|
||||
len(c.Server.TLS.LetsEncrypt.Domains) > 0
|
||||
}
|
||||
|
||||
// parseExplicitProtocol parses an explicit protocol string to nbconfig.Protocol
|
||||
func parseExplicitProtocol(proto string) (nbconfig.Protocol, bool) {
|
||||
switch strings.ToLower(proto) {
|
||||
case "udp":
|
||||
return nbconfig.UDP, true
|
||||
case "dtls":
|
||||
return nbconfig.DTLS, true
|
||||
case "tcp":
|
||||
return nbconfig.TCP, true
|
||||
case "http":
|
||||
return nbconfig.HTTP, true
|
||||
case "https":
|
||||
return nbconfig.HTTPS, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// parseStunProtocol determines protocol for STUN/TURN servers.
|
||||
// stun: → UDP, stuns: → DTLS, turn: → UDP, turns: → DTLS
|
||||
// Explicit proto overrides URI scheme. Defaults to UDP.
|
||||
func parseStunProtocol(uri, proto string) nbconfig.Protocol {
|
||||
if proto != "" {
|
||||
if p, ok := parseExplicitProtocol(proto); ok {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
uri = strings.ToLower(uri)
|
||||
switch {
|
||||
case strings.HasPrefix(uri, "stuns:"):
|
||||
return nbconfig.DTLS
|
||||
case strings.HasPrefix(uri, "turns:"):
|
||||
return nbconfig.DTLS
|
||||
default:
|
||||
// stun:, turn:, or no scheme - default to UDP
|
||||
return nbconfig.UDP
|
||||
}
|
||||
}
|
||||
|
||||
// parseSignalProtocol determines protocol for Signal servers.
|
||||
// https:// → HTTPS, http:// → HTTP. Defaults to HTTPS.
|
||||
func parseSignalProtocol(uri string) nbconfig.Protocol {
|
||||
uri = strings.ToLower(uri)
|
||||
switch {
|
||||
case strings.HasPrefix(uri, "http://"):
|
||||
return nbconfig.HTTP
|
||||
default:
|
||||
// https:// or no scheme - default to HTTPS
|
||||
return nbconfig.HTTPS
|
||||
}
|
||||
}
|
||||
|
||||
// stripSignalProtocol removes the protocol prefix from a signal URI.
|
||||
// Returns just the host:port (e.g., "selfhosted2.demo.netbird.io:443").
|
||||
func stripSignalProtocol(uri string) string {
|
||||
uri = strings.TrimPrefix(uri, "https://")
|
||||
uri = strings.TrimPrefix(uri, "http://")
|
||||
return uri
|
||||
}
|
||||
|
||||
// ToManagementConfig converts CombinedConfig to management server config
|
||||
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
||||
mgmt := c.Management
|
||||
|
||||
// Build STUN hosts
|
||||
var stuns []*nbconfig.Host
|
||||
for _, s := range mgmt.Stuns {
|
||||
stuns = append(stuns, &nbconfig.Host{
|
||||
URI: s.URI,
|
||||
Proto: parseStunProtocol(s.URI, s.Proto),
|
||||
Username: s.Username,
|
||||
Password: s.Password,
|
||||
})
|
||||
}
|
||||
|
||||
// Build relay config
|
||||
var relayConfig *nbconfig.Relay
|
||||
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
|
||||
var ttl time.Duration
|
||||
if mgmt.Relays.CredentialsTTL != "" {
|
||||
var err error
|
||||
ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err)
|
||||
}
|
||||
}
|
||||
relayConfig = &nbconfig.Relay{
|
||||
Addresses: mgmt.Relays.Addresses,
|
||||
CredentialsTTL: util.Duration{Duration: ttl},
|
||||
Secret: mgmt.Relays.Secret,
|
||||
}
|
||||
}
|
||||
|
||||
// Build signal config
|
||||
var signalConfig *nbconfig.Host
|
||||
if mgmt.SignalURI != "" {
|
||||
signalConfig = &nbconfig.Host{
|
||||
URI: stripSignalProtocol(mgmt.SignalURI),
|
||||
Proto: parseSignalProtocol(mgmt.SignalURI),
|
||||
}
|
||||
}
|
||||
|
||||
// Build store config
|
||||
storeConfig := nbconfig.StoreConfig{
|
||||
Engine: types.Engine(mgmt.Store.Engine),
|
||||
}
|
||||
|
||||
// Build reverse proxy config
|
||||
reverseProxy := nbconfig.ReverseProxy{
|
||||
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
|
||||
}
|
||||
for _, p := range mgmt.ReverseProxy.TrustedHTTPProxies {
|
||||
if prefix, err := netip.ParsePrefix(p); err == nil {
|
||||
reverseProxy.TrustedHTTPProxies = append(reverseProxy.TrustedHTTPProxies, prefix)
|
||||
}
|
||||
}
|
||||
for _, p := range mgmt.ReverseProxy.TrustedPeers {
|
||||
if prefix, err := netip.ParsePrefix(p); err == nil {
|
||||
reverseProxy.TrustedPeers = append(reverseProxy.TrustedPeers, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
// Build HTTP config (required, even if empty)
|
||||
httpConfig := &nbconfig.HttpServerConfig{}
|
||||
|
||||
// Build embedded IDP config (always enabled in combined server)
|
||||
storageFile := mgmt.Auth.Storage.File
|
||||
if storageFile == "" {
|
||||
storageFile = path.Join(mgmt.DataDir, "idp.db")
|
||||
}
|
||||
|
||||
embeddedIdP := &idp.EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: mgmt.Auth.Issuer,
|
||||
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
|
||||
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
|
||||
Storage: idp.EmbeddedStorageConfig{
|
||||
Type: mgmt.Auth.Storage.Type,
|
||||
Config: idp.EmbeddedStorageTypeConfig{
|
||||
File: storageFile,
|
||||
},
|
||||
},
|
||||
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
|
||||
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
|
||||
}
|
||||
|
||||
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
|
||||
embeddedIdP.Owner = &idp.OwnerConfig{
|
||||
Email: mgmt.Auth.Owner.Email,
|
||||
Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text
|
||||
}
|
||||
}
|
||||
|
||||
// Set HTTP config fields for embedded IDP
|
||||
httpConfig.AuthIssuer = mgmt.Auth.Issuer
|
||||
httpConfig.AuthAudience = "netbird-dashboard"
|
||||
httpConfig.AuthClientID = httpConfig.AuthAudience
|
||||
httpConfig.CLIAuthAudience = "netbird-cli"
|
||||
httpConfig.AuthUserIDClaim = "sub"
|
||||
httpConfig.AuthKeysLocation = mgmt.Auth.Issuer + "/keys"
|
||||
httpConfig.OIDCConfigEndpoint = mgmt.Auth.Issuer + "/.well-known/openid-configuration"
|
||||
httpConfig.IdpSignKeyRefreshEnabled = mgmt.Auth.SignKeyRefreshEnabled
|
||||
callbackURL := strings.TrimSuffix(httpConfig.AuthIssuer, "/oauth2")
|
||||
httpConfig.AuthCallbackURL = callbackURL + types.ProxyCallbackEndpointFull
|
||||
|
||||
return &nbconfig.Config{
|
||||
Stuns: stuns,
|
||||
Relay: relayConfig,
|
||||
Signal: signalConfig,
|
||||
Datadir: mgmt.DataDir,
|
||||
DataStoreEncryptionKey: mgmt.Store.EncryptionKey,
|
||||
HttpConfig: httpConfig,
|
||||
StoreConfig: storeConfig,
|
||||
ReverseProxy: reverseProxy,
|
||||
DisableDefaultPolicy: mgmt.DisableDefaultPolicy,
|
||||
EmbeddedIdP: embeddedIdP,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ApplyEmbeddedIdPConfig applies embedded IdP configuration to the management config.
|
||||
// This mirrors the logic in management/cmd/management.go ApplyEmbeddedIdPConfig.
|
||||
func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config, mgmtPort int, disableSingleAccMode bool) error {
|
||||
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Embedded IdP requires single account mode
|
||||
if disableSingleAccMode {
|
||||
return fmt.Errorf("embedded IdP requires single account mode; multiple account mode is not supported with embedded IdP")
|
||||
}
|
||||
|
||||
// Set LocalAddress for embedded IdP, used for internal JWT validation
|
||||
cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
|
||||
|
||||
// Set storage defaults based on Datadir
|
||||
if cfg.EmbeddedIdP.Storage.Type == "" {
|
||||
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
|
||||
}
|
||||
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
|
||||
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
|
||||
}
|
||||
|
||||
issuer := cfg.EmbeddedIdP.Issuer
|
||||
|
||||
// Ensure HttpConfig exists
|
||||
if cfg.HttpConfig == nil {
|
||||
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
|
||||
}
|
||||
|
||||
// Set HttpConfig values from EmbeddedIdP
|
||||
cfg.HttpConfig.AuthIssuer = issuer
|
||||
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
||||
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
||||
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
||||
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
||||
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
||||
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnsureEncryptionKey generates an encryption key if not set.
|
||||
// Unlike management server, we don't write back to the config file.
|
||||
func EnsureEncryptionKey(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
if cfg.DataStoreEncryptionKey != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("DataStoreEncryptionKey is not set, generating a new key")
|
||||
key, err := crypt.GenerateKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate datastore encryption key: %v", err)
|
||||
}
|
||||
cfg.DataStoreEncryptionKey = key
|
||||
keyPreview := key[:8] + "..."
|
||||
log.WithContext(ctx).Warnf("DataStoreEncryptionKey generated (%s); add it to your config file under 'server.store.encryptionKey' to persist across restarts", keyPreview)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogConfigInfo logs informational messages about the loaded configuration
|
||||
func LogConfigInfo(cfg *nbconfig.Config) {
|
||||
if cfg.EmbeddedIdP != nil && cfg.EmbeddedIdP.Enabled {
|
||||
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
|
||||
}
|
||||
if cfg.Relay != nil {
|
||||
log.Infof("Relay addresses: %v", cfg.Relay.Addresses)
|
||||
}
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
//go:build pprof
|
||||
// +build pprof
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func init() {
|
||||
addr := pprofAddr()
|
||||
go pprof(addr)
|
||||
}
|
||||
|
||||
func pprofAddr() string {
|
||||
listenAddr := os.Getenv("NB_PPROF_ADDR")
|
||||
if listenAddr == "" {
|
||||
return "localhost:6969"
|
||||
}
|
||||
|
||||
return listenAddr
|
||||
}
|
||||
|
||||
func pprof(listenAddr string) {
|
||||
log.Infof("listening pprof on: %s\n", listenAddr)
|
||||
if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||
log.Fatalf("Failed to start pprof: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1,715 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
mgmtServer "github.com/netbirdio/netbird/management/internals/server"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||
relayServer "github.com/netbirdio/netbird/relay/server"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
|
||||
"github.com/netbirdio/netbird/shared/relay/auth"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/stun"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/wsproxy"
|
||||
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
|
||||
)
|
||||
|
||||
var (
|
||||
configPath string
|
||||
config *CombinedConfig
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "combined",
|
||||
Short: "Combined Netbird server (Management + Signal + Relay + STUN)",
|
||||
Long: `Combined Netbird server for self-hosted deployments.
|
||||
|
||||
All services (Management, Signal, Relay) are multiplexed on a single port.
|
||||
Optional STUN server runs on separate UDP ports.
|
||||
|
||||
Configuration is loaded from a YAML file specified with --config.`,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
RunE: execute,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
|
||||
_ = rootCmd.MarkPersistentFlagRequired("config")
|
||||
|
||||
rootCmd.AddCommand(newTokenCommands())
|
||||
}
|
||||
|
||||
func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
func waitForExitSignal() {
|
||||
osSigs := make(chan os.Signal, 1)
|
||||
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-osSigs
|
||||
}
|
||||
|
||||
func execute(cmd *cobra.Command, _ []string) error {
|
||||
if err := initializeConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Management is required as the base server when signal or relay are enabled
|
||||
if (config.Signal.Enabled || config.Relay.Enabled) && !config.Management.Enabled {
|
||||
return fmt.Errorf("management must be enabled when signal or relay are enabled (provides the base HTTP server)")
|
||||
}
|
||||
|
||||
servers, err := createAllServers(cmd.Context(), config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register services with management's gRPC server using AfterInit hook
|
||||
setupServerHooks(servers, config)
|
||||
|
||||
// Start management server (this also starts the HTTP listener)
|
||||
if servers.mgmtSrv != nil {
|
||||
if err := servers.mgmtSrv.Start(cmd.Context()); err != nil {
|
||||
cleanupSTUNListeners(servers.stunListeners)
|
||||
return fmt.Errorf("failed to start management server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start all other servers
|
||||
wg := sync.WaitGroup{}
|
||||
startServers(&wg, servers.relaySrv, servers.healthcheck, servers.stunServer, servers.metricsServer)
|
||||
|
||||
waitForExitSignal()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = shutdownServers(ctx, servers.relaySrv, servers.healthcheck, servers.stunServer, servers.mgmtSrv, servers.metricsServer)
|
||||
wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
// initializeConfig loads and validates the configuration, then initializes logging.
|
||||
func initializeConfig() error {
|
||||
var err error
|
||||
config, err = LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
if err := config.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
if err := util.InitLog(config.Server.LogLevel, config.Server.LogFile); err != nil {
|
||||
return fmt.Errorf("failed to initialize log: %w", err)
|
||||
}
|
||||
|
||||
if dsn := config.Server.Store.DSN; dsn != "" {
|
||||
switch strings.ToLower(config.Server.Store.Engine) {
|
||||
case "postgres":
|
||||
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||
case "mysql":
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("Starting combined NetBird server")
|
||||
logConfig(config)
|
||||
logEnvVars()
|
||||
return nil
|
||||
}
|
||||
|
||||
// serverInstances holds all server instances created during startup.
|
||||
type serverInstances struct {
|
||||
relaySrv *relayServer.Server
|
||||
mgmtSrv *mgmtServer.BaseServer
|
||||
signalSrv *signalServer.Server
|
||||
healthcheck *healthcheck.Server
|
||||
stunServer *stun.Server
|
||||
stunListeners []*net.UDPConn
|
||||
metricsServer *sharedMetrics.Metrics
|
||||
}
|
||||
|
||||
// createAllServers creates all server instances based on configuration.
|
||||
func createAllServers(ctx context.Context, cfg *CombinedConfig) (*serverInstances, error) {
|
||||
metricsServer, err := sharedMetrics.NewServer(cfg.Server.MetricsPort, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create metrics server: %w", err)
|
||||
}
|
||||
servers := &serverInstances{
|
||||
metricsServer: metricsServer,
|
||||
}
|
||||
|
||||
_, tlsSupport, err := handleTLSConfig(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to setup TLS config: %w", err)
|
||||
}
|
||||
|
||||
if err := servers.createRelayServer(cfg, tlsSupport); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := servers.createManagementServer(ctx, cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := servers.createSignalServer(ctx, cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := servers.createHealthcheckServer(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
func (s *serverInstances) createRelayServer(cfg *CombinedConfig, tlsSupport bool) error {
|
||||
if !cfg.Relay.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.stunListeners, err = createSTUNListeners(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hashedSecret := sha256.Sum256([]byte(cfg.Relay.AuthSecret))
|
||||
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
|
||||
|
||||
relayCfg := relayServer.Config{
|
||||
Meter: s.metricsServer.Meter,
|
||||
ExposedAddress: cfg.Relay.ExposedAddress,
|
||||
AuthValidator: authenticator,
|
||||
TLSSupport: tlsSupport,
|
||||
}
|
||||
|
||||
s.relaySrv, err = createRelayServer(relayCfg, s.stunListeners)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Relay server created")
|
||||
|
||||
if len(s.stunListeners) > 0 {
|
||||
s.stunServer = stun.NewServer(s.stunListeners, cfg.Relay.Stun.LogLevel)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serverInstances) createManagementServer(ctx context.Context, cfg *CombinedConfig) error {
|
||||
if !cfg.Management.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
mgmtConfig, err := cfg.ToManagementConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create management config: %w", err)
|
||||
}
|
||||
|
||||
_, portStr, portErr := net.SplitHostPort(cfg.Server.ListenAddress)
|
||||
if portErr != nil {
|
||||
portStr = "443"
|
||||
}
|
||||
mgmtPort, _ := strconv.Atoi(portStr)
|
||||
|
||||
if err := ApplyEmbeddedIdPConfig(ctx, mgmtConfig, mgmtPort, false); err != nil {
|
||||
cleanupSTUNListeners(s.stunListeners)
|
||||
return fmt.Errorf("failed to apply embedded IdP config: %w", err)
|
||||
}
|
||||
|
||||
if err := EnsureEncryptionKey(ctx, mgmtConfig); err != nil {
|
||||
cleanupSTUNListeners(s.stunListeners)
|
||||
return fmt.Errorf("failed to ensure encryption key: %w", err)
|
||||
}
|
||||
|
||||
LogConfigInfo(mgmtConfig)
|
||||
|
||||
s.mgmtSrv, err = createManagementServer(cfg, mgmtConfig)
|
||||
if err != nil {
|
||||
cleanupSTUNListeners(s.stunListeners)
|
||||
return fmt.Errorf("failed to create management server: %w", err)
|
||||
}
|
||||
|
||||
// Inject externally-managed AppMetrics so management uses the shared metrics server
|
||||
appMetrics, err := telemetry.NewAppMetricsWithMeter(ctx, s.metricsServer.Meter)
|
||||
if err != nil {
|
||||
cleanupSTUNListeners(s.stunListeners)
|
||||
return fmt.Errorf("failed to create management app metrics: %w", err)
|
||||
}
|
||||
mgmtServer.Inject[telemetry.AppMetrics](s.mgmtSrv, appMetrics)
|
||||
|
||||
log.Infof("Management server created")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serverInstances) createSignalServer(ctx context.Context, cfg *CombinedConfig) error {
|
||||
if !cfg.Signal.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.signalSrv, err = signalServer.NewServer(ctx, s.metricsServer.Meter, "signal_")
|
||||
if err != nil {
|
||||
cleanupSTUNListeners(s.stunListeners)
|
||||
return fmt.Errorf("failed to create signal server: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("Signal server created")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serverInstances) createHealthcheckServer(cfg *CombinedConfig) error {
|
||||
hCfg := healthcheck.Config{
|
||||
ListenAddress: cfg.Server.HealthcheckAddress,
|
||||
ServiceChecker: s.relaySrv,
|
||||
}
|
||||
|
||||
var err error
|
||||
s.healthcheck, err = createHealthCheck(hCfg, s.stunListeners)
|
||||
return err
|
||||
}
|
||||
|
||||
// setupServerHooks registers services with management's gRPC server.
|
||||
func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
|
||||
if servers.mgmtSrv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
servers.mgmtSrv.AfterInit(func(s *mgmtServer.BaseServer) {
|
||||
grpcSrv := s.GRPCServer()
|
||||
|
||||
if servers.signalSrv != nil {
|
||||
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
|
||||
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||
}
|
||||
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
if servers.relaySrv != nil {
|
||||
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, metricsServer *sharedMetrics.Metrics) {
|
||||
if srv != nil {
|
||||
instanceURL := srv.InstanceURL()
|
||||
log.Infof("Relay server instance URL: %s", instanceURL.String())
|
||||
log.Infof("Relay WebSocket multiplexed on management port (no separate relay listener)")
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
||||
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("failed to start metrics server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("failed to start healthcheck server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if stunServer != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := stunServer.Listen(); err != nil {
|
||||
if errors.Is(err, stun.ErrServerClosed) {
|
||||
return
|
||||
}
|
||||
log.Errorf("STUN server error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv *mgmtServer.BaseServer, metricsServer *sharedMetrics.Metrics) error {
|
||||
var errs error
|
||||
|
||||
if err := httpHealthcheck.Shutdown(ctx); err != nil {
|
||||
errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err))
|
||||
}
|
||||
|
||||
if stunServer != nil {
|
||||
if err := stunServer.Shutdown(); err != nil {
|
||||
errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if srv != nil {
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if mgmtSrv != nil {
|
||||
log.Infof("shutting down management and signal servers")
|
||||
if err := mgmtSrv.Stop(); err != nil {
|
||||
errs = multierror.Append(errs, fmt.Errorf("failed to close management server: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if metricsServer != nil {
|
||||
log.Infof("shutting down metrics server")
|
||||
if err := metricsServer.Shutdown(ctx); err != nil {
|
||||
errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
func createHealthCheck(hCfg healthcheck.Config, stunListeners []*net.UDPConn) (*healthcheck.Server, error) {
|
||||
httpHealthcheck, err := healthcheck.NewServer(hCfg)
|
||||
if err != nil {
|
||||
cleanupSTUNListeners(stunListeners)
|
||||
return nil, fmt.Errorf("failed to create healthcheck server: %w", err)
|
||||
}
|
||||
return httpHealthcheck, nil
|
||||
}
|
||||
|
||||
func createRelayServer(cfg relayServer.Config, stunListeners []*net.UDPConn) (*relayServer.Server, error) {
|
||||
srv, err := relayServer.NewServer(cfg)
|
||||
if err != nil {
|
||||
cleanupSTUNListeners(stunListeners)
|
||||
return nil, fmt.Errorf("failed to create relay server: %w", err)
|
||||
}
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
func cleanupSTUNListeners(stunListeners []*net.UDPConn) {
|
||||
for _, l := range stunListeners {
|
||||
_ = l.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func createSTUNListeners(cfg *CombinedConfig) ([]*net.UDPConn, error) {
|
||||
var stunListeners []*net.UDPConn
|
||||
if cfg.Relay.Stun.Enabled {
|
||||
for _, port := range cfg.Relay.Stun.Ports {
|
||||
listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
|
||||
if err != nil {
|
||||
cleanupSTUNListeners(stunListeners)
|
||||
return nil, fmt.Errorf("failed to create STUN listener on port %d: %w", port, err)
|
||||
}
|
||||
stunListeners = append(stunListeners, listener)
|
||||
log.Infof("STUN server listening on UDP port %d", port)
|
||||
}
|
||||
}
|
||||
return stunListeners, nil
|
||||
}
|
||||
|
||||
func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
|
||||
tlsCfg := cfg.Server.TLS
|
||||
|
||||
if tlsCfg.LetsEncrypt.AWSRoute53 {
|
||||
log.Debugf("using Let's Encrypt DNS resolver with Route 53 support")
|
||||
r53 := encryption.Route53TLS{
|
||||
DataDir: tlsCfg.LetsEncrypt.DataDir,
|
||||
Email: tlsCfg.LetsEncrypt.Email,
|
||||
Domains: tlsCfg.LetsEncrypt.Domains,
|
||||
}
|
||||
tc, err := r53.GetCertificate()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return tc, true, nil
|
||||
}
|
||||
|
||||
if cfg.HasLetsEncrypt() {
|
||||
log.Infof("setting up TLS with Let's Encrypt")
|
||||
certManager, err := encryption.CreateCertManager(tlsCfg.LetsEncrypt.DataDir, tlsCfg.LetsEncrypt.Domains...)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed creating LetsEncrypt cert manager: %w", err)
|
||||
}
|
||||
return certManager.TLSConfig(), true, nil
|
||||
}
|
||||
|
||||
if cfg.HasTLSCert() {
|
||||
log.Debugf("using file based TLS config")
|
||||
tc, err := encryption.LoadTLSConfig(tlsCfg.CertFile, tlsCfg.KeyFile)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return tc, true, nil
|
||||
}
|
||||
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
|
||||
mgmt := cfg.Management
|
||||
|
||||
dnsDomain := mgmt.DnsDomain
|
||||
singleAccModeDomain := dnsDomain
|
||||
|
||||
// Extract port from listen address
|
||||
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
|
||||
if err != nil {
|
||||
// If no port specified, assume default
|
||||
portStr = "443"
|
||||
}
|
||||
mgmtPort, _ := strconv.Atoi(portStr)
|
||||
|
||||
mgmtSrv := mgmtServer.NewServer(
|
||||
&mgmtServer.Config{
|
||||
NbConfig: mgmtConfig,
|
||||
DNSDomain: dnsDomain,
|
||||
MgmtSingleAccModeDomain: singleAccModeDomain,
|
||||
MgmtPort: mgmtPort,
|
||||
MgmtMetricsPort: cfg.Server.MetricsPort,
|
||||
DisableMetrics: mgmt.DisableAnonymousMetrics,
|
||||
DisableGeoliteUpdate: mgmt.DisableGeoliteUpdate,
|
||||
// Always enable user deletion from IDP in combined server (embedded IdP is always enabled)
|
||||
UserDeleteFromIDPEnabled: true,
|
||||
},
|
||||
)
|
||||
|
||||
return mgmtSrv, nil
|
||||
}
|
||||
|
||||
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
|
||||
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
||||
|
||||
var relayAcceptFn func(conn net.Conn)
|
||||
if relaySrv != nil {
|
||||
relayAcceptFn = relaySrv.RelayAccept()
|
||||
}
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
// Native gRPC traffic (HTTP/2 with gRPC content-type)
|
||||
case r.ProtoMajor == 2 && (strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") ||
|
||||
strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc+proto")):
|
||||
grpcServer.ServeHTTP(w, r)
|
||||
|
||||
// WebSocket proxy for Management gRPC
|
||||
case r.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
|
||||
wsProxy.Handler().ServeHTTP(w, r)
|
||||
|
||||
// WebSocket proxy for Signal gRPC
|
||||
case r.URL.Path == wsproxy.ProxyPath+wsproxy.SignalComponent:
|
||||
if cfg.Signal.Enabled {
|
||||
wsProxy.Handler().ServeHTTP(w, r)
|
||||
} else {
|
||||
http.Error(w, "Signal service not enabled", http.StatusNotFound)
|
||||
}
|
||||
|
||||
// Relay WebSocket
|
||||
case r.URL.Path == "/relay":
|
||||
if relayAcceptFn != nil {
|
||||
handleRelayWebSocket(w, r, relayAcceptFn, cfg)
|
||||
} else {
|
||||
http.Error(w, "Relay service not enabled", http.StatusNotFound)
|
||||
}
|
||||
|
||||
// Management HTTP API (default)
|
||||
default:
|
||||
httpHandler.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// handleRelayWebSocket handles incoming WebSocket connections for the relay service
|
||||
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) {
|
||||
acceptOptions := &websocket.AcceptOptions{
|
||||
OriginPatterns: []string{"*"},
|
||||
}
|
||||
|
||||
wsConn, err := websocket.Accept(w, r, acceptOptions)
|
||||
if err != nil {
|
||||
log.Errorf("failed to accept relay ws connection: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
connRemoteAddr := r.RemoteAddr
|
||||
if r.Header.Get("X-Real-Ip") != "" && r.Header.Get("X-Real-Port") != "" {
|
||||
connRemoteAddr = net.JoinHostPort(r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port"))
|
||||
}
|
||||
|
||||
rAddr, err := net.ResolveTCPAddr("tcp", connRemoteAddr)
|
||||
if err != nil {
|
||||
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
|
||||
return
|
||||
}
|
||||
|
||||
lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress)
|
||||
if err != nil {
|
||||
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Relay WS client connected from: %s", rAddr)
|
||||
|
||||
conn := ws.NewConn(wsConn, lAddr, rAddr)
|
||||
acceptFn(conn)
|
||||
}
|
||||
|
||||
// logConfig prints all configuration parameters for debugging
|
||||
func logConfig(cfg *CombinedConfig) {
|
||||
log.Info("=== Configuration ===")
|
||||
logServerConfig(cfg)
|
||||
logComponentsConfig(cfg)
|
||||
logRelayConfig(cfg)
|
||||
logManagementConfig(cfg)
|
||||
log.Info("=== End Configuration ===")
|
||||
}
|
||||
|
||||
func logServerConfig(cfg *CombinedConfig) {
|
||||
log.Info("--- Server ---")
|
||||
log.Infof(" Listen address: %s", cfg.Server.ListenAddress)
|
||||
log.Infof(" Exposed address: %s", cfg.Server.ExposedAddress)
|
||||
log.Infof(" Healthcheck address: %s", cfg.Server.HealthcheckAddress)
|
||||
log.Infof(" Metrics port: %d", cfg.Server.MetricsPort)
|
||||
log.Infof(" Log level: %s", cfg.Server.LogLevel)
|
||||
log.Infof(" Data dir: %s", cfg.Server.DataDir)
|
||||
|
||||
switch {
|
||||
case cfg.HasTLSCert():
|
||||
log.Infof(" TLS: cert=%s, key=%s", cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile)
|
||||
case cfg.HasLetsEncrypt():
|
||||
log.Infof(" TLS: Let's Encrypt (domains=%v)", cfg.Server.TLS.LetsEncrypt.Domains)
|
||||
default:
|
||||
log.Info(" TLS: disabled (using reverse proxy)")
|
||||
}
|
||||
}
|
||||
|
||||
func logComponentsConfig(cfg *CombinedConfig) {
|
||||
log.Info("--- Components ---")
|
||||
log.Infof(" Management: %v (log level: %s)", cfg.Management.Enabled, cfg.Management.LogLevel)
|
||||
log.Infof(" Signal: %v (log level: %s)", cfg.Signal.Enabled, cfg.Signal.LogLevel)
|
||||
log.Infof(" Relay: %v (log level: %s)", cfg.Relay.Enabled, cfg.Relay.LogLevel)
|
||||
}
|
||||
|
||||
func logRelayConfig(cfg *CombinedConfig) {
|
||||
if !cfg.Relay.Enabled {
|
||||
return
|
||||
}
|
||||
log.Info("--- Relay ---")
|
||||
log.Infof(" Exposed address: %s", cfg.Relay.ExposedAddress)
|
||||
log.Infof(" Auth secret: %s...", maskSecret(cfg.Relay.AuthSecret))
|
||||
if cfg.Relay.Stun.Enabled {
|
||||
log.Infof(" STUN ports: %v (log level: %s)", cfg.Relay.Stun.Ports, cfg.Relay.Stun.LogLevel)
|
||||
} else {
|
||||
log.Info(" STUN: disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func logManagementConfig(cfg *CombinedConfig) {
|
||||
if !cfg.Management.Enabled {
|
||||
return
|
||||
}
|
||||
log.Info("--- Management ---")
|
||||
log.Infof(" Data dir: %s", cfg.Management.DataDir)
|
||||
log.Infof(" DNS domain: %s", cfg.Management.DnsDomain)
|
||||
log.Infof(" Store engine: %s", cfg.Management.Store.Engine)
|
||||
if cfg.Server.Store.DSN != "" {
|
||||
log.Infof(" Store DSN: %s", maskDSNPassword(cfg.Server.Store.DSN))
|
||||
}
|
||||
|
||||
log.Info(" Auth (embedded IdP):")
|
||||
log.Infof(" Issuer: %s", cfg.Management.Auth.Issuer)
|
||||
log.Infof(" Dashboard redirect URIs: %v", cfg.Management.Auth.DashboardRedirectURIs)
|
||||
log.Infof(" CLI redirect URIs: %v", cfg.Management.Auth.CLIRedirectURIs)
|
||||
|
||||
log.Info(" Client settings:")
|
||||
log.Infof(" Signal URI: %s", cfg.Management.SignalURI)
|
||||
for _, s := range cfg.Management.Stuns {
|
||||
log.Infof(" STUN: %s", s.URI)
|
||||
}
|
||||
if len(cfg.Management.Relays.Addresses) > 0 {
|
||||
log.Infof(" Relay addresses: %v", cfg.Management.Relays.Addresses)
|
||||
log.Infof(" Relay credentials TTL: %s", cfg.Management.Relays.CredentialsTTL)
|
||||
}
|
||||
}
|
||||
|
||||
// logEnvVars logs all NB_ environment variables that are currently set
|
||||
func logEnvVars() {
|
||||
log.Info("=== Environment Variables ===")
|
||||
found := false
|
||||
for _, env := range os.Environ() {
|
||||
if strings.HasPrefix(env, "NB_") {
|
||||
key, _, _ := strings.Cut(env, "=")
|
||||
value := os.Getenv(key)
|
||||
if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") {
|
||||
value = maskSecret(value)
|
||||
}
|
||||
log.Infof(" %s=%s", key, value)
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
log.Info(" (none set)")
|
||||
}
|
||||
log.Info("=== End Environment Variables ===")
|
||||
}
|
||||
|
||||
// maskDSNPassword masks the password in a DSN string.
|
||||
// Handles both key=value format ("password=secret") and URI format ("user:secret@host").
|
||||
func maskDSNPassword(dsn string) string {
|
||||
// Key=value format: "host=localhost user=nb password=secret dbname=nb"
|
||||
if strings.Contains(dsn, "password=") {
|
||||
parts := strings.Fields(dsn)
|
||||
for i, p := range parts {
|
||||
if strings.HasPrefix(p, "password=") {
|
||||
parts[i] = "password=****"
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// URI format: "user:password@host..."
|
||||
if atIdx := strings.Index(dsn, "@"); atIdx != -1 {
|
||||
prefix := dsn[:atIdx]
|
||||
if colonIdx := strings.Index(prefix, ":"); colonIdx != -1 {
|
||||
return prefix[:colonIdx+1] + "****" + dsn[atIdx:]
|
||||
}
|
||||
}
|
||||
|
||||
return dsn
|
||||
}
|
||||
|
||||
// maskSecret returns first 4 chars of secret followed by "..."
|
||||
func maskSecret(secret string) string {
|
||||
if len(secret) <= 4 {
|
||||
return "****"
|
||||
}
|
||||
return secret[:4] + "..."
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// newTokenCommands creates the token command tree with combined-specific store opener.
|
||||
func newTokenCommands() *cobra.Command {
|
||||
return tokencmd.NewCommands(withTokenStore)
|
||||
}
|
||||
|
||||
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
|
||||
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
if dsn := cfg.Server.Store.DSN; dsn != "" {
|
||||
switch strings.ToLower(cfg.Server.Store.Engine) {
|
||||
case "postgres":
|
||||
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||
case "mysql":
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
datadir := cfg.Management.DataDir
|
||||
engine := types.Engine(cfg.Management.Store.Engine)
|
||||
|
||||
s, err := store.NewStore(ctx, engine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, s)
|
||||
}
|
||||
@@ -1,111 +0,0 @@
|
||||
# NetBird Combined Server Configuration
|
||||
# Copy this file to config.yaml and customize for your deployment
|
||||
#
|
||||
# This is a Management server with optional embedded Signal, Relay, and STUN services.
|
||||
# By default, all services run locally. You can use external services instead by
|
||||
# setting the corresponding override fields.
|
||||
#
|
||||
# Architecture:
|
||||
# - Management: Always runs locally (this IS the management server)
|
||||
# - Signal: Local by default; set 'signalUri' to use external (disables local)
|
||||
# - Relay: Local by default; set 'relays' to use external (disables local)
|
||||
# - STUN: Local on port 3478 by default; set 'stuns' to use external instead
|
||||
|
||||
server:
|
||||
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
|
||||
listenAddress: ":443"
|
||||
|
||||
# Public address that peers will use to connect to this server
|
||||
# Used for relay connections and management DNS domain
|
||||
# Format: protocol://hostname:port (e.g., https://server.mycompany.com:443)
|
||||
exposedAddress: "https://server.mycompany.com:443"
|
||||
|
||||
# STUN server ports (defaults to [3478] if not specified; set 'stuns' to use external)
|
||||
# stunPorts:
|
||||
# - 3478
|
||||
|
||||
# Metrics endpoint port
|
||||
metricsPort: 9090
|
||||
|
||||
# Healthcheck endpoint address
|
||||
healthcheckAddress: ":9000"
|
||||
|
||||
# Logging configuration
|
||||
logLevel: "info" # Default log level for all components: panic, fatal, error, warn, info, debug, trace
|
||||
logFile: "console" # "console" or path to log file
|
||||
|
||||
# TLS configuration (optional)
|
||||
tls:
|
||||
certFile: ""
|
||||
keyFile: ""
|
||||
letsencrypt:
|
||||
enabled: false
|
||||
dataDir: ""
|
||||
domains: []
|
||||
email: ""
|
||||
awsRoute53: false
|
||||
|
||||
# Shared secret for relay authentication (required when running local relay)
|
||||
authSecret: "your-secret-key-here"
|
||||
|
||||
# Data directory for all services
|
||||
dataDir: "/var/lib/netbird/"
|
||||
|
||||
# ============================================================================
|
||||
# External Service Overrides (optional)
|
||||
# Use these to point to external Signal, Relay, or STUN servers instead of
|
||||
# running them locally. When set, the corresponding local service is disabled.
|
||||
# ============================================================================
|
||||
|
||||
# External STUN servers - disables local STUN server
|
||||
# stuns:
|
||||
# - uri: "stun:stun.example.com:3478"
|
||||
# - uri: "stun:stun.example.com:3479"
|
||||
|
||||
# External relay servers - disables local relay server
|
||||
# relays:
|
||||
# addresses:
|
||||
# - "rels://relay.example.com:443"
|
||||
# credentialsTTL: "12h"
|
||||
# secret: "relay-shared-secret"
|
||||
|
||||
# External signal server - disables local signal server
|
||||
# signalUri: "https://signal.example.com:443"
|
||||
|
||||
# ============================================================================
|
||||
# Management Settings
|
||||
# ============================================================================
|
||||
|
||||
# Metrics and updates
|
||||
disableAnonymousMetrics: false
|
||||
disableGeoliteUpdate: false
|
||||
|
||||
# Embedded authentication/identity provider (Dex) configuration (always enabled)
|
||||
auth:
|
||||
# OIDC issuer URL - must be publicly accessible
|
||||
issuer: "https://server.mycompany.com/oauth2"
|
||||
localAuthDisabled: false
|
||||
signKeyRefreshEnabled: false
|
||||
# OAuth2 redirect URIs for dashboard
|
||||
dashboardRedirectURIs:
|
||||
- "https://app.netbird.io/nb-auth"
|
||||
- "https://app.netbird.io/nb-silent-auth"
|
||||
# OAuth2 redirect URIs for CLI
|
||||
cliRedirectURIs:
|
||||
- "http://localhost:53000/"
|
||||
# Optional initial admin user
|
||||
# owner:
|
||||
# email: "admin@example.com"
|
||||
# password: "initial-password"
|
||||
|
||||
# Store configuration
|
||||
store:
|
||||
engine: "sqlite" # sqlite, postgres, or mysql
|
||||
dsn: "" # Connection string for postgres or mysql
|
||||
encryptionKey: ""
|
||||
|
||||
# Reverse proxy settings (optional)
|
||||
# reverseProxy:
|
||||
# trustedHTTPProxies: []
|
||||
# trustedHTTPProxiesCount: 0
|
||||
# trustedPeers: []
|
||||
@@ -1,115 +0,0 @@
|
||||
# Simplified Combined NetBird Server Configuration
|
||||
# Copy this file to config.yaml and customize for your deployment
|
||||
|
||||
# Server-wide settings
|
||||
server:
|
||||
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
|
||||
listenAddress: ":443"
|
||||
|
||||
# Metrics endpoint port
|
||||
metricsPort: 9090
|
||||
|
||||
# Healthcheck endpoint address
|
||||
healthcheckAddress: ":9000"
|
||||
|
||||
# Logging configuration
|
||||
logLevel: "info" # panic, fatal, error, warn, info, debug, trace
|
||||
logFile: "console" # "console" or path to log file
|
||||
|
||||
# TLS configuration (optional)
|
||||
tls:
|
||||
certFile: ""
|
||||
keyFile: ""
|
||||
letsencrypt:
|
||||
enabled: false
|
||||
dataDir: ""
|
||||
domains: []
|
||||
email: ""
|
||||
awsRoute53: false
|
||||
|
||||
# Relay service configuration
|
||||
relay:
|
||||
# Enable/disable the relay service
|
||||
enabled: true
|
||||
|
||||
# Public address that peers will use to connect to this relay
|
||||
# Format: hostname:port or ip:port
|
||||
exposedAddress: "relay.example.com:443"
|
||||
|
||||
# Shared secret for relay authentication (required when enabled)
|
||||
authSecret: "your-secret-key-here"
|
||||
|
||||
# Log level for relay (reserved for future use, currently uses global log level)
|
||||
logLevel: "info"
|
||||
|
||||
# Embedded STUN server (optional)
|
||||
stun:
|
||||
enabled: false
|
||||
ports: [3478]
|
||||
logLevel: "info"
|
||||
|
||||
# Signal service configuration
|
||||
signal:
|
||||
# Enable/disable the signal service
|
||||
enabled: true
|
||||
|
||||
# Log level for signal (reserved for future use, currently uses global log level)
|
||||
logLevel: "info"
|
||||
|
||||
# Management service configuration
|
||||
management:
|
||||
# Enable/disable the management service
|
||||
enabled: true
|
||||
|
||||
# Data directory for management service
|
||||
dataDir: "/var/lib/netbird/"
|
||||
|
||||
# DNS domain for the management server
|
||||
dnsDomain: ""
|
||||
|
||||
# Metrics and updates
|
||||
disableAnonymousMetrics: false
|
||||
disableGeoliteUpdate: false
|
||||
|
||||
auth:
|
||||
# OIDC issuer URL - must be publicly accessible
|
||||
issuer: "https://management.example.com/oauth2"
|
||||
localAuthDisabled: false
|
||||
signKeyRefreshEnabled: false
|
||||
# OAuth2 redirect URIs for dashboard
|
||||
dashboardRedirectURIs:
|
||||
- "https://app.example.com/nb-auth"
|
||||
- "https://app.example.com/nb-silent-auth"
|
||||
# OAuth2 redirect URIs for CLI
|
||||
cliRedirectURIs:
|
||||
- "http://localhost:53000/"
|
||||
# Optional initial admin user
|
||||
# owner:
|
||||
# email: "admin@example.com"
|
||||
# password: "initial-password"
|
||||
|
||||
# External STUN servers (for client config)
|
||||
stuns: []
|
||||
# - uri: "stun:stun.example.com:3478"
|
||||
|
||||
# External relay servers (for client config)
|
||||
relays:
|
||||
addresses: []
|
||||
# - "rels://relay.example.com:443"
|
||||
credentialsTTL: "12h"
|
||||
secret: ""
|
||||
|
||||
# External signal server URI (for client config)
|
||||
signalUri: ""
|
||||
|
||||
# Store configuration
|
||||
store:
|
||||
engine: "sqlite" # sqlite, postgres, or mysql
|
||||
dsn: "" # Connection string for postgres or mysql
|
||||
encryptionKey: ""
|
||||
|
||||
# Reverse proxy settings
|
||||
reverseProxy:
|
||||
trustedHTTPProxies: []
|
||||
trustedHTTPProxiesCount: 0
|
||||
trustedPeers: []
|
||||
@@ -1,13 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/combined/cmd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := cmd.Execute(); err != nil {
|
||||
log.Fatalf("failed to execute command: %v", err)
|
||||
}
|
||||
}
|
||||
3
go.mod
3
go.mod
@@ -40,7 +40,7 @@ require (
|
||||
github.com/c-robinson/iplib v1.0.3
|
||||
github.com/caddyserver/certmagic v0.21.3
|
||||
github.com/cilium/ebpf v0.15.0
|
||||
github.com/coder/websocket v1.8.14
|
||||
github.com/coder/websocket v1.8.13
|
||||
github.com/coreos/go-iptables v0.7.0
|
||||
github.com/coreos/go-oidc/v3 v3.14.1
|
||||
github.com/creack/pty v1.1.24
|
||||
@@ -83,7 +83,6 @@ require (
|
||||
github.com/pion/stun/v3 v3.1.0
|
||||
github.com/pion/transport/v3 v3.1.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/pires/go-proxyproto v0.11.0
|
||||
github.com/pkg/sftp v1.13.9
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/quic-go/quic-go v0.55.0
|
||||
|
||||
6
go.sum
6
go.sum
@@ -107,8 +107,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
|
||||
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
|
||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
|
||||
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=
|
||||
github.com/containerd/containerd v1.7.29/go.mod h1:azUkWcOvHrWvaiUjSQH0fjzuHIwSPg1WL5PshGP4Szs=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
@@ -474,8 +474,6 @@ github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
|
||||
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
||||
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
||||
github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8=
|
||||
github.com/pires/go-proxyproto v0.11.0 h1:gUQpS85X/VJMdUsYyEgyn59uLJvGqPhJV5YvG68wXH4=
|
||||
github.com/pires/go-proxyproto v0.11.0/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
||||
@@ -99,16 +99,15 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||
|
||||
// Build Dex server config - use Dex's types directly
|
||||
dexConfig := server.Config{
|
||||
Issuer: issuer,
|
||||
Storage: stor,
|
||||
SkipApprovalScreen: true,
|
||||
SupportedResponseTypes: []string{"code"},
|
||||
ContinueOnConnectorFailure: true,
|
||||
Logger: logger,
|
||||
PrometheusRegistry: prometheus.NewRegistry(),
|
||||
RotateKeysAfter: 6 * time.Hour,
|
||||
IDTokensValidFor: 24 * time.Hour,
|
||||
RefreshTokenPolicy: refreshPolicy,
|
||||
Issuer: issuer,
|
||||
Storage: stor,
|
||||
SkipApprovalScreen: true,
|
||||
SupportedResponseTypes: []string{"code"},
|
||||
Logger: logger,
|
||||
PrometheusRegistry: prometheus.NewRegistry(),
|
||||
RotateKeysAfter: 6 * time.Hour,
|
||||
IDTokensValidFor: 24 * time.Hour,
|
||||
RefreshTokenPolicy: refreshPolicy,
|
||||
Web: server.WebConfig{
|
||||
Issuer: "NetBird",
|
||||
},
|
||||
@@ -261,7 +260,6 @@ func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.L
|
||||
if len(cfg.SupportedResponseTypes) == 0 {
|
||||
cfg.SupportedResponseTypes = []string{"code"}
|
||||
}
|
||||
cfg.ContinueOnConnectorFailure = true
|
||||
return cfg
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package dex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -196,64 +195,3 @@ enablePasswordDB: true
|
||||
|
||||
t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID)
|
||||
}
|
||||
|
||||
func TestNewProvider_ContinueOnConnectorFailure(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "dex-connector-failure-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
config := &Config{
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Port: 5556,
|
||||
DataDir: tmpDir,
|
||||
}
|
||||
|
||||
provider, err := NewProvider(ctx, config)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = provider.Stop(ctx) }()
|
||||
|
||||
// The provider should have started successfully even though
|
||||
// ContinueOnConnectorFailure is an internal Dex config field.
|
||||
// We verify the provider is functional by performing a basic operation.
|
||||
assert.NotNil(t, provider.dexServer)
|
||||
assert.NotNil(t, provider.storage)
|
||||
}
|
||||
|
||||
func TestBuildDexConfig_ContinueOnConnectorFailure(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "dex-build-config-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
yamlContent := `
|
||||
issuer: http://localhost:5556/dex
|
||||
storage:
|
||||
type: sqlite3
|
||||
config:
|
||||
file: ` + filepath.Join(tmpDir, "dex.db") + `
|
||||
web:
|
||||
http: 127.0.0.1:5556
|
||||
enablePasswordDB: true
|
||||
`
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
yamlConfig, err := LoadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
stor, err := yamlConfig.Storage.OpenStorage(slog.New(slog.NewTextHandler(os.Stderr, nil)))
|
||||
require.NoError(t, err)
|
||||
defer stor.Close()
|
||||
|
||||
err = initializeStorage(ctx, stor, yamlConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
cfg := buildDexConfig(yamlConfig, stor, logger)
|
||||
|
||||
assert.True(t, cfg.ContinueOnConnectorFailure,
|
||||
"buildDexConfig must set ContinueOnConnectorFailure to true so management starts even if an external IdP is down")
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -29,11 +29,11 @@ import (
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
var newServer = func(cfg *server.Config) server.Server {
|
||||
return server.NewServer(cfg)
|
||||
var newServer = func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server {
|
||||
return server.NewServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
|
||||
}
|
||||
|
||||
func SetNewServer(fn func(*server.Config) server.Server) {
|
||||
func SetNewServer(fn func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server) {
|
||||
newServer = fn
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ var (
|
||||
// detect whether user specified a port
|
||||
userPort := cmd.Flag("port").Changed
|
||||
|
||||
config, err = LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||
config, err = loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading provided config file: %s: %v", nbconfig.MgmtConfigPath, err)
|
||||
}
|
||||
@@ -110,17 +110,7 @@ var (
|
||||
mgmtSingleAccModeDomain = ""
|
||||
}
|
||||
|
||||
srv := newServer(&server.Config{
|
||||
NbConfig: config,
|
||||
DNSDomain: dnsDomain,
|
||||
MgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
|
||||
MgmtPort: mgmtPort,
|
||||
MgmtMetricsPort: mgmtMetricsPort,
|
||||
DisableLegacyManagementPort: disableLegacyManagementPort,
|
||||
DisableMetrics: disableMetrics,
|
||||
DisableGeoliteUpdate: disableGeoliteUpdate,
|
||||
UserDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||
})
|
||||
srv := newServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
|
||||
go func() {
|
||||
if err := srv.Start(cmd.Context()); err != nil {
|
||||
log.Fatalf("Server error: %v", err)
|
||||
@@ -145,35 +135,35 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func LoadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
|
||||
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
|
||||
loadedConfig := &nbconfig.Config{}
|
||||
if _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ApplyCommandLineOverrides(loadedConfig)
|
||||
applyCommandLineOverrides(loadedConfig)
|
||||
|
||||
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
|
||||
err := ApplyEmbeddedIdPConfig(ctx, loadedConfig)
|
||||
err := applyEmbeddedIdPConfig(ctx, loadedConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := ApplyOIDCConfig(ctx, loadedConfig); err != nil {
|
||||
if err := applyOIDCConfig(ctx, loadedConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
LogConfigInfo(loadedConfig)
|
||||
logConfigInfo(loadedConfig)
|
||||
|
||||
if err := EnsureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
|
||||
if err := ensureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loadedConfig, nil
|
||||
}
|
||||
|
||||
// ApplyCommandLineOverrides applies command-line flag overrides to the config
|
||||
func ApplyCommandLineOverrides(cfg *nbconfig.Config) {
|
||||
// applyCommandLineOverrides applies command-line flag overrides to the config
|
||||
func applyCommandLineOverrides(cfg *nbconfig.Config) {
|
||||
if mgmtLetsencryptDomain != "" {
|
||||
cfg.HttpConfig.LetsEncryptDomain = mgmtLetsencryptDomain
|
||||
}
|
||||
@@ -186,9 +176,9 @@ func ApplyCommandLineOverrides(cfg *nbconfig.Config) {
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
|
||||
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
|
||||
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
|
||||
func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
|
||||
return nil
|
||||
}
|
||||
@@ -237,8 +227,8 @@ func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
|
||||
func ApplyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
|
||||
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
|
||||
if oidcEndpoint == "" {
|
||||
return nil
|
||||
@@ -264,16 +254,16 @@ func ApplyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
oidcConfig.JwksURI, cfg.HttpConfig.AuthKeysLocation)
|
||||
cfg.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
|
||||
|
||||
if err := ApplyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
|
||||
if err := applyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
ApplyPKCEFlowConfig(ctx, cfg, &oidcConfig)
|
||||
applyPKCEFlowConfig(ctx, cfg, &oidcConfig)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
|
||||
func ApplyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
|
||||
// applyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
|
||||
func applyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
|
||||
if cfg.DeviceAuthorizationFlow == nil || strings.ToLower(cfg.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE) {
|
||||
return nil
|
||||
}
|
||||
@@ -300,8 +290,8 @@ func ApplyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcCo
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
|
||||
func ApplyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
|
||||
// applyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
|
||||
func applyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
|
||||
if cfg.PKCEAuthorizationFlow == nil {
|
||||
return
|
||||
}
|
||||
@@ -314,8 +304,8 @@ func ApplyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *
|
||||
cfg.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
|
||||
}
|
||||
|
||||
// LogConfigInfo logs informational messages about the loaded configuration
|
||||
func LogConfigInfo(cfg *nbconfig.Config) {
|
||||
// logConfigInfo logs informational messages about the loaded configuration
|
||||
func logConfigInfo(cfg *nbconfig.Config) {
|
||||
if cfg.EmbeddedIdP != nil {
|
||||
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
|
||||
}
|
||||
@@ -324,8 +314,8 @@ func LogConfigInfo(cfg *nbconfig.Config) {
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
|
||||
func EnsureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
|
||||
// ensureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
|
||||
func ensureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
|
||||
if cfg.DataStoreEncryptionKey != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ func Test_loadMgmtConfig(t *testing.T) {
|
||||
t.Fatalf("failed to create config: %s", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadMgmtConfig(context.Background(), tmpFile)
|
||||
cfg, err := loadMgmtConfig(context.Background(), tmpFile)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load management config: %s", err)
|
||||
}
|
||||
|
||||
@@ -16,22 +16,21 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
dnsDomain string
|
||||
mgmtDataDir string
|
||||
logLevel string
|
||||
logFile string
|
||||
disableMetrics bool
|
||||
disableSingleAccMode bool
|
||||
disableGeoliteUpdate bool
|
||||
idpSignKeyRefreshEnabled bool
|
||||
userDeleteFromIDPEnabled bool
|
||||
mgmtPort int
|
||||
mgmtMetricsPort int
|
||||
disableLegacyManagementPort bool
|
||||
mgmtLetsencryptDomain string
|
||||
mgmtSingleAccModeDomain string
|
||||
certFile string
|
||||
certKey string
|
||||
dnsDomain string
|
||||
mgmtDataDir string
|
||||
logLevel string
|
||||
logFile string
|
||||
disableMetrics bool
|
||||
disableSingleAccMode bool
|
||||
disableGeoliteUpdate bool
|
||||
idpSignKeyRefreshEnabled bool
|
||||
userDeleteFromIDPEnabled bool
|
||||
mgmtPort int
|
||||
mgmtMetricsPort int
|
||||
mgmtLetsencryptDomain string
|
||||
mgmtSingleAccModeDomain string
|
||||
certFile string
|
||||
certKey string
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird-mgmt",
|
||||
@@ -56,7 +55,6 @@ func Execute() error {
|
||||
|
||||
func init() {
|
||||
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
|
||||
mgmtCmd.Flags().BoolVar(&disableLegacyManagementPort, "disable-legacy-port", false, "disabling the old legacy port (33073)")
|
||||
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
|
||||
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
|
||||
mgmtCmd.Flags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
|
||||
@@ -83,7 +81,9 @@ func init() {
|
||||
|
||||
rootCmd.AddCommand(migrationCmd)
|
||||
|
||||
tc := newTokenCommands()
|
||||
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||
rootCmd.AddCommand(tc)
|
||||
tokenCmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||
tokenCmd.AddCommand(tokenCreateCmd)
|
||||
tokenCmd.AddCommand(tokenListCmd)
|
||||
tokenCmd.AddCommand(tokenRevokeCmd)
|
||||
rootCmd.AddCommand(tokenCmd)
|
||||
}
|
||||
|
||||
@@ -3,24 +3,62 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var tokenDatadir string
|
||||
var (
|
||||
tokenName string
|
||||
tokenExpireIn string
|
||||
tokenDatadir string
|
||||
|
||||
// newTokenCommands creates the token command tree with management-specific store opener.
|
||||
func newTokenCommands() *cobra.Command {
|
||||
cmd := tokencmd.NewCommands(withTokenStore)
|
||||
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
||||
return cmd
|
||||
tokenCmd = &cobra.Command{
|
||||
Use: "token",
|
||||
Short: "Manage proxy access tokens",
|
||||
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
|
||||
}
|
||||
|
||||
tokenCreateCmd = &cobra.Command{
|
||||
Use: "create",
|
||||
Short: "Create a new proxy access token",
|
||||
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
|
||||
RunE: tokenCreateRun,
|
||||
}
|
||||
|
||||
tokenListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List all proxy access tokens",
|
||||
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
|
||||
RunE: tokenListRun,
|
||||
}
|
||||
|
||||
tokenRevokeCmd = &cobra.Command{
|
||||
Use: "revoke [token-id]",
|
||||
Short: "Revoke a proxy access token",
|
||||
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: tokenRevokeRun,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
||||
|
||||
tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
|
||||
tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
|
||||
tokenCreateCmd.MarkFlagRequired("name") //nolint
|
||||
}
|
||||
|
||||
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
|
||||
@@ -29,9 +67,10 @@ func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Sto
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
//nolint
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
|
||||
|
||||
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||
config, err := loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
@@ -53,3 +92,118 @@ func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Sto
|
||||
|
||||
return fn(ctx, s)
|
||||
}
|
||||
|
||||
func tokenCreateRun(cmd *cobra.Command, _ []string) error {
|
||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||
expiresIn, err := parseDuration(tokenExpireIn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse expiration: %w", err)
|
||||
}
|
||||
|
||||
generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI")
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
|
||||
return fmt.Errorf("save token: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("Token created successfully!") //nolint:forbidigo
|
||||
fmt.Printf("Token: %s\n", generated.PlainToken) //nolint:forbidigo
|
||||
fmt.Println() //nolint:forbidigo
|
||||
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.") //nolint:forbidigo
|
||||
fmt.Printf("Token ID: %s\n", generated.ID) //nolint:forbidigo
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func tokenListRun(cmd *cobra.Command, _ []string) error {
|
||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list tokens: %w", err)
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
fmt.Println("No proxy access tokens found.") //nolint:forbidigo
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
|
||||
fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
|
||||
|
||||
for _, t := range tokens {
|
||||
expires := "never"
|
||||
if t.ExpiresAt != nil {
|
||||
expires = t.ExpiresAt.Format("2006-01-02")
|
||||
}
|
||||
|
||||
lastUsed := "never"
|
||||
if t.LastUsed != nil {
|
||||
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
|
||||
}
|
||||
|
||||
revoked := "no"
|
||||
if t.Revoked {
|
||||
revoked = "yes"
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
t.ID,
|
||||
t.Name,
|
||||
t.CreatedAt.Format("2006-01-02"),
|
||||
expires,
|
||||
lastUsed,
|
||||
revoked,
|
||||
)
|
||||
}
|
||||
|
||||
w.Flush()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func tokenRevokeRun(cmd *cobra.Command, args []string) error {
|
||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||
tokenID := args[0]
|
||||
|
||||
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
|
||||
return fmt.Errorf("revoke token: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Token %s revoked successfully.\n", tokenID) //nolint:forbidigo
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// parseDuration parses a duration string with support for days (e.g., "30d", "365d").
|
||||
// An empty string returns zero duration (no expiration).
|
||||
func parseDuration(s string) (time.Duration, error) {
|
||||
if len(s) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if s[len(s)-1] == 'd' {
|
||||
d, err := strconv.Atoi(s[:len(s)-1])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid day format: %s", s)
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||
}
|
||||
return time.Duration(d) * 24 * time.Hour, nil
|
||||
}
|
||||
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
@@ -1,185 +0,0 @@
|
||||
// Package tokencmd provides reusable cobra commands for managing proxy access tokens.
|
||||
// Both the management and combined binaries use these commands, each providing
|
||||
// their own StoreOpener to handle config loading and store initialization.
|
||||
package tokencmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// StoreOpener initializes a store from the command context and calls fn.
|
||||
type StoreOpener func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error
|
||||
|
||||
// NewCommands creates the token command tree with the given store opener.
|
||||
// Returns the parent "token" command with create, list, and revoke subcommands.
|
||||
func NewCommands(opener StoreOpener) *cobra.Command {
|
||||
var (
|
||||
tokenName string
|
||||
tokenExpireIn string
|
||||
)
|
||||
|
||||
tokenCmd := &cobra.Command{
|
||||
Use: "token",
|
||||
Short: "Manage proxy access tokens",
|
||||
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
|
||||
}
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create",
|
||||
Short: "Create a new proxy access token",
|
||||
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, s store.Store) error {
|
||||
return runCreate(ctx, s, cmd.OutOrStdout(), tokenName, tokenExpireIn)
|
||||
})
|
||||
},
|
||||
}
|
||||
createCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
|
||||
createCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
|
||||
if err := createCmd.MarkFlagRequired("name"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List all proxy access tokens",
|
||||
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, s store.Store) error {
|
||||
return runList(ctx, s, cmd.OutOrStdout())
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
revokeCmd := &cobra.Command{
|
||||
Use: "revoke [token-id]",
|
||||
Short: "Revoke a proxy access token",
|
||||
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return opener(cmd, func(ctx context.Context, s store.Store) error {
|
||||
return runRevoke(ctx, s, cmd.OutOrStdout(), args[0])
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
tokenCmd.AddCommand(createCmd, listCmd, revokeCmd)
|
||||
return tokenCmd
|
||||
}
|
||||
|
||||
func runCreate(ctx context.Context, s store.Store, w io.Writer, name string, expireIn string) error {
|
||||
expiresIn, err := ParseDuration(expireIn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse expiration: %w", err)
|
||||
}
|
||||
|
||||
generated, err := types.CreateNewProxyAccessToken(name, expiresIn, nil, "CLI")
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
|
||||
return fmt.Errorf("save token: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(w, "Token created successfully!")
|
||||
_, _ = fmt.Fprintf(w, "Token: %s\n", generated.PlainToken)
|
||||
_, _ = fmt.Fprintln(w)
|
||||
_, _ = fmt.Fprintln(w, "IMPORTANT: Save this token now. It will not be shown again.")
|
||||
_, _ = fmt.Fprintf(w, "Token ID: %s\n", generated.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func runList(ctx context.Context, s store.Store, out io.Writer) error {
|
||||
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list tokens: %w", err)
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
_, _ = fmt.Fprintln(out, "No proxy access tokens found.")
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0)
|
||||
_, _ = fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
|
||||
_, _ = fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
|
||||
|
||||
for _, t := range tokens {
|
||||
expires := "never"
|
||||
if t.ExpiresAt != nil {
|
||||
expires = t.ExpiresAt.Format("2006-01-02")
|
||||
}
|
||||
|
||||
lastUsed := "never"
|
||||
if t.LastUsed != nil {
|
||||
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
|
||||
}
|
||||
|
||||
revoked := "no"
|
||||
if t.Revoked {
|
||||
revoked = "yes"
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
t.ID,
|
||||
t.Name,
|
||||
t.CreatedAt.Format("2006-01-02"),
|
||||
expires,
|
||||
lastUsed,
|
||||
revoked,
|
||||
)
|
||||
}
|
||||
|
||||
w.Flush()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runRevoke(ctx context.Context, s store.Store, w io.Writer, tokenID string) error {
|
||||
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
|
||||
return fmt.Errorf("revoke token: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "Token %s revoked successfully.\n", tokenID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseDuration parses a duration string with support for days (e.g., "30d", "365d").
|
||||
// An empty string returns zero duration (no expiration).
|
||||
func ParseDuration(s string) (time.Duration, error) {
|
||||
if len(s) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if s[len(s)-1] == 'd' {
|
||||
d, err := strconv.Atoi(s[:len(s)-1])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid day format: %s", s)
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||
}
|
||||
return time.Duration(d) * 24 * time.Hour, nil
|
||||
}
|
||||
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokencmd
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -89,7 +89,7 @@ func TestParseDuration(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := ParseDuration(tt.input)
|
||||
result, err := parseDuration(tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
@@ -39,63 +39,78 @@ type AccessLogFilter struct {
|
||||
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
|
||||
queryParams := r.URL.Query()
|
||||
|
||||
f.Page = parsePositiveInt(queryParams.Get("page"), 1)
|
||||
f.PageSize = min(parsePositiveInt(queryParams.Get("page_size"), DefaultPageSize), MaxPageSize)
|
||||
f.Page = 1
|
||||
if pageStr := queryParams.Get("page"); pageStr != "" {
|
||||
if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
|
||||
f.Page = page
|
||||
}
|
||||
}
|
||||
|
||||
f.Search = parseOptionalString(queryParams.Get("search"))
|
||||
f.SourceIP = parseOptionalString(queryParams.Get("source_ip"))
|
||||
f.Host = parseOptionalString(queryParams.Get("host"))
|
||||
f.Path = parseOptionalString(queryParams.Get("path"))
|
||||
f.UserID = parseOptionalString(queryParams.Get("user_id"))
|
||||
f.UserEmail = parseOptionalString(queryParams.Get("user_email"))
|
||||
f.UserName = parseOptionalString(queryParams.Get("user_name"))
|
||||
f.Method = parseOptionalString(queryParams.Get("method"))
|
||||
f.Status = parseOptionalString(queryParams.Get("status"))
|
||||
f.StatusCode = parseOptionalInt(queryParams.Get("status_code"))
|
||||
f.StartDate = parseOptionalRFC3339(queryParams.Get("start_date"))
|
||||
f.EndDate = parseOptionalRFC3339(queryParams.Get("end_date"))
|
||||
}
|
||||
f.PageSize = DefaultPageSize
|
||||
if pageSizeStr := queryParams.Get("page_size"); pageSizeStr != "" {
|
||||
if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
|
||||
f.PageSize = pageSize
|
||||
if f.PageSize > MaxPageSize {
|
||||
f.PageSize = MaxPageSize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parsePositiveInt parses a positive integer from a string, returning defaultValue if invalid
|
||||
func parsePositiveInt(s string, defaultValue int) int {
|
||||
if s == "" {
|
||||
return defaultValue
|
||||
if search := queryParams.Get("search"); search != "" {
|
||||
f.Search = &search
|
||||
}
|
||||
if val, err := strconv.Atoi(s); err == nil && val > 0 {
|
||||
return val
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// parseOptionalString returns a pointer to the string if non-empty, otherwise nil
|
||||
func parseOptionalString(s string) *string {
|
||||
if s == "" {
|
||||
return nil
|
||||
if sourceIP := queryParams.Get("source_ip"); sourceIP != "" {
|
||||
f.SourceIP = &sourceIP
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
// parseOptionalInt parses an optional positive integer from a string
|
||||
func parseOptionalInt(s string) *int {
|
||||
if s == "" {
|
||||
return nil
|
||||
if host := queryParams.Get("host"); host != "" {
|
||||
f.Host = &host
|
||||
}
|
||||
if val, err := strconv.Atoi(s); err == nil && val > 0 {
|
||||
v := val
|
||||
return &v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseOptionalRFC3339 parses an optional RFC3339 timestamp from a string
|
||||
func parseOptionalRFC3339(s string) *time.Time {
|
||||
if s == "" {
|
||||
return nil
|
||||
if path := queryParams.Get("path"); path != "" {
|
||||
f.Path = &path
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
return &t
|
||||
|
||||
if userID := queryParams.Get("user_id"); userID != "" {
|
||||
f.UserID = &userID
|
||||
}
|
||||
|
||||
if userEmail := queryParams.Get("user_email"); userEmail != "" {
|
||||
f.UserEmail = &userEmail
|
||||
}
|
||||
|
||||
if userName := queryParams.Get("user_name"); userName != "" {
|
||||
f.UserName = &userName
|
||||
}
|
||||
|
||||
if method := queryParams.Get("method"); method != "" {
|
||||
f.Method = &method
|
||||
}
|
||||
|
||||
if status := queryParams.Get("status"); status != "" {
|
||||
f.Status = &status
|
||||
}
|
||||
|
||||
if statusCodeStr := queryParams.Get("status_code"); statusCodeStr != "" {
|
||||
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil && statusCode > 0 {
|
||||
f.StatusCode = &statusCode
|
||||
}
|
||||
}
|
||||
|
||||
if startDate := queryParams.Get("start_date"); startDate != "" {
|
||||
parsedStartDate, err := time.Parse(time.RFC3339, startDate)
|
||||
if err == nil {
|
||||
f.StartDate = &parsedStartDate
|
||||
}
|
||||
}
|
||||
|
||||
if endDate := queryParams.Get("end_date"); endDate != "" {
|
||||
parsedEndDate, err := time.Parse(time.RFC3339, endDate)
|
||||
if err == nil {
|
||||
f.EndDate = &parsedEndDate
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOffset calculates the database offset for pagination
|
||||
|
||||
@@ -4,10 +4,8 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccessLogFilter_ParseFromRequest(t *testing.T) {
|
||||
@@ -161,211 +159,3 @@ func TestAccessLogFilter_GetLimit(t *testing.T) {
|
||||
limit := filter.GetLimit()
|
||||
assert.Equal(t, 25, limit, "GetLimit should return PageSize")
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_ParseFromRequest_FilterParams(t *testing.T) {
|
||||
startDate := "2024-01-15T10:30:00Z"
|
||||
endDate := "2024-01-16T15:45:00Z"
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
q := req.URL.Query()
|
||||
q.Set("search", "test query")
|
||||
q.Set("source_ip", "192.168.1.1")
|
||||
q.Set("host", "example.com")
|
||||
q.Set("path", "/api/users")
|
||||
q.Set("user_id", "user123")
|
||||
q.Set("user_email", "user@example.com")
|
||||
q.Set("user_name", "John Doe")
|
||||
q.Set("method", "GET")
|
||||
q.Set("status", "success")
|
||||
q.Set("status_code", "200")
|
||||
q.Set("start_date", startDate)
|
||||
q.Set("end_date", endDate)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
require.NotNil(t, filter.Search)
|
||||
assert.Equal(t, "test query", *filter.Search)
|
||||
|
||||
require.NotNil(t, filter.SourceIP)
|
||||
assert.Equal(t, "192.168.1.1", *filter.SourceIP)
|
||||
|
||||
require.NotNil(t, filter.Host)
|
||||
assert.Equal(t, "example.com", *filter.Host)
|
||||
|
||||
require.NotNil(t, filter.Path)
|
||||
assert.Equal(t, "/api/users", *filter.Path)
|
||||
|
||||
require.NotNil(t, filter.UserID)
|
||||
assert.Equal(t, "user123", *filter.UserID)
|
||||
|
||||
require.NotNil(t, filter.UserEmail)
|
||||
assert.Equal(t, "user@example.com", *filter.UserEmail)
|
||||
|
||||
require.NotNil(t, filter.UserName)
|
||||
assert.Equal(t, "John Doe", *filter.UserName)
|
||||
|
||||
require.NotNil(t, filter.Method)
|
||||
assert.Equal(t, "GET", *filter.Method)
|
||||
|
||||
require.NotNil(t, filter.Status)
|
||||
assert.Equal(t, "success", *filter.Status)
|
||||
|
||||
require.NotNil(t, filter.StatusCode)
|
||||
assert.Equal(t, 200, *filter.StatusCode)
|
||||
|
||||
require.NotNil(t, filter.StartDate)
|
||||
expectedStart, _ := time.Parse(time.RFC3339, startDate)
|
||||
assert.Equal(t, expectedStart, *filter.StartDate)
|
||||
|
||||
require.NotNil(t, filter.EndDate)
|
||||
expectedEnd, _ := time.Parse(time.RFC3339, endDate)
|
||||
assert.Equal(t, expectedEnd, *filter.EndDate)
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_ParseFromRequest_EmptyFilters(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Nil(t, filter.Search)
|
||||
assert.Nil(t, filter.SourceIP)
|
||||
assert.Nil(t, filter.Host)
|
||||
assert.Nil(t, filter.Path)
|
||||
assert.Nil(t, filter.UserID)
|
||||
assert.Nil(t, filter.UserEmail)
|
||||
assert.Nil(t, filter.UserName)
|
||||
assert.Nil(t, filter.Method)
|
||||
assert.Nil(t, filter.Status)
|
||||
assert.Nil(t, filter.StatusCode)
|
||||
assert.Nil(t, filter.StartDate)
|
||||
assert.Nil(t, filter.EndDate)
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_ParseFromRequest_InvalidFilters(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
q := req.URL.Query()
|
||||
q.Set("status_code", "invalid")
|
||||
q.Set("start_date", "not-a-date")
|
||||
q.Set("end_date", "2024-99-99")
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Nil(t, filter.StatusCode, "invalid status_code should be nil")
|
||||
assert.Nil(t, filter.StartDate, "invalid start_date should be nil")
|
||||
assert.Nil(t, filter.EndDate, "invalid end_date should be nil")
|
||||
}
|
||||
|
||||
func TestParsePositiveInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
defaultValue int
|
||||
expected int
|
||||
}{
|
||||
{"empty string", "", 10, 10},
|
||||
{"valid positive int", "25", 10, 25},
|
||||
{"zero", "0", 10, 10},
|
||||
{"negative", "-5", 10, 10},
|
||||
{"invalid string", "abc", 10, 10},
|
||||
{"float", "3.14", 10, 10},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parsePositiveInt(tt.input, tt.defaultValue)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOptionalString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected *string
|
||||
}{
|
||||
{"empty string", "", nil},
|
||||
{"valid string", "hello", strPtr("hello")},
|
||||
{"whitespace", " ", strPtr(" ")},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseOptionalString(tt.input)
|
||||
if tt.expected == nil {
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, *tt.expected, *result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOptionalInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected *int
|
||||
}{
|
||||
{"empty string", "", nil},
|
||||
{"valid positive int", "42", intPtr(42)},
|
||||
{"zero", "0", nil},
|
||||
{"negative", "-10", nil},
|
||||
{"invalid string", "abc", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseOptionalInt(tt.input)
|
||||
if tt.expected == nil {
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, *tt.expected, *result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOptionalRFC3339(t *testing.T) {
|
||||
validDate := "2024-01-15T10:30:00Z"
|
||||
expectedTime, _ := time.Parse(time.RFC3339, validDate)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected *time.Time
|
||||
}{
|
||||
{"empty string", "", nil},
|
||||
{"valid RFC3339", validDate, &expectedTime},
|
||||
{"invalid format", "2024-01-15", nil},
|
||||
{"invalid date", "not-a-date", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseOptionalRFC3339(tt.input)
|
||||
if tt.expected == nil {
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, *tt.expected, *result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for creating pointers
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func intPtr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
|
||||
@@ -135,11 +135,54 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
|
||||
return nil, err
|
||||
var proxyCluster string
|
||||
if m.clusterDeriver != nil {
|
||||
proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
|
||||
return nil, status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.persistNewService(ctx, accountID, service); err != nil {
|
||||
service.AccountID = accountID
|
||||
service.ProxyCluster = proxyCluster
|
||||
service.InitNewRecord()
|
||||
err = service.Auth.HashSecrets()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
// Generate session JWT signing keys
|
||||
keyPair, err := sessionkey.GenerateKeyPair()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate session keys: %w", err)
|
||||
}
|
||||
service.SessionPrivateKey = keyPair.PrivateKey
|
||||
service.SessionPublicKey = keyPair.PublicKey
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
// Check for duplicate domain
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing service: %w", err)
|
||||
}
|
||||
}
|
||||
if existingService != nil {
|
||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
|
||||
}
|
||||
|
||||
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.CreateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -157,67 +200,6 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
if m.clusterDeriver != nil {
|
||||
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
|
||||
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
|
||||
}
|
||||
service.ProxyCluster = proxyCluster
|
||||
}
|
||||
|
||||
service.AccountID = accountID
|
||||
service.InitNewRecord()
|
||||
|
||||
if err := service.Auth.HashSecrets(); err != nil {
|
||||
return fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
keyPair, err := sessionkey.GenerateKeyPair()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate session keys: %w", err)
|
||||
}
|
||||
service.SessionPrivateKey = keyPair.PrivateKey
|
||||
service.SessionPublicKey = keyPair.PublicKey
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing service: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if existingService != nil && existingService.ID != excludeServiceID {
|
||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||
if err != nil {
|
||||
@@ -227,122 +209,99 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := service.Auth.HashSecrets(); err != nil {
|
||||
var oldCluster string
|
||||
var domainChanged bool
|
||||
var serviceEnabledChanged bool
|
||||
|
||||
err = service.Auth.HashSecrets()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
updateInfo, err := m.persistServiceUpdate(ctx, accountID, service)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldCluster = existingService.ProxyCluster
|
||||
|
||||
if existingService.Domain != service.Domain {
|
||||
domainChanged = true
|
||||
conflictService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("check existing service: %w", err)
|
||||
}
|
||||
}
|
||||
if conflictService != nil && conflictService.ID != service.ID {
|
||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
|
||||
}
|
||||
|
||||
if m.clusterDeriver != nil {
|
||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
||||
}
|
||||
service.ProxyCluster = newCluster
|
||||
}
|
||||
} else {
|
||||
service.ProxyCluster = existingService.ProxyCluster
|
||||
}
|
||||
|
||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||
service.Auth.PasswordAuth.Password == "" {
|
||||
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
||||
}
|
||||
|
||||
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
||||
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
||||
service.Auth.PinAuth.Pin == "" {
|
||||
service.Auth.PinAuth = existingService.Auth.PinAuth
|
||||
}
|
||||
|
||||
service.Meta = existingService.Meta
|
||||
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||
service.SessionPublicKey = existingService.SessionPublicKey
|
||||
serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||
|
||||
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("update service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
|
||||
|
||||
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m.sendServiceUpdateNotifications(service, updateInfo)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
type serviceUpdateInfo struct {
|
||||
oldCluster string
|
||||
domainChanged bool
|
||||
serviceEnabledChanged bool
|
||||
}
|
||||
|
||||
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
|
||||
var updateInfo serviceUpdateInfo
|
||||
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateInfo.oldCluster = existingService.ProxyCluster
|
||||
updateInfo.domainChanged = existingService.Domain != service.Domain
|
||||
|
||||
if updateInfo.domainChanged {
|
||||
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
service.ProxyCluster = existingService.ProxyCluster
|
||||
}
|
||||
|
||||
m.preserveExistingAuthSecrets(service, existingService)
|
||||
m.preserveServiceMetadata(service, existingService)
|
||||
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("update service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return &updateInfo, err
|
||||
}
|
||||
|
||||
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.clusterDeriver != nil {
|
||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
||||
} else {
|
||||
service.ProxyCluster = newCluster
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
|
||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||
service.Auth.PasswordAuth.Password == "" {
|
||||
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
||||
}
|
||||
|
||||
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
||||
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
||||
service.Auth.PinAuth.Pin == "" {
|
||||
service.Auth.PinAuth = existingService.Auth.PinAuth
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
|
||||
service.Meta = existingService.Meta
|
||||
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||
service.SessionPublicKey = existingService.SessionPublicKey
|
||||
}
|
||||
|
||||
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
|
||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
|
||||
switch {
|
||||
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster)
|
||||
case domainChanged && oldCluster != service.ProxyCluster:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), oldCluster)
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||
case !service.Enabled && updateInfo.serviceEnabledChanged:
|
||||
case !service.Enabled && serviceEnabledChanged:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
|
||||
case service.Enabled && updateInfo.serviceEnabledChanged:
|
||||
case service.Enabled && serviceEnabledChanged:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||
default:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
|
||||
}
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||
@@ -473,6 +432,8 @@ func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||
}
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,375 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func TestInitializeServiceForCreate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("successful initialization without cluster deriver", func(t *testing.T) {
|
||||
mgr := &managerImpl{
|
||||
clusterDeriver: nil,
|
||||
}
|
||||
|
||||
service := &reverseproxy.Service{
|
||||
Domain: "example.com",
|
||||
Auth: reverseproxy.AuthConfig{},
|
||||
}
|
||||
|
||||
err := mgr.initializeServiceForCreate(ctx, accountID, service)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, accountID, service.AccountID)
|
||||
assert.Empty(t, service.ProxyCluster, "proxy cluster should be empty when no deriver")
|
||||
assert.NotEmpty(t, service.ID, "service ID should be initialized")
|
||||
assert.NotEmpty(t, service.SessionPrivateKey, "session private key should be generated")
|
||||
assert.NotEmpty(t, service.SessionPublicKey, "session public key should be generated")
|
||||
})
|
||||
|
||||
t.Run("verifies session keys are different", func(t *testing.T) {
|
||||
mgr := &managerImpl{
|
||||
clusterDeriver: nil,
|
||||
}
|
||||
|
||||
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}}
|
||||
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}}
|
||||
|
||||
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
|
||||
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
|
||||
|
||||
assert.NoError(t, err1)
|
||||
assert.NoError(t, err2)
|
||||
assert.NotEqual(t, service1.SessionPrivateKey, service2.SessionPrivateKey, "private keys should be unique")
|
||||
assert.NotEqual(t, service1.SessionPublicKey, service2.SessionPublicKey, "public keys should be unique")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckDomainAvailable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
excludeServiceID string
|
||||
setupMock func(*store.MockStore)
|
||||
expectedError bool
|
||||
errorType status.Type
|
||||
}{
|
||||
{
|
||||
name: "domain available - not found",
|
||||
domain: "available.com",
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "available.com").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "domain already exists",
|
||||
domain: "exists.com",
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
errorType: status.AlreadyExists,
|
||||
},
|
||||
{
|
||||
name: "domain exists but excluded (same ID)",
|
||||
domain: "exists.com",
|
||||
excludeServiceID: "service-123",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "domain exists with different ID",
|
||||
domain: "exists.com",
|
||||
excludeServiceID: "service-456",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
errorType: status.AlreadyExists,
|
||||
},
|
||||
{
|
||||
name: "store error (non-NotFound)",
|
||||
domain: "error.com",
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "error.com").
|
||||
Return(nil, errors.New("database error"))
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
tt.setupMock(mockStore)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
|
||||
|
||||
if tt.expectedError {
|
||||
require.Error(t, err)
|
||||
if tt.errorType != 0 {
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok, "error should be a status error")
|
||||
assert.Equal(t, tt.errorType, sErr.Type())
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("empty domain", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty exclude ID with existing service", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "test.com").
|
||||
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
|
||||
|
||||
assert.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, status.AlreadyExists, sErr.Type())
|
||||
})
|
||||
|
||||
t.Run("nil existing service with nil error", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "nil.com").
|
||||
Return(nil, nil)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPersistNewService(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("successful service creation with no targets", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
service := &reverseproxy.Service{
|
||||
ID: "service-123",
|
||||
Domain: "new.com",
|
||||
Targets: []*reverseproxy.Target{},
|
||||
}
|
||||
|
||||
// Mock ExecuteInTransaction to execute the function immediately
|
||||
mockStore.EXPECT().
|
||||
ExecuteInTransaction(ctx, gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
||||
// Create another mock for the transaction
|
||||
txMock := store.NewMockStore(ctrl)
|
||||
txMock.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "new.com").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
txMock.EXPECT().
|
||||
CreateService(ctx, service).
|
||||
Return(nil)
|
||||
|
||||
return fn(txMock)
|
||||
})
|
||||
|
||||
mgr := &managerImpl{store: mockStore}
|
||||
err := mgr.persistNewService(ctx, accountID, service)
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("domain already exists", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
service := &reverseproxy.Service{
|
||||
ID: "service-123",
|
||||
Domain: "existing.com",
|
||||
Targets: []*reverseproxy.Target{},
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
ExecuteInTransaction(ctx, gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
||||
txMock := store.NewMockStore(ctrl)
|
||||
txMock.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "existing.com").
|
||||
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil)
|
||||
|
||||
return fn(txMock)
|
||||
})
|
||||
|
||||
mgr := &managerImpl{store: mockStore}
|
||||
err := mgr.persistNewService(ctx, accountID, service)
|
||||
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, status.AlreadyExists, sErr.Type())
|
||||
})
|
||||
}
|
||||
func TestPreserveExistingAuthSecrets(t *testing.T) {
|
||||
mgr := &managerImpl{}
|
||||
|
||||
t.Run("preserve password when empty", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "hashed-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||
|
||||
assert.Equal(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
|
||||
})
|
||||
|
||||
t.Run("preserve pin when empty", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PinAuth: &reverseproxy.PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "hashed-pin",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PinAuth: &reverseproxy.PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||
|
||||
assert.Equal(t, existing.Auth.PinAuth, updated.Auth.PinAuth)
|
||||
})
|
||||
|
||||
t.Run("do not preserve when password is provided", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "old-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "new-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||
|
||||
assert.Equal(t, "new-password", updated.Auth.PasswordAuth.Password)
|
||||
assert.NotEqual(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPreserveServiceMetadata(t *testing.T) {
|
||||
mgr := &managerImpl{}
|
||||
|
||||
existing := &reverseproxy.Service{
|
||||
Meta: reverseproxy.ServiceMeta{
|
||||
CertificateIssuedAt: time.Now(),
|
||||
Status: "active",
|
||||
},
|
||||
SessionPrivateKey: "private-key",
|
||||
SessionPublicKey: "public-key",
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Domain: "updated.com",
|
||||
}
|
||||
|
||||
mgr.preserveServiceMetadata(updated, existing)
|
||||
|
||||
assert.Equal(t, existing.Meta, updated.Meta)
|
||||
assert.Equal(t, existing.SessionPrivateKey, updated.SessionPrivateKey)
|
||||
assert.Equal(t, existing.SessionPublicKey, updated.SessionPublicKey)
|
||||
}
|
||||
@@ -50,14 +50,13 @@ type BaseServer struct {
|
||||
// AfterInit is a function that will be called after the server is initialized
|
||||
afterInit []func(s *BaseServer)
|
||||
|
||||
disableMetrics bool
|
||||
dnsDomain string
|
||||
disableGeoliteUpdate bool
|
||||
userDeleteFromIDPEnabled bool
|
||||
mgmtSingleAccModeDomain string
|
||||
mgmtMetricsPort int
|
||||
mgmtPort int
|
||||
disableLegacyManagementPort bool
|
||||
disableMetrics bool
|
||||
dnsDomain string
|
||||
disableGeoliteUpdate bool
|
||||
userDeleteFromIDPEnabled bool
|
||||
mgmtSingleAccModeDomain string
|
||||
mgmtMetricsPort int
|
||||
mgmtPort int
|
||||
|
||||
proxyAuthClose func()
|
||||
|
||||
@@ -70,32 +69,18 @@ type BaseServer struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Config holds the configuration parameters for creating a new server
|
||||
type Config struct {
|
||||
NbConfig *nbconfig.Config
|
||||
DNSDomain string
|
||||
MgmtSingleAccModeDomain string
|
||||
MgmtPort int
|
||||
MgmtMetricsPort int
|
||||
DisableLegacyManagementPort bool
|
||||
DisableMetrics bool
|
||||
DisableGeoliteUpdate bool
|
||||
UserDeleteFromIDPEnabled bool
|
||||
}
|
||||
|
||||
// NewServer initializes and configures a new Server instance
|
||||
func NewServer(cfg *Config) *BaseServer {
|
||||
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer {
|
||||
return &BaseServer{
|
||||
Config: cfg.NbConfig,
|
||||
container: make(map[string]any),
|
||||
dnsDomain: cfg.DNSDomain,
|
||||
mgmtSingleAccModeDomain: cfg.MgmtSingleAccModeDomain,
|
||||
disableMetrics: cfg.DisableMetrics,
|
||||
disableGeoliteUpdate: cfg.DisableGeoliteUpdate,
|
||||
userDeleteFromIDPEnabled: cfg.UserDeleteFromIDPEnabled,
|
||||
mgmtPort: cfg.MgmtPort,
|
||||
disableLegacyManagementPort: cfg.DisableLegacyManagementPort,
|
||||
mgmtMetricsPort: cfg.MgmtMetricsPort,
|
||||
Config: config,
|
||||
container: make(map[string]any),
|
||||
dnsDomain: dnsDomain,
|
||||
mgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
|
||||
disableMetrics: disableMetrics,
|
||||
disableGeoliteUpdate: disableGeoliteUpdate,
|
||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||
mgmtPort: mgmtPort,
|
||||
mgmtMetricsPort: mgmtMetricsPort,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,19 +140,8 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
go metricsWorker.Run(srvCtx)
|
||||
}
|
||||
|
||||
// Eagerly create the gRPC server so that all AfterInit hooks are registered
|
||||
// before we iterate them. Lazy creation after the loop would miss hooks
|
||||
// registered during GRPCServer() construction (e.g., SetProxyManager).
|
||||
s.GRPCServer()
|
||||
|
||||
for _, fn := range s.afterInit {
|
||||
if fn != nil {
|
||||
fn(s)
|
||||
}
|
||||
}
|
||||
|
||||
var compatListener net.Listener
|
||||
if s.mgmtPort != ManagementLegacyPort && !s.disableLegacyManagementPort {
|
||||
if s.mgmtPort != ManagementLegacyPort {
|
||||
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
|
||||
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
|
||||
compatListener, err = s.serveGRPC(srvCtx, s.GRPCServer(), ManagementLegacyPort)
|
||||
@@ -206,6 +180,12 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
for _, fn := range s.afterInit {
|
||||
if fn != nil {
|
||||
fn(s)
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
|
||||
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
|
||||
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
|
||||
@@ -282,23 +262,7 @@ func (s *BaseServer) SetContainer(key string, container any) {
|
||||
log.Tracef("container with key %s set successfully", key)
|
||||
}
|
||||
|
||||
// SetHandlerFunc allows overriding the default HTTP handler function.
|
||||
// This is useful for multiplexing additional services on the same port.
|
||||
func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
|
||||
s.container["customHandler"] = handler
|
||||
log.Tracef("custom handler set successfully")
|
||||
}
|
||||
|
||||
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||
// Check if a custom handler was set (for multiplexing additional services)
|
||||
if customHandler, ok := s.GetContainer("customHandler"); ok {
|
||||
if handler, ok := customHandler.(http.Handler); ok {
|
||||
log.Tracef("using custom handler")
|
||||
return handler
|
||||
}
|
||||
}
|
||||
|
||||
// Use default handler
|
||||
wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))
|
||||
|
||||
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
|
||||
@@ -15,14 +15,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
@@ -520,11 +519,61 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
|
||||
}
|
||||
|
||||
authenticated, userId, method := s.authenticateRequest(ctx, req, service)
|
||||
var authenticated bool
|
||||
var userId string
|
||||
var method proxyauth.Method
|
||||
switch v := req.GetRequest().(type) {
|
||||
case *proto.AuthenticateRequest_Pin:
|
||||
auth := service.Auth.PinAuth
|
||||
if auth == nil || !auth.Enabled {
|
||||
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", req.GetId())
|
||||
break
|
||||
}
|
||||
err = argon2id.Verify(v.Pin.GetPin(), auth.Pin)
|
||||
if err != nil {
|
||||
if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
||||
log.WithContext(ctx).Tracef("PIN authentication failed: invalid PIN")
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("PIN authentication error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
authenticated = true
|
||||
userId = "pin-user"
|
||||
method = proxyauth.MethodPIN
|
||||
case *proto.AuthenticateRequest_Password:
|
||||
auth := service.Auth.PasswordAuth
|
||||
if auth == nil || !auth.Enabled {
|
||||
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", req.GetId())
|
||||
break
|
||||
}
|
||||
err = argon2id.Verify(v.Password.GetPassword(), auth.Password)
|
||||
if err != nil {
|
||||
if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
||||
log.WithContext(ctx).Tracef("Password authentication failed: invalid password")
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("Password authentication error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
authenticated = true
|
||||
userId = "password-user"
|
||||
method = proxyauth.MethodPassword
|
||||
}
|
||||
|
||||
token, err := s.generateSessionToken(ctx, authenticated, service, userId, method)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var token string
|
||||
if authenticated && service.SessionPrivateKey != "" {
|
||||
token, err = sessionkey.SignToken(
|
||||
service.SessionPrivateKey,
|
||||
userId,
|
||||
service.Domain,
|
||||
method,
|
||||
proxyauth.DefaultSessionExpiry,
|
||||
)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to sign session token")
|
||||
return nil, status.Errorf(codes.Internal, "sign session token: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &proto.AuthenticateResponse{
|
||||
@@ -533,73 +582,6 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *reverseproxy.Service) (bool, string, proxyauth.Method) {
|
||||
switch v := req.GetRequest().(type) {
|
||||
case *proto.AuthenticateRequest_Pin:
|
||||
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
|
||||
case *proto.AuthenticateRequest_Password:
|
||||
return s.authenticatePassword(ctx, req.GetId(), v, service.Auth.PasswordAuth)
|
||||
default:
|
||||
return false, "", ""
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *reverseproxy.PINAuthConfig) (bool, string, proxyauth.Method) {
|
||||
if auth == nil || !auth.Enabled {
|
||||
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID)
|
||||
return false, "", ""
|
||||
}
|
||||
|
||||
if err := argon2id.Verify(req.Pin.GetPin(), auth.Pin); err != nil {
|
||||
s.logAuthenticationError(ctx, err, "PIN")
|
||||
return false, "", ""
|
||||
}
|
||||
|
||||
return true, "pin-user", proxyauth.MethodPIN
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *reverseproxy.PasswordAuthConfig) (bool, string, proxyauth.Method) {
|
||||
if auth == nil || !auth.Enabled {
|
||||
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID)
|
||||
return false, "", ""
|
||||
}
|
||||
|
||||
if err := argon2id.Verify(req.Password.GetPassword(), auth.Password); err != nil {
|
||||
s.logAuthenticationError(ctx, err, "Password")
|
||||
return false, "", ""
|
||||
}
|
||||
|
||||
return true, "password-user", proxyauth.MethodPassword
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err error, authType string) {
|
||||
if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
||||
log.WithContext(ctx).Tracef("%s authentication failed: invalid credentials", authType)
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("%s authentication error: %v", authType, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *reverseproxy.Service, userId string, method proxyauth.Method) (string, error) {
|
||||
if !authenticated || service.SessionPrivateKey == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
token, err := sessionkey.SignToken(
|
||||
service.SessionPrivateKey,
|
||||
userId,
|
||||
service.Domain,
|
||||
method,
|
||||
proxyauth.DefaultSessionExpiry,
|
||||
)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to sign session token")
|
||||
return "", status.Errorf(codes.Internal, "sign session token: %v", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// SendStatusUpdate handles status updates from proxy clients
|
||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||
accountID := req.GetAccountId()
|
||||
|
||||
@@ -224,7 +224,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
s.syncSem.Add(1)
|
||||
|
||||
reqStart := time.Now()
|
||||
syncStart := reqStart.UTC()
|
||||
|
||||
ctx := srv.Context()
|
||||
|
||||
@@ -301,7 +300,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
metahash := metaHash(peerMeta, realIP.String())
|
||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
@@ -312,7 +311,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -320,7 +319,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -337,7 +336,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
|
||||
s.syncSem.Add(-1)
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, syncStart)
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart)
|
||||
}
|
||||
|
||||
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
||||
|
||||
@@ -297,7 +297,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
var oldSettings *types.Settings
|
||||
var updateAccountPeers bool
|
||||
var groupChangesAffectPeers bool
|
||||
var reloadReverseProxy bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var groupsUpdated bool
|
||||
@@ -328,7 +327,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
||||
return err
|
||||
}
|
||||
reloadReverseProxy = true
|
||||
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
|
||||
}
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
@@ -393,11 +394,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
|
||||
}
|
||||
if reloadReverseProxy {
|
||||
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
|
||||
go am.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
@@ -3918,36 +3918,3 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
|
||||
assert.False(t, user.PendingApproval, "User should not be pending approval")
|
||||
assert.Equal(t, existingAccountID, user.AccountID)
|
||||
}
|
||||
|
||||
// TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange verifies that
|
||||
// changing NetworkRange via UpdateAccountSettings does not deadlock.
|
||||
// The deadlock occurs because ReloadAllServicesForAccount is called inside a DB
|
||||
// transaction but uses the main store connection, which blocks on the transaction lock.
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Use a channel to detect if the call completes or hangs
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
NetworkRange: netip.MustParsePrefix("10.100.0.0/16"),
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.NoError(t, err, "UpdateAccountSettings should complete without error")
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,8 +18,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
@@ -378,9 +376,9 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read)
|
||||
err = h.permissionsManager.ValidateAccountAccess(r.Context(), accountID, user, false)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
|
||||
util.WriteError(r.Context(), status.NewPermissionDeniedError(), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -390,12 +388,9 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if !allowed && !userAuth.IsChild {
|
||||
if account.Settings.RegularUsersViewBlocked {
|
||||
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
|
||||
return
|
||||
}
|
||||
|
||||
// If the user is regular user and does not own the peer
|
||||
// with the given peerID return an empty list
|
||||
if !user.HasAdminPower() && !user.IsServiceUser && !userAuth.IsChild {
|
||||
peer, ok := account.Peers[peerID]
|
||||
if !ok {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
|
||||
|
||||
@@ -22,8 +22,6 @@ import (
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
@@ -117,16 +115,6 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
||||
ctrl2 := gomock.NewController(t)
|
||||
permissionsManager := permissions.NewMockManager(ctrl2)
|
||||
permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
permissionsManager.EXPECT().
|
||||
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Eq(modules.Peers), gomock.Eq(operations.Read)).
|
||||
DoAndReturn(func(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error) {
|
||||
user, ok := account.Users[userID]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("user not found")
|
||||
}
|
||||
return user.HasAdminPower() || user.IsServiceUser, nil
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
return &Handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
@@ -395,11 +383,12 @@ func TestGetAccessiblePeers(t *testing.T) {
|
||||
UserID: regularUser,
|
||||
}
|
||||
|
||||
p := initTestMetaData(t, peer1, peer2, peer3)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
peerID string
|
||||
callerUserID string
|
||||
viewBlocked bool
|
||||
expectedStatus int
|
||||
expectedPeers []string
|
||||
}{
|
||||
@@ -438,56 +427,10 @@ func TestGetAccessiblePeers(t *testing.T) {
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{"peer1", "peer2"},
|
||||
},
|
||||
{
|
||||
name: "regular user gets empty for owned peer list when view blocked",
|
||||
peerID: "peer1",
|
||||
callerUserID: regularUser,
|
||||
viewBlocked: true,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{},
|
||||
},
|
||||
{
|
||||
name: "regular user gets empty list for unowned peer when view blocked",
|
||||
peerID: "peer2",
|
||||
callerUserID: regularUser,
|
||||
viewBlocked: true,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{},
|
||||
},
|
||||
{
|
||||
name: "admin user still sees accessible peers when view blocked",
|
||||
peerID: "peer2",
|
||||
callerUserID: adminUser,
|
||||
viewBlocked: true,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{"peer1", "peer3"},
|
||||
},
|
||||
{
|
||||
name: "service user still sees accessible peers when view blocked",
|
||||
peerID: "peer3",
|
||||
callerUserID: serviceUser,
|
||||
viewBlocked: true,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{"peer1", "peer2"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := initTestMetaData(t, peer1, peer2, peer3)
|
||||
|
||||
if tc.viewBlocked {
|
||||
mockAM := p.accountManager.(*mock_server.MockAccountManager)
|
||||
originalGetAccountByIDFunc := mockAM.GetAccountByIDFunc
|
||||
mockAM.GetAccountByIDFunc = func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
|
||||
account, err := originalGetAccountByIDFunc(ctx, accountID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
account.Settings.RegularUsersViewBlocked = true
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
|
||||
|
||||
@@ -561,99 +561,6 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri
|
||||
return account.Network.Copy(), err
|
||||
}
|
||||
|
||||
type peerAddAuthConfig struct {
|
||||
AccountID string
|
||||
SetupKeyID string
|
||||
SetupKeyName string
|
||||
GroupsToAdd []string
|
||||
AllowExtraDNSLabels bool
|
||||
Ephemeral bool
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) processPeerAddAuth(ctx context.Context, accountID, userID, encodedHashedKey string, peer *nbpeer.Peer, temporary, addedByUser, addedBySetupKey bool, opEvent *activity.Event) (*peerAddAuthConfig, error) {
|
||||
config := &peerAddAuthConfig{
|
||||
AccountID: accountID,
|
||||
Ephemeral: peer.Ephemeral,
|
||||
}
|
||||
|
||||
switch {
|
||||
case addedByUser:
|
||||
if err := am.handleUserAddedPeer(ctx, accountID, userID, temporary, opEvent, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case addedBySetupKey:
|
||||
if err := am.handleSetupKeyAddedPeer(ctx, encodedHashedKey, peer, opEvent, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
if peer.ProxyMeta.Embedded {
|
||||
log.WithContext(ctx).Debugf("adding peer for proxy embedded, accountID: %s", accountID)
|
||||
} else {
|
||||
log.WithContext(ctx).Warnf("adding peer without setup key or userID, accountID: %s", accountID)
|
||||
}
|
||||
}
|
||||
|
||||
opEvent.AccountID = config.AccountID
|
||||
if temporary {
|
||||
config.Ephemeral = true
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleUserAddedPeer(ctx context.Context, accountID, userID string, temporary bool, opEvent *activity.Event, config *peerAddAuthConfig) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "failed adding new peer: user not found")
|
||||
}
|
||||
if user.PendingApproval {
|
||||
return status.Errorf(status.PermissionDenied, "user pending approval cannot add peers")
|
||||
}
|
||||
|
||||
if temporary {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
} else {
|
||||
config.AccountID = user.AccountID
|
||||
config.GroupsToAdd = user.AutoGroups
|
||||
}
|
||||
|
||||
opEvent.InitiatorID = userID
|
||||
opEvent.Activity = activity.PeerAddedByUser
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleSetupKeyAddedPeer(ctx context.Context, encodedHashedKey string, peer *nbpeer.Peer, opEvent *activity.Event, config *peerAddAuthConfig) error {
|
||||
sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
if !sk.IsValid() {
|
||||
return status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
|
||||
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
|
||||
}
|
||||
|
||||
opEvent.InitiatorID = sk.Id
|
||||
opEvent.Activity = activity.PeerAddedWithSetupKey
|
||||
config.GroupsToAdd = sk.AutoGroups
|
||||
config.Ephemeral = sk.Ephemeral
|
||||
config.SetupKeyID = sk.Id
|
||||
config.SetupKeyName = sk.Name
|
||||
config.AllowExtraDNSLabels = sk.AllowExtraDNSLabels
|
||||
config.AccountID = sk.AccountID
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeer adds a new peer to the Store.
|
||||
// Each Account has a list of pre-authorized SetupKey and if no Account has a given key err with a code status.PermissionDenied
|
||||
// will be returned, meaning the setup key is invalid or not found.
|
||||
@@ -689,12 +596,70 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
|
||||
var newPeer *nbpeer.Peer
|
||||
|
||||
peerAddConfig, err := am.processPeerAddAuth(ctx, accountID, userID, encodedHashedKey, peer, temporary, addedByUser, addedBySetupKey, opEvent)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
var setupKeyID string
|
||||
var setupKeyName string
|
||||
var groupsToAdd []string
|
||||
var allowExtraDNSLabels bool
|
||||
ephemeral := peer.Ephemeral
|
||||
switch {
|
||||
case addedByUser:
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: user not found")
|
||||
}
|
||||
if user.PendingApproval {
|
||||
return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers")
|
||||
}
|
||||
if temporary {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Create)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, nil, nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
} else {
|
||||
accountID = user.AccountID
|
||||
groupsToAdd = user.AutoGroups
|
||||
}
|
||||
opEvent.InitiatorID = userID
|
||||
opEvent.Activity = activity.PeerAddedByUser
|
||||
case addedBySetupKey:
|
||||
// Validate the setup key
|
||||
sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
// we will check key twice for early return
|
||||
if !sk.IsValid() {
|
||||
return nil, nil, nil, status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
opEvent.InitiatorID = sk.Id
|
||||
opEvent.Activity = activity.PeerAddedWithSetupKey
|
||||
groupsToAdd = sk.AutoGroups
|
||||
ephemeral = sk.Ephemeral
|
||||
setupKeyID = sk.Id
|
||||
setupKeyName = sk.Name
|
||||
allowExtraDNSLabels = sk.AllowExtraDNSLabels
|
||||
accountID = sk.AccountID
|
||||
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
|
||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
|
||||
}
|
||||
default:
|
||||
if peer.ProxyMeta.Embedded {
|
||||
log.WithContext(ctx).Debugf("adding peer for proxy embedded, accountID: %s", accountID)
|
||||
} else {
|
||||
log.WithContext(ctx).Warnf("adding peer without setup key or userID, accountID: %s", accountID)
|
||||
}
|
||||
}
|
||||
opEvent.AccountID = accountID
|
||||
|
||||
if temporary {
|
||||
ephemeral = true
|
||||
}
|
||||
accountID = peerAddConfig.AccountID
|
||||
ephemeral := peerAddConfig.Ephemeral
|
||||
|
||||
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
||||
if am.idpManager != nil {
|
||||
@@ -728,7 +693,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
Location: peer.Location,
|
||||
InactivityExpirationEnabled: addedByUser && !temporary,
|
||||
ExtraDNSLabels: peer.ExtraDNSLabels,
|
||||
AllowExtraDNSLabels: peerAddConfig.AllowExtraDNSLabels,
|
||||
AllowExtraDNSLabels: allowExtraDNSLabels,
|
||||
}
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
@@ -746,7 +711,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
}
|
||||
}
|
||||
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, peerAddConfig.GroupsToAdd, settings.Extra, temporary)
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary)
|
||||
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
@@ -782,8 +747,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
return err
|
||||
}
|
||||
|
||||
if len(peerAddConfig.GroupsToAdd) > 0 {
|
||||
for _, g := range peerAddConfig.GroupsToAdd {
|
||||
if len(groupsToAdd) > 0 {
|
||||
for _, g := range groupsToAdd {
|
||||
err = transaction.AddPeerToGroup(ctx, newPeer.AccountID, newPeer.ID, g)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -815,7 +780,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
err = transaction.IncrementSetupKeyUsage(ctx, peerAddConfig.SetupKeyID)
|
||||
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment setup key usage: %w", err)
|
||||
}
|
||||
@@ -856,7 +821,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
opEvent.TargetID = newPeer.ID
|
||||
opEvent.Meta = newPeer.EventMeta(am.networkMapController.GetDNSDomain(settings))
|
||||
if !addedByUser {
|
||||
opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName
|
||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
@@ -2489,252 +2489,3 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
|
||||
_, _, _, err = manager.LoginPeer(context.Background(), login)
|
||||
require.NoError(t, err, "Regular user should be able to login peers")
|
||||
}
|
||||
|
||||
func TestHandleUserAddedPeer(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("regular user can add peer", func(t *testing.T) {
|
||||
regularUser := types.NewRegularUser("regular-user-1", "", "")
|
||||
regularUser.AccountID = account.Id
|
||||
regularUser.AutoGroups = []string{"group1", "group2"}
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
|
||||
err = manager.handleUserAddedPeer(context.Background(), account.Id, regularUser.Id, false, opEvent, config)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, account.Id, config.AccountID)
|
||||
assert.Equal(t, regularUser.AutoGroups, config.GroupsToAdd)
|
||||
assert.Equal(t, regularUser.Id, opEvent.InitiatorID)
|
||||
assert.Equal(t, activity.PeerAddedByUser, opEvent.Activity)
|
||||
})
|
||||
|
||||
t.Run("pending approval user cannot add peer", func(t *testing.T) {
|
||||
pendingUser := types.NewRegularUser("pending-user", "", "")
|
||||
pendingUser.AccountID = account.Id
|
||||
pendingUser.PendingApproval = true
|
||||
err = manager.Store.SaveUser(context.Background(), pendingUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
|
||||
err = manager.handleUserAddedPeer(context.Background(), account.Id, pendingUser.Id, false, opEvent, config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user pending approval cannot add peers")
|
||||
})
|
||||
|
||||
t.Run("user not found", func(t *testing.T) {
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
|
||||
err = manager.handleUserAddedPeer(context.Background(), account.Id, "non-existent-user", false, opEvent, config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user not found")
|
||||
})
|
||||
|
||||
t.Run("temporary peer requires permissions", func(t *testing.T) {
|
||||
regularUser := types.NewRegularUser("regular-user-2", "", "")
|
||||
regularUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
|
||||
// Should fail because user doesn't have permissions for temporary peers
|
||||
err = manager.handleUserAddedPeer(context.Background(), account.Id, regularUser.Id, true, opEvent, config)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleSetupKeyAddedPeer(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create admin user for setup key creation
|
||||
adminUser := types.NewAdminUser("admin-user")
|
||||
adminUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), adminUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("valid setup key", func(t *testing.T) {
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, []string{}, 0, adminUser.Id, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
upperKey := strings.ToUpper(setupKey.Key)
|
||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
peer := &nbpeer.Peer{ExtraDNSLabels: []string{}}
|
||||
|
||||
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, setupKey.Id, config.SetupKeyID)
|
||||
assert.Equal(t, setupKey.Name, config.SetupKeyName)
|
||||
assert.Equal(t, setupKey.AutoGroups, config.GroupsToAdd)
|
||||
assert.Equal(t, setupKey.Ephemeral, config.Ephemeral)
|
||||
assert.Equal(t, setupKey.Id, opEvent.InitiatorID)
|
||||
assert.Equal(t, activity.PeerAddedWithSetupKey, opEvent.Activity)
|
||||
})
|
||||
|
||||
t.Run("invalid setup key", func(t *testing.T) {
|
||||
invalidKey := "invalid-key"
|
||||
hashedKey := sha256.Sum256([]byte(invalidKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
peer := &nbpeer.Peer{}
|
||||
|
||||
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "setup key is invalid")
|
||||
})
|
||||
|
||||
t.Run("expired setup key", func(t *testing.T) {
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "expired-key", types.SetupKeyReusable, time.Millisecond, []string{}, 0, adminUser.Id, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for key to expire
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
upperKey := strings.ToUpper(setupKey.Key)
|
||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
peer := &nbpeer.Peer{}
|
||||
|
||||
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "setup key is invalid")
|
||||
})
|
||||
|
||||
t.Run("extra DNS labels not allowed", func(t *testing.T) {
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "no-dns-key", types.SetupKeyReusable, time.Hour, []string{}, 0, adminUser.Id, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
upperKey := strings.ToUpper(setupKey.Key)
|
||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
peer := &nbpeer.Peer{ExtraDNSLabels: []string{"custom.label"}}
|
||||
|
||||
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "doesn't allow extra DNS labels")
|
||||
})
|
||||
|
||||
t.Run("extra DNS labels allowed", func(t *testing.T) {
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "dns-key", types.SetupKeyReusable, time.Hour, []string{}, 0, adminUser.Id, false, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
upperKey := strings.ToUpper(setupKey.Key)
|
||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
opEvent := &activity.Event{}
|
||||
config := &peerAddAuthConfig{}
|
||||
peer := &nbpeer.Peer{ExtraDNSLabels: []string{"custom.label"}}
|
||||
|
||||
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, config.AllowExtraDNSLabels)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProcessPeerAddAuth(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
adminUser := types.NewAdminUser("admin")
|
||||
adminUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), adminUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("user authentication flow", func(t *testing.T) {
|
||||
regularUser := types.NewRegularUser("user-auth-test", "", "")
|
||||
regularUser.AccountID = account.Id
|
||||
regularUser.AutoGroups = []string{"group1"}
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
opEvent := &activity.Event{Timestamp: time.Now()}
|
||||
peer := &nbpeer.Peer{Ephemeral: false}
|
||||
|
||||
config, err := manager.processPeerAddAuth(context.Background(), account.Id, regularUser.Id, "", peer, false, true, false, opEvent)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, account.Id, config.AccountID)
|
||||
assert.False(t, config.Ephemeral)
|
||||
assert.Equal(t, regularUser.AutoGroups, config.GroupsToAdd)
|
||||
assert.Equal(t, account.Id, opEvent.AccountID)
|
||||
})
|
||||
|
||||
t.Run("setup key authentication flow", func(t *testing.T) {
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "auth-test-key", types.SetupKeyReusable, time.Hour, []string{}, 0, adminUser.Id, true, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
upperKey := strings.ToUpper(setupKey.Key)
|
||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
opEvent := &activity.Event{Timestamp: time.Now()}
|
||||
peer := &nbpeer.Peer{Ephemeral: false}
|
||||
|
||||
config, err := manager.processPeerAddAuth(context.Background(), account.Id, "", encodedHashedKey, peer, false, false, true, opEvent)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, account.Id, config.AccountID)
|
||||
assert.True(t, config.Ephemeral) // setupKey.Ephemeral is true
|
||||
assert.Equal(t, setupKey.AutoGroups, config.GroupsToAdd)
|
||||
assert.Equal(t, account.Id, opEvent.AccountID)
|
||||
})
|
||||
|
||||
t.Run("temporary flag overrides ephemeral", func(t *testing.T) {
|
||||
regularUser := types.NewRegularUser("temp-user", "", "")
|
||||
regularUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
opEvent := &activity.Event{Timestamp: time.Now()}
|
||||
peer := &nbpeer.Peer{Ephemeral: false}
|
||||
|
||||
config, err := manager.processPeerAddAuth(context.Background(), account.Id, regularUser.Id, "", peer, true, true, false, opEvent)
|
||||
require.Error(t, err) // Will fail permission check but that's expected
|
||||
_ = config // avoid unused warning
|
||||
})
|
||||
|
||||
t.Run("proxy embedded peer (no auth)", func(t *testing.T) {
|
||||
opEvent := &activity.Event{Timestamp: time.Now()}
|
||||
peer := &nbpeer.Peer{
|
||||
Ephemeral: false,
|
||||
ProxyMeta: nbpeer.ProxyMeta{Embedded: true},
|
||||
}
|
||||
|
||||
config, err := manager.processPeerAddAuth(context.Background(), account.Id, "", "", peer, false, false, false, opEvent)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, account.Id, config.AccountID)
|
||||
assert.False(t, config.Ephemeral)
|
||||
assert.Empty(t, config.GroupsToAdd)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2792,7 +2792,7 @@ func getGormConfig() *gorm.Config {
|
||||
|
||||
// newPostgresStore initializes a new Postgres store.
|
||||
func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics, skipMigration bool) (Store, error) {
|
||||
dsn, ok := lookupDSNEnv(postgresDsnEnv, postgresDsnEnvLegacy)
|
||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s is not set", postgresDsnEnv)
|
||||
}
|
||||
@@ -2801,7 +2801,7 @@ func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics, skipMig
|
||||
|
||||
// newMysqlStore initializes a new MySQL store.
|
||||
func newMysqlStore(ctx context.Context, metrics telemetry.AppMetrics, skipMigration bool) (Store, error) {
|
||||
dsn, ok := lookupDSNEnv(mysqlDsnEnv, mysqlDsnEnvLegacy)
|
||||
dsn, ok := os.LookupEnv(mysqlDsnEnv)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
package store
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -package store -destination=store_mock.go -source=./store.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
@@ -273,20 +271,10 @@ type Store interface {
|
||||
}
|
||||
|
||||
const (
|
||||
postgresDsnEnv = "NB_STORE_ENGINE_POSTGRES_DSN"
|
||||
postgresDsnEnvLegacy = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
mysqlDsnEnv = "NB_STORE_ENGINE_MYSQL_DSN"
|
||||
mysqlDsnEnvLegacy = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
|
||||
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
|
||||
)
|
||||
|
||||
// lookupDSNEnv checks the NB_ env var first, then falls back to the legacy NETBIRD_ env var.
|
||||
func lookupDSNEnv(nbKey, legacyKey string) (string, bool) {
|
||||
if v, ok := os.LookupEnv(nbKey); ok {
|
||||
return v, true
|
||||
}
|
||||
return os.LookupEnv(legacyKey)
|
||||
}
|
||||
|
||||
var supportedEngines = []types.Engine{types.SqliteStoreEngine, types.PostgresStoreEngine, types.MysqlStoreEngine}
|
||||
|
||||
func getStoreEngineFromEnv() types.Engine {
|
||||
@@ -571,7 +559,7 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine)
|
||||
}
|
||||
|
||||
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
|
||||
dsn, ok := lookupDSNEnv(postgresDsnEnv, postgresDsnEnvLegacy)
|
||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok || dsn == "" {
|
||||
var err error
|
||||
_, dsn, err = testutil.CreatePostgresTestContainer()
|
||||
@@ -609,7 +597,7 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng
|
||||
}
|
||||
|
||||
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
|
||||
dsn, ok := lookupDSNEnv(mysqlDsnEnv, mysqlDsnEnvLegacy)
|
||||
dsn, ok := os.LookupEnv(mysqlDsnEnv)
|
||||
if !ok || dsn == "" {
|
||||
var err error
|
||||
_, dsn, err = testutil.CreateMysqlTestContainer()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -122,7 +122,6 @@ type defaultAppMetrics struct {
|
||||
Meter metric2.Meter
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
externallyManaged bool
|
||||
idpMetrics *IDPMetrics
|
||||
httpMiddleware *HTTPMiddleware
|
||||
grpcMetrics *GRPCMetrics
|
||||
@@ -172,9 +171,6 @@ func (appMetrics *defaultAppMetrics) Close() error {
|
||||
// Expose metrics on a given port and endpoint. If endpoint is empty a defaultEndpoint one will be used.
|
||||
// Exposes metrics in the Prometheus format https://prometheus.io/
|
||||
func (appMetrics *defaultAppMetrics) Expose(ctx context.Context, port int, endpoint string) error {
|
||||
if appMetrics.externallyManaged {
|
||||
return nil
|
||||
}
|
||||
if endpoint == "" {
|
||||
endpoint = defaultEndpoint
|
||||
}
|
||||
@@ -256,49 +252,3 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
|
||||
accountManagerMetrics: accountManagerMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewAppMetricsWithMeter creates AppMetrics using an externally provided meter.
|
||||
// The caller is responsible for exposing metrics via HTTP. Expose() and Close() are no-ops.
|
||||
func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetrics, error) {
|
||||
idpMetrics, err := NewIDPMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize IDP metrics: %w", err)
|
||||
}
|
||||
|
||||
middleware, err := NewMetricsMiddleware(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize HTTP middleware metrics: %w", err)
|
||||
}
|
||||
|
||||
grpcMetrics, err := NewGRPCMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize gRPC metrics: %w", err)
|
||||
}
|
||||
|
||||
storeMetrics, err := NewStoreMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize store metrics: %w", err)
|
||||
}
|
||||
|
||||
updateChannelMetrics, err := NewUpdateChannelMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize update channel metrics: %w", err)
|
||||
}
|
||||
|
||||
accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
|
||||
}
|
||||
|
||||
return &defaultAppMetrics{
|
||||
Meter: meter,
|
||||
ctx: ctx,
|
||||
externallyManaged: true,
|
||||
idpMetrics: idpMetrics,
|
||||
httpMiddleware: middleware,
|
||||
grpcMetrics: grpcMetrics,
|
||||
storeMetrics: storeMetrics,
|
||||
updateChannelMetrics: updateChannelMetrics,
|
||||
accountManagerMetrics: accountManagerMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -374,6 +374,74 @@ func (a *Account) GetPeerNetworkMap(
|
||||
return nm
|
||||
}
|
||||
|
||||
// GetProxyConnectionResources returns ACL peers for the proxy-embedded peer based on exposed services.
|
||||
// No firewall rules are generated here; the proxy peer is always a new on-demand client with a stateful
|
||||
// firewall, so OUT rules are unnecessary. Inbound rules are handled on the target/router peer side.
|
||||
func (a *Account) GetProxyConnectionResources(ctx context.Context, exposedServices map[string][]*reverseproxy.Service) []*nbpeer.Peer {
|
||||
var aclPeers []*nbpeer.Peer
|
||||
|
||||
for _, peerServices := range exposedServices {
|
||||
for _, service := range peerServices {
|
||||
if !service.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, target := range service.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
}
|
||||
if target.TargetType == reverseproxy.TargetTypePeer {
|
||||
tpeer := a.GetPeer(target.TargetId)
|
||||
if tpeer == nil {
|
||||
continue
|
||||
}
|
||||
aclPeers = append(aclPeers, tpeer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return aclPeers
|
||||
}
|
||||
|
||||
// GetPeerProxyResources returns ACL peers and inbound firewall rules for a peer that is targeted by reverse proxy services.
|
||||
// Only IN rules are generated; OUT rules are omitted since proxy peers are always new clients with stateful firewalls.
|
||||
// Rules use PortRange only (not the legacy Port field) as this feature only targets current peer versions.
|
||||
func (a *Account) GetPeerProxyResources(peerID string, services []*reverseproxy.Service, proxyPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
var aclPeers []*nbpeer.Peer
|
||||
var firewallRules []*FirewallRule
|
||||
|
||||
for _, service := range services {
|
||||
if !service.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, target := range service.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
aclPeers = proxyPeers
|
||||
|
||||
needsPeerRules := (target.TargetType == reverseproxy.TargetTypePeer && target.TargetId == peerID) ||
|
||||
(target.TargetType == reverseproxy.TargetTypeHost || target.TargetType == reverseproxy.TargetTypeSubnet || target.TargetType == reverseproxy.TargetTypeDomain)
|
||||
|
||||
if needsPeerRules {
|
||||
for _, proxyPeer := range proxyPeers {
|
||||
firewallRules = append(firewallRules, &FirewallRule{
|
||||
PolicyID: "proxy-" + service.ID,
|
||||
PeerIP: proxyPeer.IP.String(),
|
||||
Direction: FirewallRuleDirectionIN,
|
||||
Action: "allow",
|
||||
Protocol: string(PolicyRuleProtocolTCP),
|
||||
PortRange: RulePortRange{Start: uint16(target.Port), End: uint16(target.Port)},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return aclPeers, firewallRules
|
||||
}
|
||||
|
||||
func (a *Account) addNetworksRoutingPeers(
|
||||
networkResourcesRoutes []*route.Route,
|
||||
peer *nbpeer.Peer,
|
||||
@@ -1796,6 +1864,71 @@ func (a *Account) GetProxyPeers() map[string][]*nbpeer.Peer {
|
||||
return proxyPeers
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerProxyRoutes(ctx context.Context, peer *nbpeer.Peer, proxies map[string][]*reverseproxy.Service, resourcesMap map[string]*resourceTypes.NetworkResource, routers map[string]map[string]*routerTypes.NetworkRouter, proxyPeers []*nbpeer.Peer) ([]*route.Route, []*RouteFirewallRule, []*nbpeer.Peer) {
|
||||
sourceRanges := make([]string, 0, len(proxyPeers))
|
||||
for _, proxyPeer := range proxyPeers {
|
||||
sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, proxyPeer.IP))
|
||||
}
|
||||
peers := make(map[string]*nbpeer.Peer, len(resourcesMap))
|
||||
|
||||
var routes []*route.Route
|
||||
var firewallRules []*RouteFirewallRule
|
||||
for _, proxyPerResource := range proxies {
|
||||
for _, proxy := range proxyPerResource {
|
||||
for _, target := range proxy.Targets {
|
||||
if target.TargetType == reverseproxy.TargetTypeHost || target.TargetType == reverseproxy.TargetTypeSubnet || target.TargetType == reverseproxy.TargetTypeDomain {
|
||||
resource, ok := resourcesMap[target.TargetId]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Warnf("proxy target %s not found in resources map", target.TargetId)
|
||||
continue
|
||||
}
|
||||
networkRouters, ok := routers[resource.NetworkID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Warnf("proxy target %s not found in routers map", target.TargetId)
|
||||
continue
|
||||
}
|
||||
for peerID, router := range networkRouters {
|
||||
routePeer := a.GetPeer(peerID)
|
||||
route := resource.ToRoute(routePeer, router)
|
||||
routes = append(routes, route)
|
||||
rule := RouteFirewallRule{
|
||||
PolicyID: fmt.Sprintf("proxy-%s-%s", proxy.ID, route.ID),
|
||||
RouteID: route.ID,
|
||||
SourceRanges: sourceRanges,
|
||||
Action: string(PolicyTrafficActionAccept),
|
||||
Destination: route.Network.String(),
|
||||
Protocol: string(PolicyRuleProtocolTCP),
|
||||
Domains: route.Domains,
|
||||
IsDynamic: route.IsDynamic(),
|
||||
PortRange: RulePortRange{
|
||||
Start: uint16(target.Port),
|
||||
End: uint16(target.Port),
|
||||
},
|
||||
}
|
||||
firewallRules = append(firewallRules, &rule)
|
||||
peers[peerID] = routePeer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resultPeers := make([]*nbpeer.Peer, 0, len(peers))
|
||||
for _, peer := range peers {
|
||||
resultPeers = append(resultPeers, peer)
|
||||
}
|
||||
|
||||
return routes, firewallRules, resultPeers
|
||||
}
|
||||
|
||||
func (a *Account) GetResourcesMap() map[string]*resourceTypes.NetworkResource {
|
||||
resourcesMap := make(map[string]*resourceTypes.NetworkResource, len(a.NetworkResources))
|
||||
for _, resource := range a.NetworkResources {
|
||||
resourcesMap[resource.ID] = resource
|
||||
}
|
||||
return resourcesMap
|
||||
}
|
||||
|
||||
func (a *Account) InjectProxyPolicies(ctx context.Context) {
|
||||
if len(a.Services) == 0 {
|
||||
return
|
||||
@@ -1810,83 +1943,62 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) {
|
||||
if !service.Enabled {
|
||||
continue
|
||||
}
|
||||
a.injectServiceProxyPolicies(ctx, service, proxyPeersByCluster)
|
||||
}
|
||||
}
|
||||
for _, target := range service.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
|
||||
for _, target := range service.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
}
|
||||
a.injectTargetProxyPolicies(ctx, service, target, proxyPeersByCluster[service.ProxyCluster])
|
||||
}
|
||||
}
|
||||
for _, proxyPeer := range proxyPeersByCluster[service.ProxyCluster] {
|
||||
port := target.Port
|
||||
if port == 0 {
|
||||
switch target.Protocol {
|
||||
case "https":
|
||||
port = 443
|
||||
case "http":
|
||||
port = 80
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("unsupported protocol %s for proxy target %s, skipping policy injection", target.Protocol, target.TargetId)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) {
|
||||
port, ok := a.resolveTargetPort(ctx, target)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
path := ""
|
||||
if target.Path != nil {
|
||||
path = *target.Path
|
||||
}
|
||||
|
||||
for _, proxyPeer := range proxyPeers {
|
||||
policy := a.createProxyPolicy(service, target, proxyPeer, port, path)
|
||||
a.Policies = append(a.Policies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) {
|
||||
if target.Port != 0 {
|
||||
return target.Port, true
|
||||
}
|
||||
|
||||
switch target.Protocol {
|
||||
case "https":
|
||||
return 443, true
|
||||
case "http":
|
||||
return 80, true
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("unsupported protocol %s for proxy target %s, skipping policy injection", target.Protocol, target.TargetId)
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
|
||||
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
|
||||
return &Policy{
|
||||
ID: policyID,
|
||||
Name: fmt.Sprintf("Proxy Access to %s", service.Name),
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: policyID,
|
||||
PolicyID: policyID,
|
||||
Name: fmt.Sprintf("Allow access to %s", service.Name),
|
||||
Enabled: true,
|
||||
SourceResource: Resource{
|
||||
ID: proxyPeer.ID,
|
||||
Type: ResourceTypePeer,
|
||||
},
|
||||
DestinationResource: Resource{
|
||||
ID: target.TargetId,
|
||||
Type: ResourceType(target.TargetType),
|
||||
},
|
||||
Bidirectional: false,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
PortRanges: []RulePortRange{
|
||||
{
|
||||
Start: uint16(port),
|
||||
End: uint16(port),
|
||||
path := ""
|
||||
if target.Path != nil {
|
||||
path = *target.Path
|
||||
}
|
||||
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
|
||||
a.Policies = append(a.Policies, &Policy{
|
||||
ID: policyID,
|
||||
Name: fmt.Sprintf("Proxy Access to %s", service.Name),
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: policyID,
|
||||
PolicyID: policyID,
|
||||
Name: fmt.Sprintf("Allow access to %s", service.Name),
|
||||
Enabled: true,
|
||||
SourceResource: Resource{
|
||||
ID: proxyPeer.ID,
|
||||
Type: ResourceTypePeer,
|
||||
},
|
||||
DestinationResource: Resource{
|
||||
ID: target.TargetId,
|
||||
Type: ResourceType(target.TargetType),
|
||||
},
|
||||
Bidirectional: false,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
PortRanges: []RulePortRange{
|
||||
{
|
||||
Start: uint16(port),
|
||||
End: uint16(port),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,19 +1,25 @@
|
||||
FROM golang:1.25-alpine AS builder
|
||||
WORKDIR /app
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o netbird-proxy ./proxy/cmd/proxy
|
||||
|
||||
RUN echo "netbird:x:1000:1000:netbird:/var/lib/netbird:/sbin/nologin" > /tmp/passwd && \
|
||||
echo "netbird:x:1000:netbird" > /tmp/group && \
|
||||
mkdir -p /tmp/var/lib/netbird && \
|
||||
mkdir -p /tmp/certs
|
||||
|
||||
FROM gcr.io/distroless/base:debug
|
||||
COPY netbird-proxy /go/bin/netbird-proxy
|
||||
COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy
|
||||
COPY --from=builder /tmp/passwd /etc/passwd
|
||||
COPY --from=builder /tmp/group /etc/group
|
||||
COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird
|
||||
COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs
|
||||
COPY --from=builder --chown=1000:1000 /tmp/certs /certs
|
||||
USER netbird:netbird
|
||||
ENV HOME=/var/lib/netbird
|
||||
ENV NB_PROXY_ADDRESS=":8443"
|
||||
EXPOSE 8443
|
||||
ENTRYPOINT ["/go/bin/netbird-proxy"]
|
||||
ENTRYPOINT ["/usr/bin/netbird-proxy"]
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
FROM golang:1.25-alpine AS builder
|
||||
WORKDIR /app
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY client ./client
|
||||
COPY dns ./dns
|
||||
COPY encryption ./encryption
|
||||
COPY flow ./flow
|
||||
COPY formatter ./formatter
|
||||
COPY monotime ./monotime
|
||||
COPY proxy ./proxy
|
||||
COPY route ./route
|
||||
COPY shared ./shared
|
||||
COPY sharedsock ./sharedsock
|
||||
COPY upload-server ./upload-server
|
||||
COPY util ./util
|
||||
COPY version ./version
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o netbird-proxy ./proxy/cmd/proxy
|
||||
|
||||
RUN echo "netbird:x:1000:1000:netbird:/var/lib/netbird:/sbin/nologin" > /tmp/passwd && \
|
||||
echo "netbird:x:1000:netbird" > /tmp/group && \
|
||||
mkdir -p /tmp/var/lib/netbird && \
|
||||
mkdir -p /tmp/certs
|
||||
|
||||
FROM gcr.io/distroless/base:debug
|
||||
COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy
|
||||
COPY --from=builder /tmp/passwd /etc/passwd
|
||||
COPY --from=builder /tmp/group /etc/group
|
||||
COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird
|
||||
COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs
|
||||
USER netbird:netbird
|
||||
ENV HOME=/var/lib/netbird
|
||||
ENV NB_PROXY_ADDRESS=":8443"
|
||||
EXPOSE 8443
|
||||
ENTRYPOINT ["/usr/bin/netbird-proxy"]
|
||||
@@ -6,14 +6,14 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/acme"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy"
|
||||
nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@@ -39,20 +39,23 @@ var (
|
||||
addr string
|
||||
proxyDomain string
|
||||
certDir string
|
||||
acmeCerts bool
|
||||
acmeAddr string
|
||||
acmeDir string
|
||||
acmeChallengeType string
|
||||
acmeCerts bool
|
||||
acmeAddr string
|
||||
acmeDir string
|
||||
acmeChallengeType string
|
||||
debugEndpoint bool
|
||||
debugEndpointAddr string
|
||||
healthAddr string
|
||||
oidcClientID string
|
||||
oidcClientSecret string
|
||||
oidcEndpoint string
|
||||
oidcScopes string
|
||||
forwardedProto string
|
||||
trustedProxies string
|
||||
certFile string
|
||||
certKeyFile string
|
||||
certLockMethod string
|
||||
wgPort int
|
||||
proxyProtocol bool
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
@@ -77,13 +80,16 @@ func init() {
|
||||
rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint")
|
||||
rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint")
|
||||
rootCmd.Flags().StringVar(&healthAddr, "health-addr", envStringOrDefault("NB_PROXY_HEALTH_ADDRESS", "localhost:8080"), "Address for the health probe endpoint (liveness/readiness/startup)")
|
||||
rootCmd.Flags().StringVar(&oidcClientID, "oidc-id", envStringOrDefault("NB_PROXY_OIDC_CLIENT_ID", "netbird-proxy"), "The OAuth2 Client ID for OIDC User Authentication")
|
||||
rootCmd.Flags().StringVar(&oidcClientSecret, "oidc-secret", envStringOrDefault("NB_PROXY_OIDC_CLIENT_SECRET", ""), "The OAuth2 Client Secret for OIDC User Authentication")
|
||||
rootCmd.Flags().StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication")
|
||||
rootCmd.Flags().StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated")
|
||||
rootCmd.Flags().StringVar(&forwardedProto, "forwarded-proto", envStringOrDefault("NB_PROXY_FORWARDED_PROTO", "auto"), "X-Forwarded-Proto value for backends: auto, http, or https")
|
||||
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", envStringOrDefault("NB_PROXY_TRUSTED_PROXIES", ""), "Comma-separated list of trusted upstream proxy CIDR ranges (e.g. '10.0.0.0/8,192.168.1.1')")
|
||||
rootCmd.Flags().StringVar(&certFile, "cert-file", envStringOrDefault("NB_PROXY_CERTIFICATE_FILE", "tls.crt"), "TLS certificate filename within the certificate directory")
|
||||
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
|
||||
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
|
||||
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
|
||||
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
|
||||
}
|
||||
|
||||
// Execute runs the root command.
|
||||
@@ -117,7 +123,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
|
||||
_ = util.InitLogger(logger, level, util.LogConsole)
|
||||
|
||||
logger.Infof("configured log level: %s", level)
|
||||
log.Infof("configured log level: %s", level)
|
||||
|
||||
switch forwardedProto {
|
||||
case "auto", "http", "https":
|
||||
@@ -151,18 +157,21 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
DebugEndpointEnabled: debugEndpoint,
|
||||
DebugEndpointAddress: debugEndpointAddr,
|
||||
HealthAddress: healthAddr,
|
||||
OIDCClientId: oidcClientID,
|
||||
OIDCClientSecret: oidcClientSecret,
|
||||
OIDCEndpoint: oidcEndpoint,
|
||||
OIDCScopes: strings.Split(oidcScopes, ","),
|
||||
ForwardedProto: forwardedProto,
|
||||
TrustedProxies: parsedTrustedProxies,
|
||||
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
||||
WireguardPort: wgPort,
|
||||
ProxyProtocol: proxyProtocol,
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||
defer stop()
|
||||
|
||||
if err := srv.ListenAndServe(ctx, addr); err != nil {
|
||||
logger.Error(err)
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
@@ -28,8 +27,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
|
||||
// Use a response writer wrapper so we can access the status code later.
|
||||
sw := &statusWriter{
|
||||
PassthroughWriter: responsewriter.New(w),
|
||||
status: http.StatusOK,
|
||||
w: w,
|
||||
status: http.StatusOK,
|
||||
}
|
||||
|
||||
// Resolve the source IP using trusted proxy configuration before passing
|
||||
|
||||
@@ -1,18 +1,26 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// statusWriter captures the HTTP status code from WriteHeader calls.
|
||||
// It embeds responsewriter.PassthroughWriter which handles all the optional
|
||||
// interfaces (Hijacker, Flusher, Pusher) automatically.
|
||||
// statusWriter is a simple wrapper around an http.ResponseWriter
|
||||
// that captures the setting of the status code via the WriteHeader
|
||||
// function and stores it so that it can be retrieved later.
|
||||
type statusWriter struct {
|
||||
*responsewriter.PassthroughWriter
|
||||
w http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *statusWriter) Header() http.Header {
|
||||
return w.w.Header()
|
||||
}
|
||||
|
||||
func (w *statusWriter) Write(data []byte) (int, error) {
|
||||
return w.w.Write(data)
|
||||
}
|
||||
|
||||
func (w *statusWriter) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.PassthroughWriter.WriteHeader(status)
|
||||
w.w.WriteHeader(status)
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func (l *flockLocker) Lock(ctx context.Context, domain string) (func(), error) {
|
||||
|
||||
// nil lockFile means locking is not supported (non-unix).
|
||||
if lockFile == nil {
|
||||
return func() { /* no-op: locking unsupported on this platform */ }, nil
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
return func() {
|
||||
@@ -98,5 +98,5 @@ type noopLocker struct{}
|
||||
|
||||
// Lock is a no-op that always succeeds immediately.
|
||||
func (noopLocker) Lock(context.Context, string) (func(), error) {
|
||||
return func() { /* no-op: locker disabled */ }, nil
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
@@ -90,8 +90,10 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
if err != nil {
|
||||
host = r.Host
|
||||
}
|
||||
mw.domainsMux.RLock()
|
||||
config, exists := mw.domains[host]
|
||||
mw.domainsMux.RUnlock()
|
||||
|
||||
config, exists := mw.getDomainConfig(host)
|
||||
mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists)
|
||||
|
||||
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
|
||||
@@ -101,160 +103,115 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// Set account and service IDs in captured data for access logging.
|
||||
setCapturedIDs(r, config)
|
||||
|
||||
if mw.handleOAuthCallbackError(w, r) {
|
||||
return
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetAccountId(types.AccountID(config.AccountID))
|
||||
cd.SetServiceId(config.ServiceID)
|
||||
}
|
||||
|
||||
if mw.forwardWithSessionCookie(w, r, host, config, next) {
|
||||
return
|
||||
}
|
||||
|
||||
mw.authenticateWithSchemes(w, r, host, config)
|
||||
})
|
||||
}
|
||||
|
||||
func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
|
||||
mw.domainsMux.RLock()
|
||||
defer mw.domainsMux.RUnlock()
|
||||
config, exists := mw.domains[host]
|
||||
return config, exists
|
||||
}
|
||||
|
||||
func setCapturedIDs(r *http.Request, config DomainConfig) {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetAccountId(types.AccountID(config.AccountID))
|
||||
cd.SetServiceId(config.ServiceID)
|
||||
}
|
||||
}
|
||||
|
||||
// handleOAuthCallbackError checks for error query parameters from an OAuth
|
||||
// callback and renders the access denied page if present.
|
||||
func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool {
|
||||
errCode := r.URL.Query().Get("error")
|
||||
if errCode == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(auth.MethodOIDC.String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
errDesc := r.URL.Query().Get("error_description")
|
||||
if errDesc == "" {
|
||||
errDesc = "An error occurred during authentication"
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
|
||||
return true
|
||||
}
|
||||
|
||||
// forwardWithSessionCookie checks for a valid session cookie and, if found,
|
||||
// sets the user identity on the request context and forwards to the next handler.
|
||||
func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
|
||||
cookie, err := r.Cookie(auth.SessionCookieName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(userID)
|
||||
cd.SetAuthMethod(method)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
|
||||
// authenticateWithSchemes tries each configured auth scheme in order.
|
||||
// On success it sets a session cookie and redirects; on failure it renders the login page.
|
||||
func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) {
|
||||
methods := make(map[string]string)
|
||||
var attemptedMethod string
|
||||
|
||||
for _, scheme := range config.Schemes {
|
||||
token, promptData, err := scheme.Authenticate(r)
|
||||
if err != nil {
|
||||
mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err)
|
||||
// Check for error from OAuth callback (e.g., access denied)
|
||||
if errCode := r.URL.Query().Get("error"); errCode != "" {
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(auth.MethodOIDC.String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
|
||||
errDesc := r.URL.Query().Get("error_description")
|
||||
if errDesc == "" {
|
||||
errDesc = "An error occurred during authentication"
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
|
||||
return
|
||||
}
|
||||
|
||||
// Track if credentials were submitted but auth failed
|
||||
if token == "" && wasCredentialSubmitted(r, scheme.Type()) {
|
||||
attemptedMethod = scheme.Type().String()
|
||||
// Check for an existing session cookie (contains JWT)
|
||||
if cookie, err := r.Cookie(auth.SessionCookieName); err == nil {
|
||||
if userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey); err == nil {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(userID)
|
||||
cd.SetAuthMethod(method)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
mw.handleAuthenticatedToken(w, r, host, token, config, scheme)
|
||||
return
|
||||
}
|
||||
methods[scheme.Type().String()] = promptData
|
||||
}
|
||||
// Try to authenticate with each scheme.
|
||||
methods := make(map[string]string)
|
||||
var attemptedMethod string
|
||||
for _, scheme := range config.Schemes {
|
||||
token, promptData, err := scheme.Authenticate(r)
|
||||
if err != nil {
|
||||
mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
}
|
||||
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
if attemptedMethod != "" {
|
||||
cd.SetAuthMethod(attemptedMethod)
|
||||
}
|
||||
}
|
||||
web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized)
|
||||
}
|
||||
// Track if credentials were submitted but auth failed
|
||||
if token == "" && wasCredentialSubmitted(r, scheme.Type()) {
|
||||
attemptedMethod = scheme.Type().String()
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
|
||||
if err != nil {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !result.Valid {
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
|
||||
return
|
||||
}
|
||||
|
||||
expiration := config.SessionExpiration
|
||||
if expiration == 0 {
|
||||
expiration = auth.DefaultSessionExpiry
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: auth.SessionCookieName,
|
||||
Value: token,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(expiration.Seconds()),
|
||||
})
|
||||
|
||||
// Redirect instead of forwarding the auth POST to the backend.
|
||||
// The browser will follow with a GET carrying the new session cookie.
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
methods[scheme.Type().String()] = promptData
|
||||
}
|
||||
|
||||
// handleAuthenticatedToken validates the token, handles denied access, and on
|
||||
// success sets a session cookie and redirects to the original URL.
|
||||
func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Request, host, token string, config DomainConfig, scheme Scheme) {
|
||||
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
|
||||
if err != nil {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
if attemptedMethod != "" {
|
||||
cd.SetAuthMethod(attemptedMethod)
|
||||
}
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !result.Valid {
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
|
||||
return
|
||||
}
|
||||
|
||||
expiration := config.SessionExpiration
|
||||
if expiration == 0 {
|
||||
expiration = auth.DefaultSessionExpiry
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: auth.SessionCookieName,
|
||||
Value: token,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(expiration.Seconds()),
|
||||
web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
// Redirect instead of forwarding the auth POST to the backend.
|
||||
// The browser will follow with a GET carrying the new session cookie.
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// wasCredentialSubmitted checks if credentials were submitted for the given auth method.
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// trackedConn wraps a net.Conn and removes itself from the tracker on Close.
|
||||
type trackedConn struct {
|
||||
net.Conn
|
||||
tracker *HijackTracker
|
||||
}
|
||||
|
||||
func (c *trackedConn) Close() error {
|
||||
c.tracker.conns.Delete(c)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// trackingWriter wraps an http.ResponseWriter and intercepts Hijack calls
|
||||
// to replace the raw connection with a trackedConn that auto-deregisters.
|
||||
type trackingWriter struct {
|
||||
http.ResponseWriter
|
||||
tracker *HijackTracker
|
||||
}
|
||||
|
||||
func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := w.ResponseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
conn, buf, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tc := &trackedConn{Conn: conn, tracker: w.tracker}
|
||||
w.tracker.conns.Store(tc, struct{}{})
|
||||
return tc, buf, nil
|
||||
}
|
||||
|
||||
func (w *trackingWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *trackingWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// HijackTracker tracks connections that have been hijacked (e.g. WebSocket
|
||||
// upgrades). http.Server.Shutdown does not close hijacked connections, so
|
||||
// they must be tracked and closed explicitly during graceful shutdown.
|
||||
//
|
||||
// Use Middleware as the outermost HTTP middleware to ensure hijacked
|
||||
// connections are tracked and automatically deregistered when closed.
|
||||
type HijackTracker struct {
|
||||
conns sync.Map // net.Conn → struct{}
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that wraps the ResponseWriter so that
|
||||
// hijacked connections are tracked and automatically deregistered from the
|
||||
// tracker when closed. This should be the outermost middleware in the chain.
|
||||
func (t *HijackTracker) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r)
|
||||
})
|
||||
}
|
||||
|
||||
// CloseAll closes all tracked hijacked connections and returns the number
|
||||
// of connections that were closed.
|
||||
func (t *HijackTracker) CloseAll() int {
|
||||
var count int
|
||||
t.conns.Range(func(key, _ any) bool {
|
||||
if conn, ok := key.(net.Conn); ok {
|
||||
_ = conn.Close()
|
||||
count++
|
||||
}
|
||||
t.conns.Delete(key)
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
@@ -83,10 +83,6 @@ func (c *Client) printHealth(data map[string]any) {
|
||||
}
|
||||
}
|
||||
|
||||
c.printHealthClients(data)
|
||||
}
|
||||
|
||||
func (c *Client) printHealthClients(data map[string]any) {
|
||||
clients, ok := data["clients"].(map[string]any)
|
||||
if !ok || len(clients) == 0 {
|
||||
return
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPrintHealth_WithCertsAndClients(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
c := NewClient("localhost:8444", false, &buf)
|
||||
|
||||
data := map[string]any{
|
||||
"status": "ok",
|
||||
"uptime": "1h30m",
|
||||
"management_connected": true,
|
||||
"all_clients_healthy": true,
|
||||
"certs_total": float64(3),
|
||||
"certs_ready": float64(2),
|
||||
"certs_pending": float64(1),
|
||||
"certs_failed": float64(0),
|
||||
"certs_ready_domains": []any{"a.example.com", "b.example.com"},
|
||||
"certs_pending_domains": []any{"c.example.com"},
|
||||
"clients": map[string]any{
|
||||
"acc-1": map[string]any{
|
||||
"healthy": true,
|
||||
"management_connected": true,
|
||||
"signal_connected": true,
|
||||
"relays_connected": float64(1),
|
||||
"relays_total": float64(2),
|
||||
"peers_connected": float64(3),
|
||||
"peers_total": float64(5),
|
||||
"peers_p2p": float64(2),
|
||||
"peers_relayed": float64(1),
|
||||
"peers_degraded": float64(0),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
c.printHealth(data)
|
||||
out := buf.String()
|
||||
|
||||
assert.Contains(t, out, "Status: ok")
|
||||
assert.Contains(t, out, "Uptime: 1h30m")
|
||||
assert.Contains(t, out, "yes") // management_connected
|
||||
assert.Contains(t, out, "2 ready, 1 pending, 0 failed (3 total)")
|
||||
assert.Contains(t, out, "a.example.com")
|
||||
assert.Contains(t, out, "c.example.com")
|
||||
assert.Contains(t, out, "acc-1")
|
||||
}
|
||||
|
||||
func TestPrintHealth_Minimal(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
c := NewClient("localhost:8444", false, &buf)
|
||||
|
||||
data := map[string]any{
|
||||
"status": "ok",
|
||||
"uptime": "5m",
|
||||
"management_connected": false,
|
||||
"all_clients_healthy": false,
|
||||
}
|
||||
|
||||
c.printHealth(data)
|
||||
out := buf.String()
|
||||
|
||||
assert.Contains(t, out, "Status: ok")
|
||||
assert.Contains(t, out, "Uptime: 5m")
|
||||
assert.NotContains(t, out, "Certificates")
|
||||
assert.NotContains(t, out, "ACCOUNT ID")
|
||||
}
|
||||
@@ -17,11 +17,11 @@
|
||||
<h2>Client Control</h2>
|
||||
<div class="form-row">
|
||||
<div class="form-group">
|
||||
<span> </span>
|
||||
<label> </label>
|
||||
<button onclick="startClient()">Start</button>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<span> </span>
|
||||
<label> </label>
|
||||
<button onclick="stopClient()">Stop</button>
|
||||
</div>
|
||||
</div>
|
||||
@@ -30,7 +30,7 @@
|
||||
<h2>Log Level</h2>
|
||||
<div class="form-row">
|
||||
<div class="form-group">
|
||||
<label for="log-level">Level</label>
|
||||
<label>Level</label>
|
||||
<select id="log-level" style="width: 120px;">
|
||||
<option value="trace">trace</option>
|
||||
<option value="debug">debug</option>
|
||||
@@ -40,7 +40,7 @@
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<span> </span>
|
||||
<label> </label>
|
||||
<button onclick="setLogLevel()">Set Level</button>
|
||||
</div>
|
||||
</div>
|
||||
@@ -49,15 +49,15 @@
|
||||
<h2>TCP Ping</h2>
|
||||
<div class="form-row">
|
||||
<div class="form-group">
|
||||
<label for="tcp-host">Host</label>
|
||||
<label>Host</label>
|
||||
<input type="text" id="tcp-host" placeholder="100.0.0.1 or hostname.netbird.cloud" style="width: 300px;">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="tcp-port">Port</label>
|
||||
<label>Port</label>
|
||||
<input type="number" id="tcp-port" placeholder="80" style="width: 80px;">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<span> </span>
|
||||
<label> </label>
|
||||
<button onclick="doTcpPing()">Connect</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
// that no lock was acquired; callers must treat a nil file as "proceed
|
||||
// without lock" rather than "lock held by someone else."
|
||||
func Lock(_ context.Context, _ string) (*os.File, error) {
|
||||
return nil, nil //nolint:nilnil // intentional: nil file signals locking unsupported on this platform
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Unlock is a no-op on non-Unix platforms.
|
||||
|
||||
@@ -323,7 +323,7 @@ func NewServer(addr string, checker *Checker, logger *log.Logger, metricsHandler
|
||||
if metricsHandler != nil {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/metrics", metricsHandler)
|
||||
mux.Handle("/", handler)
|
||||
mux.Handle("/", checker.Handler())
|
||||
handler = mux
|
||||
}
|
||||
|
||||
|
||||
@@ -404,70 +404,3 @@ func TestChecker_Handler_Full(t *testing.T) {
|
||||
// Clients may be empty map when no clients exist.
|
||||
assert.Empty(t, resp.Clients)
|
||||
}
|
||||
|
||||
func TestChecker_SetShuttingDown(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
|
||||
assert.True(t, checker.ReadinessProbe(), "should be ready before shutdown")
|
||||
|
||||
checker.SetShuttingDown()
|
||||
|
||||
assert.False(t, checker.ReadinessProbe(), "should not be ready after shutdown")
|
||||
}
|
||||
|
||||
func TestChecker_Handler_Readiness_ShuttingDown(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
checker.SetShuttingDown()
|
||||
handler := checker.Handler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
|
||||
var resp ProbeResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
assert.Equal(t, "fail", resp.Status)
|
||||
}
|
||||
|
||||
func TestNewServer_WithMetricsHandler(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
|
||||
metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("metrics"))
|
||||
})
|
||||
|
||||
srv := NewServer(":0", checker, nil, metricsHandler)
|
||||
require.NotNil(t, srv)
|
||||
|
||||
// Verify health endpoint still works through the mux.
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.server.Handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Verify metrics endpoint is mounted.
|
||||
req = httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
srv.server.Handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "metrics", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestNewServer_WithoutMetricsHandler(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
|
||||
srv := NewServer(":0", checker, nil, nil)
|
||||
require.NotNil(t, srv)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.server.Handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
@@ -5,11 +5,9 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||
)
|
||||
|
||||
type Metrics struct {
|
||||
@@ -62,18 +60,18 @@ func New(reg prometheus.Registerer) *Metrics {
|
||||
}
|
||||
|
||||
type responseInterceptor struct {
|
||||
*responsewriter.PassthroughWriter
|
||||
http.ResponseWriter
|
||||
status int
|
||||
size int
|
||||
}
|
||||
|
||||
func (w *responseInterceptor) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.PassthroughWriter.WriteHeader(status)
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
func (w *responseInterceptor) Write(b []byte) (int, error) {
|
||||
size, err := w.PassthroughWriter.Write(b)
|
||||
size, err := w.ResponseWriter.Write(b)
|
||||
w.size += size
|
||||
return size, err
|
||||
}
|
||||
@@ -83,7 +81,7 @@ func (m *Metrics) Middleware(next http.Handler) http.Handler {
|
||||
m.requestsTotal.Inc()
|
||||
m.activeRequests.Inc()
|
||||
|
||||
interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)}
|
||||
interceptor := &responseInterceptor{ResponseWriter: w}
|
||||
|
||||
start := time.Now()
|
||||
next.ServeHTTP(interceptor, r)
|
||||
|
||||
@@ -53,9 +53,6 @@ func TestMetrics_RoundTripper(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
rt := m.RoundTripper(test.roundTripper)
|
||||
res, err := rt.RoundTrip(test.request)
|
||||
if res != nil && res.Body != nil {
|
||||
defer res.Body.Close()
|
||||
}
|
||||
if diff := cmp.Diff(test.err, err); diff != "" {
|
||||
t.Errorf("Incorrect error (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -320,8 +320,7 @@ func getRequestID(r *http.Request) string {
|
||||
// status code, and component status based on the error type.
|
||||
func classifyProxyError(err error) (title, message string, code int, status web.ErrorStatus) {
|
||||
switch {
|
||||
case errors.Is(err, context.DeadlineExceeded),
|
||||
isNetTimeout(err):
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
return "Request Timeout",
|
||||
"The request timed out while trying to reach the service. Please refresh the page and try again.",
|
||||
http.StatusGatewayTimeout,
|
||||
@@ -346,12 +345,6 @@ func classifyProxyError(err error) (title, message string, code int, status web.
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: false, Destination: false}
|
||||
|
||||
case errors.Is(err, roundtrip.ErrTooManyInflight):
|
||||
return "Service Overloaded",
|
||||
"The service is currently handling too many requests. Please try again shortly.",
|
||||
http.StatusServiceUnavailable,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
|
||||
case isConnectionRefused(err):
|
||||
return "Service Unavailable",
|
||||
"The connection to the service was refused. Please verify that the service is running and try again.",
|
||||
@@ -363,6 +356,12 @@ func classifyProxyError(err error) (title, message string, code int, status web.
|
||||
"The connection to the peer could not be established. Please ensure the peer is running and connected to the NetBird network.",
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
|
||||
case isNetTimeout(err):
|
||||
return "Request Timeout",
|
||||
"The request timed out while trying to reach the service. Please refresh the page and try again.",
|
||||
http.StatusGatewayTimeout,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
}
|
||||
|
||||
return "Connection Error",
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
package responsewriter
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// PassthroughWriter wraps an http.ResponseWriter and preserves optional
|
||||
// interfaces like Hijacker, Flusher, and Pusher by delegating to the underlying
|
||||
// ResponseWriter if it supports them.
|
||||
//
|
||||
// This is the standard pattern for Go middleware that needs to wrap ResponseWriter
|
||||
// while maintaining support for protocol upgrades (WebSocket), streaming (Flusher),
|
||||
// and HTTP/2 server push.
|
||||
type PassthroughWriter struct {
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
// New creates a new wrapper around the given ResponseWriter.
|
||||
func New(w http.ResponseWriter) *PassthroughWriter {
|
||||
return &PassthroughWriter{ResponseWriter: w}
|
||||
}
|
||||
|
||||
// Hijack implements http.Hijacker interface if the underlying ResponseWriter supports it.
|
||||
// This is required for WebSocket connections and other protocol upgrades.
|
||||
func (w *PassthroughWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacker, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
// Flush implements http.Flusher interface if the underlying ResponseWriter supports it.
|
||||
func (w *PassthroughWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Push implements http.Pusher interface if the underlying ResponseWriter supports it.
|
||||
func (w *PassthroughWriter) Push(target string, opts *http.PushOptions) error {
|
||||
if pusher, ok := w.ResponseWriter.(http.Pusher); ok {
|
||||
return pusher.Push(target, opts)
|
||||
}
|
||||
return http.ErrNotSupported
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying ResponseWriter.
|
||||
// This is required for http.ResponseController (Go 1.20+) to work correctly.
|
||||
func (w *PassthroughWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
@@ -24,9 +24,6 @@ import (
|
||||
|
||||
const deviceNamePrefix = "ingress-proxy-"
|
||||
|
||||
// backendKey identifies a backend by its host:port from the target URL.
|
||||
type backendKey = string
|
||||
|
||||
var (
|
||||
// ErrNoAccountID is returned when a request context is missing the account ID.
|
||||
ErrNoAccountID = errors.New("no account ID in request context")
|
||||
@@ -34,8 +31,6 @@ var (
|
||||
ErrNoPeerConnection = errors.New("no peer connection found")
|
||||
// ErrClientStartFailed is returned when the embedded client fails to start.
|
||||
ErrClientStartFailed = errors.New("client start failed")
|
||||
// ErrTooManyInflight is returned when the per-backend in-flight limit is reached.
|
||||
ErrTooManyInflight = errors.New("too many in-flight requests")
|
||||
)
|
||||
|
||||
// domainInfo holds metadata about a registered domain.
|
||||
@@ -43,11 +38,6 @@ type domainInfo struct {
|
||||
serviceID string
|
||||
}
|
||||
|
||||
type domainNotification struct {
|
||||
domain domain.Domain
|
||||
serviceID string
|
||||
}
|
||||
|
||||
// clientEntry holds an embedded NetBird client and tracks which domains use it.
|
||||
type clientEntry struct {
|
||||
client *embed.Client
|
||||
@@ -55,35 +45,6 @@ type clientEntry struct {
|
||||
domains map[domain.Domain]domainInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
// Per-backend in-flight limiting keyed by target host:port.
|
||||
// TODO: clean up stale entries when backend targets change.
|
||||
inflightMu sync.Mutex
|
||||
inflightMap map[backendKey]chan struct{}
|
||||
maxInflight int
|
||||
}
|
||||
|
||||
// acquireInflight attempts to acquire an in-flight slot for the given backend.
|
||||
// It returns a release function that must always be called, and true on success.
|
||||
func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bool) {
|
||||
noop := func() {}
|
||||
if e.maxInflight <= 0 {
|
||||
return noop, true
|
||||
}
|
||||
|
||||
e.inflightMu.Lock()
|
||||
sem, exists := e.inflightMap[backend]
|
||||
if !exists {
|
||||
sem = make(chan struct{}, e.maxInflight)
|
||||
e.inflightMap[backend] = sem
|
||||
}
|
||||
e.inflightMu.Unlock()
|
||||
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
return func() { <-sem }, true
|
||||
default:
|
||||
return noop, false
|
||||
}
|
||||
}
|
||||
|
||||
type statusNotifier interface {
|
||||
@@ -98,13 +59,12 @@ type managementClient interface {
|
||||
// backed by underlying NetBird connections.
|
||||
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
|
||||
type NetBird struct {
|
||||
mgmtAddr string
|
||||
proxyID string
|
||||
proxyAddr string
|
||||
wgPort int
|
||||
logger *log.Logger
|
||||
mgmtClient managementClient
|
||||
transportCfg transportConfig
|
||||
mgmtAddr string
|
||||
proxyID string
|
||||
proxyAddr string
|
||||
wgPort int
|
||||
logger *log.Logger
|
||||
mgmtClient managementClient
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[types.AccountID]*clientEntry
|
||||
@@ -154,30 +114,6 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
|
||||
return nil
|
||||
}
|
||||
|
||||
entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID)
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
// retry on the first request via RoundTrip.
|
||||
go n.runClientStartup(ctx, accountID, entry.client)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
||||
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": serviceID,
|
||||
@@ -185,7 +121,8 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
|
||||
privateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate wireguard private key: %w", err)
|
||||
n.clientsMux.Unlock()
|
||||
return fmt.Errorf("generate wireguard private key: %w", err)
|
||||
}
|
||||
publicKey := privateKey.PublicKey()
|
||||
|
||||
@@ -195,6 +132,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
"public_key": publicKey.String(),
|
||||
}).Debug("authenticating new proxy peer with management")
|
||||
|
||||
// Authenticate with management using the one-time token and send public key
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
|
||||
ServiceId: serviceID,
|
||||
AccountId: string(accountID),
|
||||
@@ -203,14 +141,16 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
Cluster: n.proxyAddr,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate proxy peer with management: %w", err)
|
||||
n.clientsMux.Unlock()
|
||||
return fmt.Errorf("authenticate proxy peer with management: %w", err)
|
||||
}
|
||||
if resp != nil && !resp.GetSuccess() {
|
||||
n.clientsMux.Unlock()
|
||||
errMsg := "unknown error"
|
||||
if resp.ErrorMessage != nil {
|
||||
errMsg = *resp.ErrorMessage
|
||||
}
|
||||
return nil, fmt.Errorf("proxy peer authentication failed: %s", errMsg)
|
||||
return fmt.Errorf("proxy peer authentication failed: %s", errMsg)
|
||||
}
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -236,80 +176,95 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
WireguardPort: &n.wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create netbird client: %w", err)
|
||||
n.clientsMux.Unlock()
|
||||
return fmt.Errorf("create netbird client: %w", err)
|
||||
}
|
||||
|
||||
// Create a transport using the client dialer. We do this instead of using
|
||||
// the client's HTTPClient to avoid issues with request validation that do
|
||||
// not work with reverse proxied requests.
|
||||
return &clientEntry{
|
||||
entry = &clientEntry{
|
||||
client: client,
|
||||
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
|
||||
transport: &http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: n.transportCfg.maxIdleConns,
|
||||
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
|
||||
IdleConnTimeout: n.transportCfg.idleConnTimeout,
|
||||
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
|
||||
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
|
||||
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
|
||||
WriteBufferSize: n.transportCfg.writeBufferSize,
|
||||
ReadBufferSize: n.transportCfg.readBufferSize,
|
||||
DisableCompression: n.transportCfg.disableCompression,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
started: false,
|
||||
inflightMap: make(map[backendKey]chan struct{}),
|
||||
maxInflight: n.transportCfg.maxInflight,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// runClientStartup starts the client and notifies registered domains on success.
|
||||
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Start(startCtx); err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
n.logger.WithField("account_id", accountID).Warn("netbird client start timed out, will retry on first request")
|
||||
} else {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Error("failed to start netbird client")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Mark client as started and collect domains to notify outside the lock.
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.started = true
|
||||
}
|
||||
var domainsToNotify []domainNotification
|
||||
if exists {
|
||||
for dom, info := range entry.domains {
|
||||
domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID})
|
||||
}
|
||||
createdAt: time.Now(),
|
||||
started: false,
|
||||
}
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
}
|
||||
for _, dn := range domainsToNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": dn.domain,
|
||||
}).WithError(err).Warn("failed to notify tunnel connection status")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": dn.domain,
|
||||
}).Info("notified management about tunnel connection")
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background, if this fails
|
||||
// then it is not ideal, but it isn't the end of the world because
|
||||
// we will try to start the client again before we use it.
|
||||
go func() {
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Start(startCtx); err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).Warn("netbird client start timed out, will retry on first request")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Error("failed to start netbird client")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Mark client as started and notify all registered domains
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.started = true
|
||||
}
|
||||
// Copy domain info while holding lock
|
||||
var domainsToNotify []struct {
|
||||
domain domain.Domain
|
||||
serviceID string
|
||||
}
|
||||
if exists {
|
||||
for dom, info := range entry.domains {
|
||||
domainsToNotify = append(domainsToNotify, struct {
|
||||
domain domain.Domain
|
||||
serviceID string
|
||||
}{domain: dom, serviceID: info.serviceID})
|
||||
}
|
||||
}
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// Notify all domains that they're connected
|
||||
if n.statusNotifier != nil {
|
||||
for _, domInfo := range domainsToNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(domInfo.domain), true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": domInfo.domain,
|
||||
}).WithError(err).Warn("failed to notify tunnel connection status")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": domInfo.domain,
|
||||
}).Info("notified management about tunnel connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePeer unregisters a domain from an account. The client is only stopped
|
||||
@@ -410,12 +365,6 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
transport := entry.transport
|
||||
n.clientsMux.RUnlock()
|
||||
|
||||
release, ok := entry.acquireInflight(req.URL.Host)
|
||||
defer release()
|
||||
if !ok {
|
||||
return nil, ErrTooManyInflight
|
||||
}
|
||||
|
||||
// Attempt to start the client, if the client is already running then
|
||||
// it will return an error that we ignore, if this hits a timeout then
|
||||
// this request is unprocessable.
|
||||
@@ -552,7 +501,6 @@ func NewNetBird(mgmtAddr, proxyID, proxyAddr string, wgPort int, logger *log.Log
|
||||
clients: make(map[types.AccountID]*clientEntry),
|
||||
statusNotifier: notifier,
|
||||
mgmtClient: mgmtClient,
|
||||
transportCfg: loadTransportConfig(logger),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package roundtrip
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -21,31 +20,6 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
type mockStatusNotifier struct {
|
||||
mu sync.Mutex
|
||||
statuses []statusCall
|
||||
}
|
||||
|
||||
type statusCall struct {
|
||||
accountID string
|
||||
serviceID string
|
||||
domain string
|
||||
connected bool
|
||||
}
|
||||
|
||||
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStatusNotifier) calls() []statusCall {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return append([]statusCall{}, m.statuses...)
|
||||
}
|
||||
|
||||
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
||||
// It uses an invalid management URL to prevent real connections.
|
||||
func mockNetBird() *NetBird {
|
||||
@@ -279,50 +253,3 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer connection found for account")
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
notifier := &mockStatusNotifier{}
|
||||
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first domain — creates a new client entry.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually mark client as started to simulate background startup completing.
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID].started = true
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
// Add second domain — should notify immediately since client is already started.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
|
||||
require.NoError(t, err)
|
||||
|
||||
calls := notifier.calls()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, string(accountID), calls[0].accountID)
|
||||
assert.Equal(t, "svc-2", calls[0].serviceID)
|
||||
assert.Equal(t, "domain2.test", calls[0].domain)
|
||||
assert.True(t, calls[0].connected)
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
notifier := &mockStatusNotifier{}
|
||||
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove one domain — client stays, but disconnection notification fires.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
|
||||
calls := notifier.calls()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "domain1.test", calls[0].domain)
|
||||
assert.False(t, calls[0].connected)
|
||||
}
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Environment variable names for tuning the backend HTTP transport.
|
||||
const (
|
||||
EnvMaxIdleConns = "NB_PROXY_MAX_IDLE_CONNS"
|
||||
EnvMaxIdleConnsPerHost = "NB_PROXY_MAX_IDLE_CONNS_PER_HOST"
|
||||
EnvMaxConnsPerHost = "NB_PROXY_MAX_CONNS_PER_HOST"
|
||||
EnvIdleConnTimeout = "NB_PROXY_IDLE_CONN_TIMEOUT"
|
||||
EnvTLSHandshakeTimeout = "NB_PROXY_TLS_HANDSHAKE_TIMEOUT"
|
||||
EnvExpectContinueTimeout = "NB_PROXY_EXPECT_CONTINUE_TIMEOUT"
|
||||
EnvResponseHeaderTimeout = "NB_PROXY_RESPONSE_HEADER_TIMEOUT"
|
||||
EnvWriteBufferSize = "NB_PROXY_WRITE_BUFFER_SIZE"
|
||||
EnvReadBufferSize = "NB_PROXY_READ_BUFFER_SIZE"
|
||||
EnvDisableCompression = "NB_PROXY_DISABLE_COMPRESSION"
|
||||
EnvMaxInflight = "NB_PROXY_MAX_INFLIGHT"
|
||||
)
|
||||
|
||||
// transportConfig holds tunable parameters for the per-account HTTP transport.
|
||||
type transportConfig struct {
|
||||
maxIdleConns int
|
||||
maxIdleConnsPerHost int
|
||||
maxConnsPerHost int
|
||||
idleConnTimeout time.Duration
|
||||
tlsHandshakeTimeout time.Duration
|
||||
expectContinueTimeout time.Duration
|
||||
responseHeaderTimeout time.Duration
|
||||
writeBufferSize int
|
||||
readBufferSize int
|
||||
disableCompression bool
|
||||
// maxInflight limits per-backend concurrent requests. 0 means unlimited.
|
||||
maxInflight int
|
||||
}
|
||||
|
||||
func defaultTransportConfig() transportConfig {
|
||||
return transportConfig{
|
||||
maxIdleConns: 100,
|
||||
maxIdleConnsPerHost: 100,
|
||||
maxConnsPerHost: 0, // unlimited
|
||||
idleConnTimeout: 90 * time.Second,
|
||||
tlsHandshakeTimeout: 10 * time.Second,
|
||||
expectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func loadTransportConfig(logger *log.Logger) transportConfig {
|
||||
cfg := defaultTransportConfig()
|
||||
|
||||
if v, ok := envInt(EnvMaxIdleConns, logger); ok {
|
||||
cfg.maxIdleConns = v
|
||||
}
|
||||
if v, ok := envInt(EnvMaxIdleConnsPerHost, logger); ok {
|
||||
cfg.maxIdleConnsPerHost = v
|
||||
}
|
||||
if v, ok := envInt(EnvMaxConnsPerHost, logger); ok {
|
||||
cfg.maxConnsPerHost = v
|
||||
}
|
||||
if v, ok := envDuration(EnvIdleConnTimeout, logger); ok {
|
||||
cfg.idleConnTimeout = v
|
||||
}
|
||||
if v, ok := envDuration(EnvTLSHandshakeTimeout, logger); ok {
|
||||
cfg.tlsHandshakeTimeout = v
|
||||
}
|
||||
if v, ok := envDuration(EnvExpectContinueTimeout, logger); ok {
|
||||
cfg.expectContinueTimeout = v
|
||||
}
|
||||
if v, ok := envDuration(EnvResponseHeaderTimeout, logger); ok {
|
||||
cfg.responseHeaderTimeout = v
|
||||
}
|
||||
if v, ok := envInt(EnvWriteBufferSize, logger); ok {
|
||||
cfg.writeBufferSize = v
|
||||
}
|
||||
if v, ok := envInt(EnvReadBufferSize, logger); ok {
|
||||
cfg.readBufferSize = v
|
||||
}
|
||||
if v, ok := envBool(EnvDisableCompression, logger); ok {
|
||||
cfg.disableCompression = v
|
||||
}
|
||||
if v, ok := envInt(EnvMaxInflight, logger); ok {
|
||||
cfg.maxInflight = v
|
||||
}
|
||||
|
||||
logger.WithFields(log.Fields{
|
||||
"max_idle_conns": cfg.maxIdleConns,
|
||||
"max_idle_conns_per_host": cfg.maxIdleConnsPerHost,
|
||||
"max_conns_per_host": cfg.maxConnsPerHost,
|
||||
"idle_conn_timeout": cfg.idleConnTimeout,
|
||||
"tls_handshake_timeout": cfg.tlsHandshakeTimeout,
|
||||
"expect_continue_timeout": cfg.expectContinueTimeout,
|
||||
"response_header_timeout": cfg.responseHeaderTimeout,
|
||||
"write_buffer_size": cfg.writeBufferSize,
|
||||
"read_buffer_size": cfg.readBufferSize,
|
||||
"disable_compression": cfg.disableCompression,
|
||||
"max_inflight": cfg.maxInflight,
|
||||
}).Debug("backend transport configuration")
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func envInt(key string, logger *log.Logger) (int, bool) {
|
||||
s := os.Getenv(key)
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
v, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
logger.Warnf("failed to parse %s=%q as int: %v", key, s, err)
|
||||
return 0, false
|
||||
}
|
||||
if v < 0 {
|
||||
logger.Warnf("ignoring negative value for %s=%d", key, v)
|
||||
return 0, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
func envDuration(key string, logger *log.Logger) (time.Duration, bool) {
|
||||
s := os.Getenv(key)
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
v, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
logger.Warnf("failed to parse %s=%q as duration: %v", key, s, err)
|
||||
return 0, false
|
||||
}
|
||||
if v < 0 {
|
||||
logger.Warnf("ignoring negative value for %s=%s", key, v)
|
||||
return 0, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
func envBool(key string, logger *log.Logger) (bool, bool) {
|
||||
s := os.Getenv(key)
|
||||
if s == "" {
|
||||
return false, false
|
||||
}
|
||||
v, err := strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
logger.Warnf("failed to parse %s=%q as bool: %v", key, s, err)
|
||||
return false, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
21
proxy/log.go
21
proxy/log.go
@@ -1,21 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
stdlog "log"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// HTTP server type identifiers for logging
|
||||
logtagFieldHTTPServer = "http-server"
|
||||
logtagValueHTTPS = "https"
|
||||
logtagValueACME = "acme"
|
||||
logtagValueDebug = "debug"
|
||||
)
|
||||
|
||||
// newHTTPServerLogger creates a standard library logger that writes to logrus
|
||||
// with the specified server type field.
|
||||
func newHTTPServerLogger(logger *log.Logger, serverType string) *stdlog.Logger {
|
||||
return stdlog.New(logger.WithField(logtagFieldHTTPServer, serverType).WriterLevel(log.WarnLevel), "", 0)
|
||||
}
|
||||
@@ -3,9 +3,9 @@ package proxy
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"errors"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -251,7 +251,7 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
|
||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "test-proxy-1",
|
||||
Version: "test-v1",
|
||||
Address: "test.proxy.io",
|
||||
Address: "https://test.proxy.io",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -293,7 +293,7 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
clusterAddress := "test.proxy.io"
|
||||
clusterAddress := "https://test.proxy.io"
|
||||
|
||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "test-proxy-cluster",
|
||||
@@ -328,7 +328,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
clusterAddress := "test.proxy.io"
|
||||
clusterAddress := "https://test.proxy.io"
|
||||
proxyID := "test-proxy-reconnect"
|
||||
|
||||
// Helper to receive all mappings from a stream
|
||||
@@ -401,7 +401,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
|
||||
authMw := auth.NewMiddleware(logger, nil)
|
||||
proxyHandler := proxy.NewReverseProxy(nil, "auto", nil, logger)
|
||||
|
||||
clusterAddress := "test.proxy.io"
|
||||
clusterAddress := "https://test.proxy.io"
|
||||
proxyID := "test-proxy-idempotent"
|
||||
|
||||
var addMappingCalls atomic.Int32
|
||||
@@ -497,7 +497,7 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
|
||||
setup := setupIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
clusterAddress := "test.proxy.io"
|
||||
clusterAddress := "https://test.proxy.io"
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
proxyproto "github.com/pires/go-proxyproto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWrapProxyProtocol_OverridesRemoteAddr(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")},
|
||||
ProxyProtocol: true,
|
||||
}
|
||||
|
||||
raw, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer raw.Close()
|
||||
|
||||
ln := srv.wrapProxyProtocol(raw)
|
||||
|
||||
realClientIP := "203.0.113.50"
|
||||
realClientPort := uint16(54321)
|
||||
|
||||
accepted := make(chan net.Conn, 1)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
accepted <- conn
|
||||
}()
|
||||
|
||||
// Connect and send a PROXY v2 header.
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
header := &proxyproto.Header{
|
||||
Version: 2,
|
||||
Command: proxyproto.PROXY,
|
||||
TransportProtocol: proxyproto.TCPv4,
|
||||
SourceAddr: &net.TCPAddr{IP: net.ParseIP(realClientIP), Port: int(realClientPort)},
|
||||
DestinationAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443},
|
||||
}
|
||||
_, err = header.WriteTo(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case accepted := <-accepted:
|
||||
defer accepted.Close()
|
||||
host, _, err := net.SplitHostPort(accepted.RemoteAddr().String())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, realClientIP, host, "RemoteAddr should reflect the PROXY header source IP")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_TrustedRequires(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 1234},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.REQUIRE, policy, "trusted source should require PROXY header")
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_UntrustedIgnores(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.TCPAddr{IP: net.ParseIP("203.0.113.50"), Port: 1234},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.IGNORE, policy, "untrusted source should have PROXY header ignored")
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_InvalidIPRejects(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.REJECT, policy, "unparsable address should be rejected")
|
||||
}
|
||||
473
proxy/server.go
473
proxy/server.go
@@ -23,7 +23,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/pires/go-proxyproto"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -37,7 +36,6 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
||||
"github.com/netbirdio/netbird/proxy/internal/conntrack"
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
@@ -65,11 +63,6 @@ type Server struct {
|
||||
healthChecker *health.Checker
|
||||
meter *metrics.Metrics
|
||||
|
||||
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
|
||||
// so they can be closed during graceful shutdown, since http.Server.Shutdown
|
||||
// does not handle them.
|
||||
hijackTracker conntrack.HijackTracker
|
||||
|
||||
// Mostly used for debugging on management.
|
||||
startTime time.Time
|
||||
|
||||
@@ -89,13 +82,17 @@ type Server struct {
|
||||
ACMEChallengeType string
|
||||
// CertLockMethod controls how ACME certificate locks are coordinated
|
||||
// across replicas. Default: CertLockAuto (detect environment).
|
||||
CertLockMethod acme.CertLockMethod
|
||||
CertLockMethod acme.CertLockMethod
|
||||
OIDCClientId string
|
||||
OIDCClientSecret string
|
||||
OIDCEndpoint string
|
||||
OIDCScopes []string
|
||||
|
||||
// DebugEndpointEnabled enables the debug HTTP endpoint.
|
||||
DebugEndpointEnabled bool
|
||||
// DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444").
|
||||
DebugEndpointAddress string
|
||||
// HealthAddress is the address for the health probe endpoint.
|
||||
// HealthAddress is the address for the health probe endpoint (default: "localhost:8080").
|
||||
HealthAddress string
|
||||
// ProxyToken is the access token for authenticating with the management server.
|
||||
ProxyToken string
|
||||
@@ -110,10 +107,6 @@ type Server struct {
|
||||
// random OS-assigned port. A fixed port only works with single-account
|
||||
// deployments; multiple accounts will fail to bind the same port.
|
||||
WireguardPort int
|
||||
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
|
||||
// When enabled, the real client IP is extracted from the PROXY header
|
||||
// sent by upstream L4 proxies that support PROXY protocol.
|
||||
ProxyProtocol bool
|
||||
}
|
||||
|
||||
// NotifyStatus sends a status update to management about tunnel connectivity
|
||||
@@ -144,95 +137,6 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service
|
||||
}
|
||||
|
||||
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
s.initDefaults()
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
s.meter = metrics.New(reg)
|
||||
|
||||
mgmtConn, err := s.dialManagement()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := mgmtConn.Close(); err != nil {
|
||||
s.Logger.Debugf("management connection close: %v", err)
|
||||
}
|
||||
}()
|
||||
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
|
||||
go s.newManagementMappingWorker(ctx, s.mgmtClient)
|
||||
|
||||
// Initialize the netbird client, this is required to build peer connections
|
||||
// to proxy over.
|
||||
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient)
|
||||
|
||||
tlsConfig, err := s.configureTLS(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
|
||||
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
|
||||
|
||||
// Configure the authentication middleware with session validator for OIDC group checks.
|
||||
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient)
|
||||
|
||||
// Configure Access logs to management server.
|
||||
accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
|
||||
|
||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||
|
||||
s.startDebugEndpoint()
|
||||
|
||||
if err := s.startHealthServer(reg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build the handler chain from inside out.
|
||||
handler := http.Handler(s.proxy)
|
||||
handler = s.auth.Protect(handler)
|
||||
handler = web.AssetHandler(handler)
|
||||
handler = accessLog.Middleware(handler)
|
||||
handler = s.meter.Middleware(handler)
|
||||
handler = s.hijackTracker.Middleware(handler)
|
||||
|
||||
// Start the reverse proxy HTTPS server.
|
||||
s.https = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
TLSConfig: tlsConfig,
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
|
||||
}
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
ln, err := lc.Listen(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", addr, err)
|
||||
}
|
||||
if s.ProxyProtocol {
|
||||
ln = s.wrapProxyProtocol(ln)
|
||||
}
|
||||
|
||||
httpsErr := make(chan error, 1)
|
||||
go func() {
|
||||
s.Logger.Debugf("starting reverse proxy server on %s", addr)
|
||||
httpsErr <- s.https.ServeTLS(ln, "", "")
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-httpsErr:
|
||||
s.shutdownServices()
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("https server: %w", err)
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
s.gracefulShutdown()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// initDefaults sets fallback values for optional Server fields.
|
||||
func (s *Server) initDefaults() {
|
||||
s.startTime = time.Now()
|
||||
|
||||
// If no ID is set then one can be generated.
|
||||
@@ -248,36 +152,141 @@ func (s *Server) initDefaults() {
|
||||
if s.Logger == nil {
|
||||
s.Logger = log.StandardLogger()
|
||||
}
|
||||
}
|
||||
|
||||
// startDebugEndpoint launches the debug HTTP server if enabled.
|
||||
func (s *Server) startDebugEndpoint() {
|
||||
if !s.DebugEndpointEnabled {
|
||||
return
|
||||
// Start up metrics gathering
|
||||
reg := prometheus.NewRegistry()
|
||||
s.meter = metrics.New(reg)
|
||||
|
||||
// The very first thing to do should be to connect to the Management server.
|
||||
// Without this connection, the Proxy cannot do anything.
|
||||
mgmtURL, err := url.Parse(s.ManagementAddress)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse management address: %w", err)
|
||||
}
|
||||
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
|
||||
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger)
|
||||
if s.acme != nil {
|
||||
debugHandler.SetCertStatus(s.acme)
|
||||
creds := insecure.NewCredentials()
|
||||
// Simple TLS check using management URL.
|
||||
// Assume management TLS is enabled for gRPC as well if using HTTPS for the API.
|
||||
if mgmtURL.Scheme == "https" {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
// Fall back to embedded CAs if no OS-provided ones are available.
|
||||
certPool = embeddedroots.Get()
|
||||
}
|
||||
|
||||
creds = credentials.NewTLS(&tls.Config{
|
||||
RootCAs: certPool,
|
||||
})
|
||||
}
|
||||
s.debug = &http.Server{
|
||||
Addr: debugAddr,
|
||||
Handler: debugHandler,
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug),
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"gRPC_address": mgmtURL.Host,
|
||||
"TLS_enabled": mgmtURL.Scheme == "https",
|
||||
}).Debug("starting management gRPC client")
|
||||
mgmtConn, err := grpc.NewClient(mgmtURL.Host,
|
||||
grpc.WithTransportCredentials(creds),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 20 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}),
|
||||
proxygrpc.WithProxyToken(s.ProxyToken),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create management connection: %w", err)
|
||||
}
|
||||
go func() {
|
||||
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
|
||||
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Errorf("debug endpoint error: %v", err)
|
||||
defer func() {
|
||||
if err := mgmtConn.Close(); err != nil {
|
||||
s.Logger.Debugf("management connection close: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
|
||||
go s.newManagementMappingWorker(ctx, s.mgmtClient)
|
||||
|
||||
// startHealthServer launches the health probe and metrics server.
|
||||
func (s *Server) startHealthServer(reg *prometheus.Registry) error {
|
||||
// Initialize the netbird client, this is required to build peer connections
|
||||
// to proxy over.
|
||||
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient)
|
||||
|
||||
// When generating ACME certificates, start a challenge server.
|
||||
tlsConfig := &tls.Config{}
|
||||
if s.GenerateACMECertificates {
|
||||
// Default to TLS-ALPN-01 challenge if not specified
|
||||
if s.ACMEChallengeType == "" {
|
||||
s.ACMEChallengeType = "tls-alpn-01"
|
||||
}
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"acme_server": s.ACMEDirectory,
|
||||
"challenge_type": s.ACMEChallengeType,
|
||||
}).Debug("ACME certificates enabled, configuring certificate manager")
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod)
|
||||
|
||||
// Only start HTTP server for HTTP-01 challenge type
|
||||
if s.ACMEChallengeType == "http-01" {
|
||||
s.http = &http.Server{
|
||||
Addr: s.ACMEChallengeAddress,
|
||||
Handler: s.acme.HTTPHandler(nil),
|
||||
}
|
||||
go func() {
|
||||
if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed")
|
||||
}
|
||||
}()
|
||||
}
|
||||
tlsConfig = s.acme.TLSConfig()
|
||||
|
||||
// ServerName needs to be set to allow for ACME to work correctly
|
||||
// when using CNAME URLs to access the proxy.
|
||||
tlsConfig.ServerName = s.ProxyURL
|
||||
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"ServerName": s.ProxyURL,
|
||||
"challenge_type": s.ACMEChallengeType,
|
||||
}).Debug("ACME certificate manager configured")
|
||||
} else {
|
||||
s.Logger.Debug("ACME certificates disabled, using static certificates with file watching")
|
||||
certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile)
|
||||
keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile)
|
||||
|
||||
certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize certificate watcher: %w", err)
|
||||
}
|
||||
go certWatcher.Watch(ctx)
|
||||
|
||||
tlsConfig.GetCertificate = certWatcher.GetCertificate
|
||||
}
|
||||
|
||||
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
|
||||
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
|
||||
|
||||
// Configure the authentication middleware with session validator for OIDC group checks.
|
||||
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient)
|
||||
|
||||
// Configure Access logs to management server.
|
||||
accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
|
||||
|
||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||
|
||||
if s.DebugEndpointEnabled {
|
||||
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
|
||||
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger)
|
||||
if s.acme != nil {
|
||||
debugHandler.SetCertStatus(s.acme)
|
||||
}
|
||||
s.debug = &http.Server{
|
||||
Addr: debugAddr,
|
||||
Handler: debugHandler,
|
||||
}
|
||||
go func() {
|
||||
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
|
||||
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Errorf("debug endpoint error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Start health probe server.
|
||||
healthAddr := s.HealthAddress
|
||||
if healthAddr == "" {
|
||||
healthAddr = defaultHealthAddr
|
||||
healthAddr = "localhost:8080"
|
||||
}
|
||||
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
||||
healthListener, err := net.Listen("tcp", healthAddr)
|
||||
@@ -289,57 +298,34 @@ func (s *Server) startHealthServer(reg *prometheus.Registry) error {
|
||||
s.Logger.Errorf("health probe server: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// wrapProxyProtocol wraps a listener with PROXY protocol support.
|
||||
// When TrustedProxies is configured, only those sources may send PROXY headers;
|
||||
// connections from untrusted sources have any PROXY header ignored.
|
||||
func (s *Server) wrapProxyProtocol(ln net.Listener) net.Listener {
|
||||
ppListener := &proxyproto.Listener{
|
||||
Listener: ln,
|
||||
ReadHeaderTimeout: proxyProtoHeaderTimeout,
|
||||
// Start the reverse proxy HTTPS server.
|
||||
s.https = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: s.meter.Middleware(accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy)))),
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
if len(s.TrustedProxies) > 0 {
|
||||
ppListener.ConnPolicy = s.proxyProtocolPolicy
|
||||
} else {
|
||||
s.Logger.Warn("PROXY protocol enabled without trusted proxies; any source may send PROXY headers")
|
||||
}
|
||||
s.Logger.Info("PROXY protocol enabled on listener")
|
||||
return ppListener
|
||||
}
|
||||
|
||||
// proxyProtocolPolicy returns whether to require, skip, or reject the PROXY
|
||||
// header based on whether the connection source is in TrustedProxies.
|
||||
func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) {
|
||||
// No logging on reject to prevent abuse
|
||||
tcpAddr, ok := opts.Upstream.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return proxyproto.REJECT, nil
|
||||
}
|
||||
addr, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||
if !ok {
|
||||
return proxyproto.REJECT, nil
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
httpsErr := make(chan error, 1)
|
||||
go func() {
|
||||
s.Logger.Debugf("starting reverse proxy server on %s", addr)
|
||||
httpsErr <- s.https.ListenAndServeTLS("", "")
|
||||
}()
|
||||
|
||||
// called per accept
|
||||
for _, prefix := range s.TrustedProxies {
|
||||
if prefix.Contains(addr) {
|
||||
return proxyproto.REQUIRE, nil
|
||||
select {
|
||||
case err := <-httpsErr:
|
||||
s.shutdownServices()
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("https server: %w", err)
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
s.gracefulShutdown()
|
||||
return nil
|
||||
}
|
||||
return proxyproto.IGNORE, nil
|
||||
}
|
||||
|
||||
const (
|
||||
defaultHealthAddr = "localhost:8080"
|
||||
defaultDebugAddr = "localhost:8444"
|
||||
|
||||
// proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol
|
||||
// header after accepting a connection.
|
||||
proxyProtoHeaderTimeout = 5 * time.Second
|
||||
|
||||
// shutdownPreStopDelay is the time to wait after receiving a shutdown signal
|
||||
// before draining connections. This allows the load balancer to propagate
|
||||
// the endpoint removal.
|
||||
@@ -354,92 +340,6 @@ const (
|
||||
shutdownServiceTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
func (s *Server) dialManagement() (*grpc.ClientConn, error) {
|
||||
mgmtURL, err := url.Parse(s.ManagementAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse management address: %w", err)
|
||||
}
|
||||
creds := insecure.NewCredentials()
|
||||
// Assume management TLS is enabled for gRPC as well if using HTTPS for the API.
|
||||
if mgmtURL.Scheme == "https" {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
// Fall back to embedded CAs if no OS-provided ones are available.
|
||||
certPool = embeddedroots.Get()
|
||||
}
|
||||
creds = credentials.NewTLS(&tls.Config{
|
||||
RootCAs: certPool,
|
||||
})
|
||||
}
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"gRPC_address": mgmtURL.Host,
|
||||
"TLS_enabled": mgmtURL.Scheme == "https",
|
||||
}).Debug("starting management gRPC client")
|
||||
conn, err := grpc.NewClient(mgmtURL.Host,
|
||||
grpc.WithTransportCredentials(creds),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 20 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}),
|
||||
proxygrpc.WithProxyToken(s.ProxyToken),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create management connection: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
|
||||
tlsConfig := &tls.Config{}
|
||||
if !s.GenerateACMECertificates {
|
||||
s.Logger.Debug("ACME certificates disabled, using static certificates with file watching")
|
||||
certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile)
|
||||
keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile)
|
||||
|
||||
certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initialize certificate watcher: %w", err)
|
||||
}
|
||||
go certWatcher.Watch(ctx)
|
||||
tlsConfig.GetCertificate = certWatcher.GetCertificate
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
if s.ACMEChallengeType == "" {
|
||||
s.ACMEChallengeType = "tls-alpn-01"
|
||||
}
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"acme_server": s.ACMEDirectory,
|
||||
"challenge_type": s.ACMEChallengeType,
|
||||
}).Debug("ACME certificates enabled, configuring certificate manager")
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod)
|
||||
|
||||
if s.ACMEChallengeType == "http-01" {
|
||||
s.http = &http.Server{
|
||||
Addr: s.ACMEChallengeAddress,
|
||||
Handler: s.acme.HTTPHandler(nil),
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueACME),
|
||||
}
|
||||
go func() {
|
||||
if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed")
|
||||
}
|
||||
}()
|
||||
}
|
||||
tlsConfig = s.acme.TLSConfig()
|
||||
|
||||
// ServerName needs to be set to allow for ACME to work correctly
|
||||
// when using CNAME URLs to access the proxy.
|
||||
tlsConfig.ServerName = s.ProxyURL
|
||||
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"ServerName": s.ProxyURL,
|
||||
"challenge_type": s.ACMEChallengeType,
|
||||
}).Debug("ACME certificate manager configured")
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// gracefulShutdown performs a zero-downtime shutdown sequence. It marks the
|
||||
// readiness probe as failing, waits for load balancer propagation, drains
|
||||
// in-flight connections, and then stops all background services.
|
||||
@@ -467,12 +367,7 @@ func (s *Server) gracefulShutdown() {
|
||||
s.Logger.Warnf("https server drain: %v", err)
|
||||
}
|
||||
|
||||
// Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
|
||||
if n := s.hijackTracker.CloseAll(); n > 0 {
|
||||
s.Logger.Infof("closed %d hijacked connection(s)", n)
|
||||
}
|
||||
|
||||
// Step 5: Stop all remaining background services.
|
||||
// Step 4: Stop all remaining background services.
|
||||
s.shutdownServices()
|
||||
s.Logger.Info("graceful shutdown complete")
|
||||
}
|
||||
@@ -597,7 +492,36 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
return fmt.Errorf("receive msg: %w", err)
|
||||
}
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
// Process msg updates sequentially to avoid conflict, so block
|
||||
// additional receiving until this processing is completed.
|
||||
for _, mapping := range msg.GetMapping() {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"type": mapping.GetType(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"path": mapping.GetPath(),
|
||||
"id": mapping.GetId(),
|
||||
}).Debug("Processing mapping update")
|
||||
switch mapping.GetType() {
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
|
||||
if err := s.addMapping(ctx, mapping); err != nil {
|
||||
// TODO: Retry this? Or maybe notify the management server that this mapping has failed?
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"error": err,
|
||||
}).Error("Error adding new mapping, ignoring this mapping and continuing processing")
|
||||
}
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
|
||||
if err := s.updateMapping(ctx, mapping); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
}).Errorf("failed to update mapping: %v", err)
|
||||
}
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED:
|
||||
s.removeMapping(ctx, mapping)
|
||||
}
|
||||
}
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
|
||||
if !*initialSyncDone && msg.GetInitialSyncComplete() {
|
||||
@@ -611,37 +535,6 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
for _, mapping := range mappings {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"type": mapping.GetType(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"path": mapping.GetPath(),
|
||||
"id": mapping.GetId(),
|
||||
}).Debug("Processing mapping update")
|
||||
switch mapping.GetType() {
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
|
||||
if err := s.addMapping(ctx, mapping); err != nil {
|
||||
// TODO: Retry this? Or maybe notify the management server that this mapping has failed?
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"error": err,
|
||||
}).Error("Error adding new mapping, ignoring this mapping and continuing processing")
|
||||
}
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
|
||||
if err := s.updateMapping(ctx, mapping); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
}).Errorf("failed to update mapping: %v", err)
|
||||
}
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED:
|
||||
s.removeMapping(ctx, mapping)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
||||
d := domain.Domain(mapping.GetDomain())
|
||||
accountID := types.AccountID(mapping.GetAccountId())
|
||||
@@ -740,7 +633,7 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
||||
// If addr is empty, it defaults to localhost:8444 for security.
|
||||
func debugEndpointAddr(addr string) string {
|
||||
if addr == "" {
|
||||
return defaultDebugAddr
|
||||
return "localhost:8444"
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
6
proxy/web/dist/assets/index.js
vendored
6
proxy/web/dist/assets/index.js
vendored
File diff suppressed because one or more lines are too long
2
proxy/web/dist/assets/style.css
vendored
2
proxy/web/dist/assets/style.css
vendored
File diff suppressed because one or more lines are too long
@@ -59,7 +59,7 @@ function App() {
|
||||
formData.append(methods.pin!, value);
|
||||
}
|
||||
|
||||
fetch(globalThis.location.href, {
|
||||
fetch(window.location.href, {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
redirect: "manual",
|
||||
@@ -67,7 +67,7 @@ function App() {
|
||||
.then((res) => {
|
||||
if (res.type === "opaqueredirect" || res.status === 0) {
|
||||
setSubmitting("redirect");
|
||||
globalThis.location.reload();
|
||||
window.location.reload();
|
||||
} else {
|
||||
handleAuthError(method, "Authentication failed. Please try again.");
|
||||
}
|
||||
@@ -92,7 +92,6 @@ function App() {
|
||||
|
||||
const hasCredentialAuth = methods.password || methods.pin;
|
||||
const hasBothCredentials = methods.password && methods.pin;
|
||||
const buttonLabel = activeTab === "password" ? "Sign in" : "Submit";
|
||||
|
||||
if (submitting === "redirect") {
|
||||
return (
|
||||
@@ -125,7 +124,7 @@ function App() {
|
||||
<Button
|
||||
variant="primary"
|
||||
className="w-full"
|
||||
onClick={() => { globalThis.location.href = methods.oidc!; }}
|
||||
onClick={() => (window.location.href = methods.oidc!)}
|
||||
>
|
||||
<LogIn size={16} />
|
||||
Sign in with SSO
|
||||
@@ -171,7 +170,7 @@ function App() {
|
||||
<div className="mb-4">
|
||||
{methods.password && (activeTab === "password" || !methods.pin) && (
|
||||
<>
|
||||
{!hasBothCredentials && <Label htmlFor="password">Password</Label>}
|
||||
{!hasBothCredentials && <Label>Password</Label>}
|
||||
<Input
|
||||
ref={passwordRef}
|
||||
type="password"
|
||||
@@ -187,7 +186,7 @@ function App() {
|
||||
)}
|
||||
{methods.pin && (activeTab === "pin" || !methods.password) && (
|
||||
<>
|
||||
{!hasBothCredentials && <Label htmlFor="pin-0">Enter PIN Code</Label>}
|
||||
{!hasBothCredentials && <Label>Enter PIN Code</Label>}
|
||||
<PinCodeInput
|
||||
ref={pinRef}
|
||||
value={pin}
|
||||
@@ -205,13 +204,13 @@ function App() {
|
||||
variant="secondary"
|
||||
className="w-full"
|
||||
>
|
||||
{submitting === null ? (
|
||||
buttonLabel
|
||||
) : (
|
||||
{submitting !== null ? (
|
||||
<>
|
||||
<Loader2 className="animate-spin" size={16} />
|
||||
Verifying...
|
||||
</>
|
||||
) : (
|
||||
activeTab === "password" ? "Sign in" : "Submit"
|
||||
)}
|
||||
</Button>
|
||||
</form>
|
||||
|
||||
@@ -7,7 +7,7 @@ import { PoweredByNetBird } from "@/components/PoweredByNetBird";
|
||||
import { StatusCard } from "@/components/StatusCard";
|
||||
import type { ErrorData } from "@/data";
|
||||
|
||||
export function ErrorPage({ code, title, message, proxy = true, destination = true, requestId, simple = false, retryUrl }: Readonly<ErrorData>) {
|
||||
export function ErrorPage({ code, title, message, proxy = true, destination = true, requestId, simple = false, retryUrl }: ErrorData) {
|
||||
useEffect(() => {
|
||||
document.title = `${title} - NetBird Service`;
|
||||
}, [title]);
|
||||
@@ -38,19 +38,13 @@ export function ErrorPage({ code, title, message, proxy = true, destination = tr
|
||||
|
||||
{/* Buttons */}
|
||||
<div className="flex gap-3 justify-center items-center mb-6 z-10 relative">
|
||||
<Button variant="primary" onClick={() => {
|
||||
if (retryUrl) {
|
||||
globalThis.location.href = retryUrl;
|
||||
} else {
|
||||
globalThis.location.reload();
|
||||
}
|
||||
}}>
|
||||
<Button variant="primary" onClick={() => retryUrl ? window.location.href = retryUrl : window.location.reload()}>
|
||||
<RotateCw size={16} />
|
||||
Refresh Page
|
||||
</Button>
|
||||
<Button
|
||||
variant="secondary"
|
||||
onClick={() => globalThis.open("https://docs.netbird.io", "_blank", "noopener,noreferrer")}
|
||||
onClick={() => window.open("https://docs.netbird.io", "_blank", "noopener,noreferrer")}
|
||||
>
|
||||
<BookText size={16} />
|
||||
Documentation
|
||||
|
||||
@@ -4,7 +4,7 @@ interface ConnectionLineProps {
|
||||
success?: boolean;
|
||||
}
|
||||
|
||||
export function ConnectionLine({ success = true }: Readonly<ConnectionLineProps>) {
|
||||
export function ConnectionLine({ success = true }: ConnectionLineProps) {
|
||||
if (success) {
|
||||
return (
|
||||
<div className="flex-1 flex items-center justify-center h-12 w-full px-5">
|
||||
|
||||
@@ -5,7 +5,7 @@ type Props = {
|
||||
className?: string;
|
||||
};
|
||||
|
||||
export function Description({ children, className }: Readonly<Props>) {
|
||||
export function Description({ children, className }: Props) {
|
||||
return (
|
||||
<div className={cn("text-sm text-nb-gray-300 font-light mt-2 block text-center z-10 relative", className)}>
|
||||
{children}
|
||||
|
||||
@@ -5,7 +5,7 @@ interface HelpTextProps {
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function HelpText({ children, className }: Readonly<HelpTextProps>) {
|
||||
export default function HelpText({ children, className }: HelpTextProps) {
|
||||
return (
|
||||
<span
|
||||
className={cn(
|
||||
|
||||
@@ -2,10 +2,9 @@ import { cn } from "@/utils/helpers";
|
||||
|
||||
type LabelProps = React.LabelHTMLAttributes<HTMLLabelElement>;
|
||||
|
||||
export function Label({ className, htmlFor, ...props }: Readonly<LabelProps>) {
|
||||
export function Label({ className, ...props }: LabelProps) {
|
||||
return (
|
||||
<label
|
||||
htmlFor={htmlFor}
|
||||
className={cn(
|
||||
"text-sm font-medium tracking-wider leading-none",
|
||||
"peer-disabled:cursor-not-allowed peer-disabled:opacity-70",
|
||||
|
||||
@@ -20,7 +20,7 @@ interface Props {
|
||||
autoFocus?: boolean;
|
||||
}
|
||||
|
||||
const PinCodeInput = forwardRef<PinCodeInputRef, Readonly<Props>>(function PinCodeInput(
|
||||
const PinCodeInput = forwardRef<PinCodeInputRef, Props>(function PinCodeInput(
|
||||
{ value, onChange, length = 6, disabled = false, className, autoFocus = false },
|
||||
ref,
|
||||
) {
|
||||
@@ -32,15 +32,14 @@ const PinCodeInput = forwardRef<PinCodeInputRef, Readonly<Props>>(function PinCo
|
||||
},
|
||||
}));
|
||||
|
||||
const digits = value.split("").concat(new Array(length).fill("")).slice(0, length);
|
||||
const slotIds = Array.from({ length }, (_, i) => `pin-${i}`);
|
||||
const digits = value.split("").concat(Array(length).fill("")).slice(0, length);
|
||||
|
||||
const handleChange = (index: number, digit: string) => {
|
||||
if (!/^\d*$/.test(digit)) return;
|
||||
|
||||
const newDigits = [...digits];
|
||||
newDigits[index] = digit.slice(-1);
|
||||
const newValue = newDigits.join("").replaceAll(/\s/g, "");
|
||||
const newValue = newDigits.join("").replace(/\s/g, "");
|
||||
onChange(newValue);
|
||||
|
||||
if (digit && index < length - 1) {
|
||||
@@ -62,7 +61,7 @@ const PinCodeInput = forwardRef<PinCodeInputRef, Readonly<Props>>(function PinCo
|
||||
|
||||
const handlePaste = (e: ClipboardEvent<HTMLInputElement>) => {
|
||||
e.preventDefault();
|
||||
const pastedData = e.clipboardData.getData("text").replaceAll(/\D/g, "").slice(0, length);
|
||||
const pastedData = e.clipboardData.getData("text").replace(/\D/g, "").slice(0, length);
|
||||
onChange(pastedData);
|
||||
|
||||
const nextIndex = Math.min(pastedData.length, length - 1);
|
||||
@@ -77,8 +76,7 @@ const PinCodeInput = forwardRef<PinCodeInputRef, Readonly<Props>>(function PinCo
|
||||
<div className={cn("flex gap-2 w-full min-w-0", className)}>
|
||||
{digits.map((digit, index) => (
|
||||
<input
|
||||
key={slotIds[index]}
|
||||
id={slotIds[index]}
|
||||
key={index}
|
||||
ref={(el) => {
|
||||
inputRefs.current[index] = el;
|
||||
}}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { cn } from "@/utils/helpers";
|
||||
import { useState, useMemo, useCallback } from "react";
|
||||
import { useState } from "react";
|
||||
import { TabContext, useTabContext } from "./TabContext";
|
||||
|
||||
type TabsProps = {
|
||||
@@ -11,24 +11,19 @@ type TabsProps = {
|
||||
| ((context: { value: string; onChange: (value: string) => void }) => React.ReactNode);
|
||||
};
|
||||
|
||||
function SegmentedTabs({ value, defaultValue, onChange, children }: Readonly<TabsProps>) {
|
||||
const [internalValue, setInternalValue] = useState(defaultValue ?? "");
|
||||
const currentValue = value ?? internalValue;
|
||||
function SegmentedTabs({ value, defaultValue, onChange, children }: TabsProps) {
|
||||
const [internalValue, setInternalValue] = useState(defaultValue || "");
|
||||
const currentValue = value !== undefined ? value : internalValue;
|
||||
|
||||
const handleChange = useCallback((newValue: string) => {
|
||||
const handleChange = (newValue: string) => {
|
||||
if (value === undefined) {
|
||||
setInternalValue(newValue);
|
||||
}
|
||||
onChange?.(newValue);
|
||||
}, [value, onChange]);
|
||||
|
||||
const contextValue = useMemo(
|
||||
() => ({ value: currentValue, onChange: handleChange }),
|
||||
[currentValue, handleChange],
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<TabContext.Provider value={contextValue}>
|
||||
<TabContext.Provider value={{ value: currentValue, onChange: handleChange }}>
|
||||
<div>
|
||||
{typeof children === "function"
|
||||
? children({ value: currentValue, onChange: handleChange })
|
||||
@@ -41,10 +36,10 @@ function SegmentedTabs({ value, defaultValue, onChange, children }: Readonly<Tab
|
||||
function List({
|
||||
children,
|
||||
className,
|
||||
}: Readonly<{
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
className?: string;
|
||||
}>) {
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
role="tablist"
|
||||
@@ -65,23 +60,16 @@ function Trigger({
|
||||
className,
|
||||
selected,
|
||||
onClick,
|
||||
}: Readonly<{
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
value: string;
|
||||
disabled?: boolean;
|
||||
className?: string;
|
||||
selected?: boolean;
|
||||
onClick?: () => void;
|
||||
}>) {
|
||||
}) {
|
||||
const context = useTabContext();
|
||||
const isSelected = selected ?? value === context.value;
|
||||
|
||||
let stateClassName = "";
|
||||
if (isSelected) {
|
||||
stateClassName = "bg-nb-gray-900 text-white";
|
||||
} else if (!disabled) {
|
||||
stateClassName = "text-nb-gray-400 hover:bg-nb-gray-900/50";
|
||||
}
|
||||
const isSelected = selected !== undefined ? selected : value === context.value;
|
||||
|
||||
const handleClick = () => {
|
||||
context.onChange(value);
|
||||
@@ -98,7 +86,11 @@ function Trigger({
|
||||
className={cn(
|
||||
"px-4 py-2 text-sm rounded-md w-full transition-all cursor-pointer",
|
||||
disabled && "opacity-30 cursor-not-allowed",
|
||||
stateClassName,
|
||||
isSelected
|
||||
? "bg-nb-gray-900 text-white"
|
||||
: disabled
|
||||
? ""
|
||||
: "text-nb-gray-400 hover:bg-nb-gray-900/50",
|
||||
className
|
||||
)}
|
||||
>
|
||||
@@ -114,14 +106,14 @@ function Content({
|
||||
value,
|
||||
className,
|
||||
visible,
|
||||
}: Readonly<{
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
value: string;
|
||||
className?: string;
|
||||
visible?: boolean;
|
||||
}>) {
|
||||
}) {
|
||||
const context = useTabContext();
|
||||
const isVisible = visible ?? value === context.value;
|
||||
const isVisible = visible !== undefined ? visible : value === context.value;
|
||||
|
||||
if (!isVisible) return null;
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user