mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-28 18:59:57 +00:00
Compare commits
38 Commits
client_lif
...
v0.74.0-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db0f5fa4f5 | ||
|
|
7e4e26a83e | ||
|
|
d8980ee2f8 | ||
|
|
aeda36f713 | ||
|
|
a2275c24d1 | ||
|
|
0744faf938 | ||
|
|
b416063bcc | ||
|
|
615631567a | ||
|
|
f4daf59bcd | ||
|
|
ff2787e184 | ||
|
|
e20b62ad65 | ||
|
|
18b38943aa | ||
|
|
a400828b89 | ||
|
|
e2bb328a34 | ||
|
|
221b9c012c | ||
|
|
17b2044596 | ||
|
|
07101c59ac | ||
|
|
51b6f6291b | ||
|
|
2ebf26006a | ||
|
|
211a26019a | ||
|
|
6c26178ad5 | ||
|
|
af3b7e4497 | ||
|
|
e84f6527f7 | ||
|
|
ac9529ea8c | ||
|
|
f736ef9647 | ||
|
|
cf58bf1ba9 | ||
|
|
522b8ed969 | ||
|
|
c9e99659ea | ||
|
|
58c79f5878 | ||
|
|
15a0504fb1 | ||
|
|
883a1a8961 | ||
|
|
54192a94b7 | ||
|
|
8511687270 | ||
|
|
35b465fa4a | ||
|
|
fb87f751a5 | ||
|
|
679c7182a4 | ||
|
|
8c031ea6f0 | ||
|
|
60a9544656 |
@@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -59,12 +59,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: true
|
||||
|
||||
2
.github/workflows/git-town.yml
vendored
2
.github/workflows/git-town.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
|
||||
|
||||
6
.github/workflows/golang-test-darwin.yml
vendored
6
.github/workflows/golang-test-darwin.yml
vendored
@@ -16,12 +16,12 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
|
||||
4
.github/workflows/golang-test-freebsd.yml
vendored
4
.github/workflows/golang-test-freebsd.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
id: test
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
|
||||
58
.github/workflows/golang-test-linux.yml
vendored
58
.github/workflows/golang-test-linux.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -119,12 +119,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -162,7 +162,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -175,12 +175,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -246,12 +246,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -290,7 +290,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -306,12 +306,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -347,7 +347,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -363,12 +363,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -407,7 +407,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -424,12 +424,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -484,7 +484,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -529,12 +529,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -579,10 +579,11 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags devcert -run=^$ -bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
|
||||
env:
|
||||
GIT_BRANCH: ${{ github.ref_name }}
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
@@ -623,12 +624,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -673,12 +674,13 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags=benchmark \
|
||||
-run=^$ \
|
||||
-bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/server/http/...
|
||||
env:
|
||||
GIT_BRANCH: ${{ github.ref_name }}
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
@@ -692,12 +694,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -734,7 +736,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
|
||||
4
.github/workflows/golang-test-windows.yml
vendored
4
.github/workflows/golang-test-windows.yml
vendored
@@ -18,12 +18,12 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
id: go
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
8
.github/workflows/golangci-lint.yml
vendored
8
.github/workflows/golangci-lint.yml
vendored
@@ -15,13 +15,13 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals,flate,recordin,unparseable
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Check for duplicate constants
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
2
.github/workflows/install-script-test.yml
vendored
2
.github/workflows/install-script-test.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
10
.github/workflows/mobile-build-validation.yml
vendored
10
.github/workflows/mobile-build-validation.yml
vendored
@@ -16,11 +16,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Setup Android SDK
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
|
||||
uses: actions/setup-java@ad2b38190b15e4d6bdf0c97fb4fca8412226d287
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
@@ -54,11 +54,11 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: install gomobile
|
||||
|
||||
32
.github/workflows/release.yml
vendored
32
.github/workflows/release.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
@@ -166,7 +166,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -186,9 +186,9 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
|
||||
uses: docker/setup-qemu-action@06116385d9baf250c9f4dcb4858b16962ea869c3 #v4.1.0
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
|
||||
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 #v4.1.0
|
||||
- name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
@@ -221,7 +221,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --clean ${{ env.flags }}
|
||||
@@ -347,7 +347,7 @@ jobs:
|
||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
@@ -374,7 +374,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -420,7 +420,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||
@@ -464,12 +464,12 @@ jobs:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -488,7 +488,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||
@@ -522,7 +522,7 @@ jobs:
|
||||
downloadPath: '${{ github.workspace }}\temp'
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -534,13 +534,13 @@ jobs:
|
||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
- name: Download release artifacts
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
with:
|
||||
name: release
|
||||
path: release
|
||||
|
||||
- name: Download UI release artifacts
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: release-ui
|
||||
|
||||
12
.github/workflows/test-infrastructure-files.yml
vendored
12
.github/workflows/test-infrastructure-files.yml
vendored
@@ -68,12 +68,12 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
@@ -207,7 +207,7 @@ jobs:
|
||||
- name: Build management docker image
|
||||
working-directory: management
|
||||
run: |
|
||||
docker build -t netbirdio/management:latest .
|
||||
docker build -t netbirdio/management:latest --build-arg TARGETPLATFORM=. .
|
||||
|
||||
- name: Build signal binary
|
||||
working-directory: signal
|
||||
@@ -216,7 +216,7 @@ jobs:
|
||||
- name: Build signal docker image
|
||||
working-directory: signal
|
||||
run: |
|
||||
docker build -t netbirdio/signal:latest .
|
||||
docker build -t netbirdio/signal:latest --build-arg TARGETPLATFORM=. .
|
||||
|
||||
- name: Build relay binary
|
||||
working-directory: relay
|
||||
@@ -225,7 +225,7 @@ jobs:
|
||||
- name: Build relay docker image
|
||||
working-directory: relay
|
||||
run: |
|
||||
docker build -t netbirdio/relay:latest .
|
||||
docker build -t netbirdio/relay:latest --build-arg TARGETPLATFORM=. .
|
||||
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files/artifacts
|
||||
@@ -256,7 +256,7 @@ jobs:
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
8
.github/workflows/wasm-build-validation.yml
vendored
8
.github/workflows/wasm-build-validation.yml
vendored
@@ -19,11 +19,11 @@ jobs:
|
||||
GOARCH: wasm
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
@@ -44,11 +44,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Build Wasm client
|
||||
|
||||
@@ -247,7 +247,7 @@ dockers_v2:
|
||||
- netbirdio/netbird
|
||||
- ghcr.io/netbirdio/netbird
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
@@ -295,7 +295,7 @@ dockers_v2:
|
||||
- netbirdio/relay
|
||||
- ghcr.io/netbirdio/relay
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: relay/Dockerfile
|
||||
platforms:
|
||||
@@ -317,7 +317,7 @@ dockers_v2:
|
||||
- netbirdio/signal
|
||||
- ghcr.io/netbirdio/signal
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: signal/Dockerfile
|
||||
platforms:
|
||||
@@ -339,7 +339,7 @@ dockers_v2:
|
||||
- netbirdio/management
|
||||
- ghcr.io/netbirdio/management
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: management/Dockerfile
|
||||
platforms:
|
||||
@@ -361,7 +361,7 @@ dockers_v2:
|
||||
- netbirdio/upload
|
||||
- ghcr.io/netbirdio/upload
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: upload-server/Dockerfile
|
||||
platforms:
|
||||
@@ -383,7 +383,7 @@ dockers_v2:
|
||||
- netbirdio/netbird-server
|
||||
- ghcr.io/netbirdio/netbird-server
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: combined/Dockerfile
|
||||
platforms:
|
||||
@@ -405,7 +405,7 @@ dockers_v2:
|
||||
- netbirdio/reverse-proxy
|
||||
- ghcr.io/netbirdio/reverse-proxy
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: proxy/Dockerfile
|
||||
platforms:
|
||||
@@ -462,9 +462,13 @@ checksum:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
- glob: ./infrastructure_files/getting-started-enterprise.sh
|
||||
- glob: ./infrastructure_files/migrate-to-enterprise.sh
|
||||
|
||||
release:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
- glob: ./infrastructure_files/getting-started-enterprise.sh
|
||||
- glob: ./infrastructure_files/migrate-to-enterprise.sh
|
||||
|
||||
@@ -151,9 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
connectClient := internal.NewConnectClient(ctx, c.recorder)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
c.setState(cfg, cacheDir, connectClient)
|
||||
return connectClient.RunOnAndroid(cfg, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
@@ -186,9 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
connectClient := internal.NewConnectClient(ctx, c.recorder)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
c.setState(cfg, cacheDir, connectClient)
|
||||
return connectClient.RunOnAndroid(cfg, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
|
||||
@@ -130,7 +130,7 @@ func debugConfigDump(cmd *cobra.Command, _ []string) error {
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
ProfileName: string(activeProf.ID),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -227,7 +227,7 @@ func switchProfile(ctx context.Context, handle string, username string) (profile
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("switch profile failed: %v", err)
|
||||
return "", fmt.Errorf("switch profile failed: %w", err)
|
||||
}
|
||||
|
||||
return profilemanager.ID(resp.Id), nil
|
||||
|
||||
@@ -138,26 +138,23 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to service CLI interface: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
profileName := args[0]
|
||||
|
||||
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
id, err := addProfileOnDaemon(cmd.Context(), daemonClient, profileName, currUser.Username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add profile request: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
|
||||
@@ -166,7 +163,6 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
|
||||
}
|
||||
|
||||
id := profilemanager.ID(resp.Id)
|
||||
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
|
||||
return nil
|
||||
|
||||
@@ -330,3 +326,19 @@ func wrapAmbiguityError(err error, handle string) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// addProfileOnDaemon issues the AddProfile RPC on an existing daemon client
|
||||
// and returns the new profile's ID. It is the single entry point for profile
|
||||
// creation, shared by `netbird profile add` and the `netbird up --profile
|
||||
// <name>` auto-create path.
|
||||
func addProfileOnDaemon(ctx context.Context, client proto.DaemonServiceClient, profileName, username string) (profilemanager.ID, error) {
|
||||
resp, err := client.AddProfile(ctx, &proto.AddProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: username,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("add profile failed: %w", err)
|
||||
}
|
||||
|
||||
return profilemanager.ID(resp.Id), nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||
@@ -262,46 +261,17 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
|
||||
return prefix + upper
|
||||
}
|
||||
|
||||
// DialClientGRPCServer returns client connection to the daemon server. It waits
|
||||
// (up to the timeout) for the daemon to become reachable so an `up` issued right
|
||||
// after `service start` tolerates the startup race. Instead of grpc's blocking
|
||||
// dial — whose raw "transport failed" retry warnings are silenced by the logger
|
||||
// config — we drive the wait ourselves and emit one clean line per failed attempt.
|
||||
// DialClientGRPCServer returns client connection to the daemon server.
|
||||
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
return grpc.DialContext(
|
||||
ctx,
|
||||
strings.TrimPrefix(addr, "tcp://"),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn.Connect()
|
||||
for {
|
||||
state := conn.GetState()
|
||||
if state == connectivity.Ready {
|
||||
return conn, nil
|
||||
}
|
||||
// Log only once the connection has actually failed — not during the
|
||||
// brief Idle/Connecting phase on a healthy daemon (avoids a spurious
|
||||
// line + wait when the daemon is already up).
|
||||
if state == connectivity.TransientFailure {
|
||||
log.Infof("waiting for the netbird daemon to become available at %s...", addr)
|
||||
}
|
||||
// Wake on the next state change, but at least every second so a stuck
|
||||
// TransientFailure re-logs at a steady cadence until the timeout.
|
||||
waitCtx, waitCancel := context.WithTimeout(ctx, time.Second)
|
||||
conn.WaitForStateChange(waitCtx, state)
|
||||
waitCancel()
|
||||
if ctx.Err() != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("daemon not reachable at %s: %w", addr, ctx.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithBackOff execute function in backoff cycle.
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@@ -111,11 +110,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
// Resolve the active profile's display name via the daemon, which runs
|
||||
// as root and can read the per-user profile files. The local profile
|
||||
// manager only knows the active profile ID, not its display name.
|
||||
profName := getActiveProfileName(ctx)
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
||||
Anonymize: anonymizeFlag,
|
||||
@@ -167,6 +165,25 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// getActiveProfileName asks the daemon for the active profile's display
|
||||
// name. The daemon runs as root and can read the per-user profile files to
|
||||
// resolve the ID to its human-readable name. Returns an empty string on any
|
||||
// error so status output degrades gracefully.
|
||||
func getActiveProfileName(ctx context.Context) string {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).GetActiveProfile(ctx, &proto.GetActiveProfileRequest{})
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return resp.GetProfileName()
|
||||
}
|
||||
|
||||
func parseFilters() error {
|
||||
switch strings.ToLower(statusFilter) {
|
||||
case "", "idle", "connecting", "connected":
|
||||
|
||||
@@ -128,15 +128,9 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
var profileSwitched bool
|
||||
// switch profile if provided
|
||||
if profileName != "" {
|
||||
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
|
||||
if err != nil {
|
||||
if err := switchOrCreateProfile(cmd.Context(), pm, profileName, username.Username); err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
profileSwitched = true
|
||||
}
|
||||
|
||||
@@ -151,6 +145,52 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
|
||||
}
|
||||
|
||||
// switchOrCreateProfile switches the active profile to the one identified by
|
||||
// handle, creating it first when it does not exist yet. This restores the
|
||||
// pre-0.73 behaviour where `netbird up --profile <name>` auto-creates a
|
||||
// missing profile instead of failing.
|
||||
func switchOrCreateProfile(ctx context.Context, pm *profilemanager.ProfileManager, handle, username string) error {
|
||||
resolvedID, err := switchProfile(ctx, handle, username)
|
||||
if err != nil {
|
||||
st, ok := gstatus.FromError(err)
|
||||
if !ok || st.Code() != codes.NotFound {
|
||||
return err
|
||||
}
|
||||
// Don't fail immediately on a create error: a concurrent run may
|
||||
// have created the profile between the NotFound above and this
|
||||
// call, in which case the retried switch still succeeds. Only
|
||||
// surface the create error if the switch also fails.
|
||||
_, createErr := createProfile(ctx, handle, username)
|
||||
if resolvedID, err = switchProfile(ctx, handle, username); err != nil {
|
||||
if createErr != nil {
|
||||
return fmt.Errorf("create profile: %w", createErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createProfile dials the daemon and creates a new profile with the given
|
||||
// display name, returning its generated ID. Use addProfileOnDaemon directly
|
||||
// when a daemon client is already available to reuse the connection.
|
||||
func createProfile(ctx context.Context, profileName, username string) (profilemanager.ID, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return addProfileOnDaemon(ctx, proto.NewDaemonServiceClient(conn), profileName, username)
|
||||
}
|
||||
|
||||
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
|
||||
// override the default profile filepath if provided
|
||||
if configPath != "" {
|
||||
@@ -201,10 +241,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, r)
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(config, nil, util.FindFirstLogPath(logFiles))
|
||||
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
|
||||
|
||||
@@ -264,24 +264,34 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
client := internal.NewConnectClient(ctx, c.recorder)
|
||||
client := internal.NewConnectClient(ctx, c.config, c.recorder)
|
||||
client.SetSyncResponsePersistence(true)
|
||||
|
||||
// The supervisor owns the run; we wait until it is established, ends with a
|
||||
// startup error (permanent backoff err), or startCtx expires.
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
client.RunAsync(c.config, nil)
|
||||
run := make(chan struct{})
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run, ""); err != nil {
|
||||
clientErr <- err
|
||||
}
|
||||
}()
|
||||
|
||||
if err := client.WaitEstablishedOrDone(startCtx); err != nil {
|
||||
// Either startCtx expired while connecting, or the run ended before it
|
||||
// established. Cancel the client context before stopping: Engine.Start
|
||||
// blocks on the signal stream while holding the engine mutex and only
|
||||
// unblocks on cancellation. Stopping first would deadlock on that mutex.
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
// ConnectClient.Stop now cancels its own run context and waits for the
|
||||
// run loop to tear the engine down, so this cancel() is no longer
|
||||
// required to break the deadlock and could be removed. It is kept as a
|
||||
// defensive belt-and-suspenders: cancelling the parent context first
|
||||
// guarantees the run loop is unblocked even if Stop's contract regresses.
|
||||
cancel()
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after startup failure. Stop error: %w. Startup: %w", stopErr, err)
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
}
|
||||
return startCtx.Err()
|
||||
case err := <-clientErr:
|
||||
return fmt.Errorf("startup: %w", err)
|
||||
case <-run:
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
@@ -18,7 +19,6 @@ import (
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
@@ -49,23 +49,17 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// androidMobileDep is set on Android to inject the MobileDependency for runs
|
||||
// started through the generic entry points (Run/RunAsync, e.g. embed.Client).
|
||||
// nil on other platforms, where the dependency is empty.
|
||||
var androidMobileDep func(config *profilemanager.Config) MobileDependency
|
||||
|
||||
// mobileDependency returns the MobileDependency for a run started via the
|
||||
// generic entry points. On Android the androidMobileDep provider supplies
|
||||
// platform stubs (or real implementations); elsewhere it is empty.
|
||||
func (c *ConnectClient) mobileDependency(config *profilemanager.Config) MobileDependency {
|
||||
if androidMobileDep != nil {
|
||||
return androidMobileDep(config)
|
||||
}
|
||||
return MobileDependency{}
|
||||
}
|
||||
// androidRunOverride is set on Android to inject mobile dependencies
|
||||
// when using embed.Client (which calls Run() with empty MobileDependency).
|
||||
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
|
||||
|
||||
type ConnectClient struct {
|
||||
ctx context.Context
|
||||
runCancel context.CancelFunc
|
||||
runExited chan struct{}
|
||||
runOnce sync.Once
|
||||
runStarted atomic.Bool
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
|
||||
engine *Engine
|
||||
@@ -74,62 +68,41 @@ type ConnectClient struct {
|
||||
updateManager *updater.Manager
|
||||
|
||||
persistSyncResponse bool
|
||||
|
||||
// sup serializes all start/stop requests so two lifecycle operations can
|
||||
// never overlap. See connect_lifecycle.go.
|
||||
sup *supervisor
|
||||
}
|
||||
|
||||
func NewConnectClient(
|
||||
ctx context.Context,
|
||||
config *profilemanager.Config,
|
||||
statusRecorder *peer.Status,
|
||||
) *ConnectClient {
|
||||
c := &ConnectClient{
|
||||
ctx: ctx,
|
||||
// Derive the run context here so Stop owns the cancel that unblocks the run
|
||||
// loop. runCancel is set once at construction, so Stop can call it without
|
||||
// racing the run loop's startup. Callers therefore need not cancel before Stop.
|
||||
runCtx, runCancel := context.WithCancel(ctx)
|
||||
return &ConnectClient{
|
||||
ctx: runCtx,
|
||||
runCancel: runCancel,
|
||||
runExited: make(chan struct{}),
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
engineMutex: sync.Mutex{},
|
||||
}
|
||||
c.sup = newSupervisor(ctx, c.run)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
|
||||
c.updateManager = um
|
||||
}
|
||||
|
||||
// Run with main logic. md carries optional gRPC metadata (e.g. the UI
|
||||
// user-agent) to forward to the management/signal services; nil when none.
|
||||
func (c *ConnectClient) Run(config *profilemanager.Config, md metadata.MD, logPath string) error {
|
||||
return c.sup.start(config, md, c.mobileDependency(config), logPath)
|
||||
}
|
||||
|
||||
// RunAsync starts a client run without blocking. Used by the daemon and embed,
|
||||
// which drive the lifecycle through the supervisor rather than blocking on Run;
|
||||
// they then wait for the outcome via WaitEstablishedOrDone. The run's lifecycle
|
||||
// channels are created and owned by the supervisor — callers never hold them.
|
||||
func (c *ConnectClient) RunAsync(config *profilemanager.Config, md metadata.MD) {
|
||||
c.sup.startAsync(config, md, c.mobileDependency(config), "", nil)
|
||||
}
|
||||
|
||||
// Restart atomically stops any in-flight run and starts a fresh one with the
|
||||
// given config. The stop+start happens as a single supervisor operation, so no
|
||||
// other lifecycle request can interleave between them — used for explicit
|
||||
// restarts (e.g. an MDM policy change) that must not expose a "stopped" window.
|
||||
func (c *ConnectClient) Restart(config *profilemanager.Config, md metadata.MD) {
|
||||
c.sup.restartAsync(config, md, c.mobileDependency(config), "")
|
||||
}
|
||||
|
||||
// WaitEstablishedOrDone blocks until the in-flight run becomes established (nil),
|
||||
// ends before that (the run error, or a sentinel on a clean stop), or ctx is
|
||||
// cancelled. Returns errNoRunInFlight if no run is in flight. Wraps the wait on
|
||||
// the supervisor-owned channels so callers never touch them directly.
|
||||
func (c *ConnectClient) WaitEstablishedOrDone(ctx context.Context) error {
|
||||
return c.sup.waitEstablishedOrDone(ctx)
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||
if androidRunOverride != nil {
|
||||
return androidRunOverride(c, runningChan, logPath)
|
||||
}
|
||||
return c.run(MobileDependency{}, runningChan, logPath)
|
||||
}
|
||||
|
||||
// RunOnAndroid with main logic on mobile system
|
||||
func (c *ConnectClient) RunOnAndroid(
|
||||
config *profilemanager.Config,
|
||||
tunAdapter device.TunAdapter,
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
@@ -148,11 +121,10 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
StateFilePath: stateFilePath,
|
||||
TempDir: cacheDir,
|
||||
}
|
||||
return c.sup.start(config, nil, mobileDependency, "")
|
||||
return c.run(mobileDependency, nil, "")
|
||||
}
|
||||
|
||||
func (c *ConnectClient) RunOniOS(
|
||||
config *profilemanager.Config,
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
@@ -170,12 +142,15 @@ func (c *ConnectClient) RunOniOS(
|
||||
StateFilePath: stateFilePath,
|
||||
TempDir: cacheDir,
|
||||
}
|
||||
return c.sup.start(config, nil, mobileDependency, logFilePath)
|
||||
return c.run(mobileDependency, nil, logFilePath)
|
||||
}
|
||||
|
||||
// run executes a single client run. runCtx is owned by the supervisor: cancelling
|
||||
// it tears the run down (it is the parent of the per-attempt engine context).
|
||||
func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Config, mobileDependency MobileDependency, connEstablishedChan chan struct{}, logPath string) error {
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
||||
// Mark the loop as started and signal exit on return so Stop can wait for
|
||||
// the loop to finish (and skip the wait if the loop never ran).
|
||||
c.runStarted.Store(true)
|
||||
defer c.runOnce.Do(func() { close(c.runExited) })
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
rec := c.statusRecorder
|
||||
@@ -239,18 +214,18 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
|
||||
}()
|
||||
|
||||
wrapErr := state.Wrap
|
||||
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
myPrivateKey, err := wgtypes.ParseKey(c.config.PrivateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", c.config.PrivateKey, err.Error())
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
if c.config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
publicSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
publicSSHKey, err := ssh.GeneratePublicKey([]byte(c.config.SSHKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -284,13 +259,13 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
|
||||
defer c.statusRecorder.ClientStop()
|
||||
operation := func() error {
|
||||
// if context cancelled we not start new backoff cycle
|
||||
if runCtx.Err() != nil {
|
||||
if c.ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
state.Set(StatusConnecting)
|
||||
|
||||
engineCtx, cancel := context.WithCancel(runCtx)
|
||||
engineCtx, cancel := context.WithCancel(c.ctx)
|
||||
defer func() {
|
||||
_, err := state.Status()
|
||||
c.statusRecorder.MarkManagementDisconnected(err)
|
||||
@@ -298,8 +273,8 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
|
||||
cancel()
|
||||
}()
|
||||
|
||||
log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||
}
|
||||
@@ -316,7 +291,7 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
|
||||
}
|
||||
c.clientMetrics.UpdateAgentInfo(agentInfo, myPrivateKey.PublicKey().String())
|
||||
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
|
||||
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
|
||||
defer func() {
|
||||
if err = mgmClient.Close(); err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
@@ -325,14 +300,13 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
|
||||
|
||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
|
||||
loginStarted := time.Now()
|
||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, config)
|
||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
||||
if err != nil {
|
||||
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), false)
|
||||
log.Debug(err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
state.Set(StatusNeedsLogin)
|
||||
// No teardown needed: login fails before the engine is started
|
||||
// (engine.Start is below), so there is nothing running to stop.
|
||||
c.runCancel()
|
||||
return backoff.Permanent(wrapErr(err)) // unrecoverable error
|
||||
}
|
||||
return wrapErr(err)
|
||||
@@ -386,7 +360,7 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
|
||||
}
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, logPath)
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
@@ -430,7 +404,7 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
|
||||
c.engine = engine
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
if err := engine.Start(loginResp.GetNetbirdConfig(), config.ManagementURL); err != nil {
|
||||
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
@@ -438,13 +412,12 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
// The supervisor owns connEstablishedChan and it is always present. Guard
|
||||
// against a double close: operation re-runs on ErrResetConnection retries
|
||||
// within the same run, and the channel is closed only on the first connect.
|
||||
select {
|
||||
case <-connEstablishedChan:
|
||||
default:
|
||||
close(connEstablishedChan)
|
||||
if runningChan != nil {
|
||||
select {
|
||||
case <-runningChan:
|
||||
default:
|
||||
close(runningChan)
|
||||
}
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
@@ -453,10 +426,8 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
|
||||
c.engine = nil
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
// Always tear the engine down once its context is cancelled. engine.Stop
|
||||
// is nil-guarded per component, so calling it unconditionally is safe and
|
||||
// avoids both the data race on engine.wgInterface and skipping teardown
|
||||
// when the interface was never brought up (e.g. a mid-start failure).
|
||||
log.Infof("ensuring wg interface is removed, Netbird engine context cancelled")
|
||||
|
||||
if err := engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
@@ -474,13 +445,12 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
|
||||
}
|
||||
|
||||
c.statusRecorder.ClientStart()
|
||||
err = backoff.Retry(operation, backOff)
|
||||
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
|
||||
if err != nil {
|
||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// Login failed permanently: the engine was never started, so there
|
||||
// is nothing to tear down — just record that a login is needed.
|
||||
state.Set(StatusNeedsLogin)
|
||||
c.runCancel()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -501,22 +471,6 @@ func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
|
||||
return relayCfg.GetUrls(), token
|
||||
}
|
||||
|
||||
// ConnectionRunning reports whether a connection run is currently in flight
|
||||
// (connecting, connected, or reconnecting). Answered by the supervisor via a
|
||||
// serialized query, so it settles behind an in-flight stop. Distinct from
|
||||
// ServiceRunning, which reports whether the service itself is alive.
|
||||
func (c *ConnectClient) ConnectionRunning() bool {
|
||||
return c.sup.isRunning()
|
||||
}
|
||||
|
||||
// ServiceRunning reports whether the client's lifecycle supervisor is alive and
|
||||
// able to accept start/stop commands — i.e. its context has not been cancelled
|
||||
// (the daemon is not shutting down). Independent of whether a connection run is
|
||||
// up (that is ConnectionRunning).
|
||||
func (c *ConnectClient) ServiceRunning() bool {
|
||||
return c.sup.ctx.Err() == nil
|
||||
}
|
||||
|
||||
func (c *ConnectClient) Engine() *Engine {
|
||||
if c == nil {
|
||||
return nil
|
||||
@@ -573,10 +527,12 @@ func (c *ConnectClient) Status() StatusType {
|
||||
return status
|
||||
}
|
||||
|
||||
// Stop serializes a stop request through the lifecycle supervisor and blocks
|
||||
// until the in-flight run is fully torn down.
|
||||
func (c *ConnectClient) Stop() error {
|
||||
return c.sup.stop()
|
||||
c.runCancel()
|
||||
if c.runStarted.Load() {
|
||||
<-c.runExited
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetSyncResponsePersistence enables or disables sync response persistence.
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
@@ -60,17 +59,19 @@ var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
|
||||
var _ dns.ReadyListener = noopDnsReadyListener{}
|
||||
|
||||
func init() {
|
||||
// Wire up the default MobileDependency provider so embed.Client.Start() works
|
||||
// on Android with netstack mode. Provides complete no-op stubs for all mobile
|
||||
// Wire up the default override so embed.Client.Start() works on Android
|
||||
// with netstack mode. Provides complete no-op stubs for all mobile
|
||||
// dependencies so the engine's existing Android code paths work unchanged.
|
||||
// Applications that need P2P ICE or real DNS should replace this by setting
|
||||
// androidMobileDep before calling Start().
|
||||
androidMobileDep = func(config *profilemanager.Config) MobileDependency {
|
||||
return mobileDependencyForEmbed(
|
||||
// Applications that need P2P ICE or real DNS should replace this by
|
||||
// setting androidRunOverride before calling Start().
|
||||
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
|
||||
return c.runOnAndroidEmbed(
|
||||
noopIFaceDiscover{},
|
||||
noopNetworkChangeListener{},
|
||||
[]netip.AddrPort{},
|
||||
noopDnsReadyListener{},
|
||||
runningChan,
|
||||
logPath,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,18 +10,23 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// mobileDependencyForEmbed builds the MobileDependency used by embed.Client on
|
||||
// Android so the engine's existing Android code paths work unchanged.
|
||||
func mobileDependencyForEmbed(
|
||||
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
|
||||
// so embed.Client.Start() can detect when the engine is ready.
|
||||
// It provides complete MobileDependency so the engine's existing
|
||||
// Android code paths work unchanged.
|
||||
func (c *ConnectClient) runOnAndroidEmbed(
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
) MobileDependency {
|
||||
return MobileDependency{
|
||||
runningChan chan struct{},
|
||||
logPath string,
|
||||
) error {
|
||||
mobileDependency := MobileDependency{
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return c.run(mobileDependency, runningChan, logPath)
|
||||
}
|
||||
|
||||
@@ -1,362 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
// errAlreadyRunning is returned when a start is requested while a run is already
|
||||
// in flight.
|
||||
var errAlreadyRunning = errors.New("client is already running")
|
||||
|
||||
// errNoRunInFlight is returned by waitEstablishedOrDone when no run is active.
|
||||
var errNoRunInFlight = errors.New("no connection run in flight")
|
||||
|
||||
// errStoppedBeforeEstablished is returned when a run ended (cleanly) before the
|
||||
// connection was established.
|
||||
var errStoppedBeforeEstablished = errors.New("run stopped before the connection was established")
|
||||
|
||||
// lifecycleOp is a serialized lifecycle operation processed by the supervisor.
|
||||
type lifecycleOp int
|
||||
|
||||
const (
|
||||
opStart lifecycleOp = iota
|
||||
opStop
|
||||
opRestart
|
||||
opStatus
|
||||
opWaitEstablished
|
||||
)
|
||||
|
||||
// lifecycleCmd is a single lifecycle request handed to the supervisor goroutine.
|
||||
// They all flow through the same cmdCh so they are strictly ordered (FIFO) with
|
||||
// respect to each other.
|
||||
type lifecycleCmd struct {
|
||||
op lifecycleOp
|
||||
config *profilemanager.Config
|
||||
md metadata.MD
|
||||
mobileDep MobileDependency
|
||||
logPath string
|
||||
|
||||
// done is the caller's notification channel (nil for fire-and-forget). Its
|
||||
// meaning depends on op:
|
||||
// - opStart: receives the run's end result when the run terminates, or
|
||||
// errAlreadyRunning immediately if a run is already in flight.
|
||||
// - opStop: receives nil once the in-flight run has fully unwound.
|
||||
// - opWaitEstablished: receives the wait outcome (see waitEstablishedOrDone).
|
||||
done chan error
|
||||
|
||||
reply chan bool // opStatus only: receives whether a run is in flight
|
||||
waitCtx context.Context // opWaitEstablished only: the waiter's cancellation context
|
||||
}
|
||||
|
||||
// runState holds the lifecycle channels of a single in-flight run, owned by the
|
||||
// loop goroutine. It never escapes the supervisor as an API; the only readers
|
||||
// are the per-wait goroutines the loop spawns for opWaitEstablished.
|
||||
//
|
||||
// connEstablishedChan is closed by the run once the connection is established.
|
||||
// The supervisor creates and owns it — callers no longer supply it; they observe
|
||||
// it through waitEstablishedOrDone. ended is closed (broadcast) when the run
|
||||
// terminates, so any number of waiters can observe it; err is the run's end
|
||||
// result, valid only after ended is closed.
|
||||
type runState struct {
|
||||
connEstablishedChan chan struct{} // closed by the run on established
|
||||
ended chan struct{} // closed by finishRun when the run terminates
|
||||
err error // run end result, valid after ended is closed
|
||||
}
|
||||
|
||||
// runEndResult is sent by the run goroutine to the supervisor when a run ends,
|
||||
// whether on its own (error / external context cancellation) or because of a Stop.
|
||||
type runEndResult struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// runFunc executes a single client run bound to the supervisor-owned context,
|
||||
// with the config supplied by the start request.
|
||||
type runFunc func(ctx context.Context, config *profilemanager.Config, mobileDep MobileDependency, connEstablishedChan chan struct{}, logPath string) error
|
||||
|
||||
// supervisor serializes start/stop of a single client run. Every request goes
|
||||
// through cmdCh and is handled one at a time by the loop goroutine, so two
|
||||
// lifecycle operations can never overlap and their order is preserved (FIFO).
|
||||
// The loop goroutine is the sole owner of curStart/runCancel, so that state
|
||||
// needs no locking. The loop exits when the parent context is cancelled.
|
||||
type supervisor struct {
|
||||
ctx context.Context
|
||||
run runFunc
|
||||
cmdCh chan lifecycleCmd
|
||||
runEnded chan runEndResult
|
||||
|
||||
// owned exclusively by the loop goroutine. curStart is the in-flight start
|
||||
// command (nil = idle); its done channel is notified when the run ends.
|
||||
// curRun holds that run's lifecycle channels; runCancel cancels it.
|
||||
curStart *lifecycleCmd
|
||||
curRun *runState
|
||||
runCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newSupervisor(ctx context.Context, run runFunc) *supervisor {
|
||||
s := &supervisor{
|
||||
ctx: ctx,
|
||||
run: run,
|
||||
cmdCh: make(chan lifecycleCmd, 16),
|
||||
runEnded: make(chan runEndResult, 1),
|
||||
}
|
||||
go s.loop()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *supervisor) loop() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
s.shutdown()
|
||||
return
|
||||
case cmd := <-s.cmdCh:
|
||||
switch cmd.op {
|
||||
case opStart:
|
||||
s.handleStart(cmd)
|
||||
case opStop:
|
||||
s.handleStop(cmd)
|
||||
case opRestart:
|
||||
s.handleRestart(cmd)
|
||||
case opStatus:
|
||||
cmd.reply <- (s.isRunningInternal())
|
||||
case opWaitEstablished:
|
||||
s.handleWaitEstablished(cmd)
|
||||
}
|
||||
case res := <-s.runEnded:
|
||||
// Run ended on its own, without an explicit Stop.
|
||||
s.finishRun(res.err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *supervisor) handleStart(cmd lifecycleCmd) {
|
||||
if s.isRunningInternal() {
|
||||
notify(cmd.done, errAlreadyRunning)
|
||||
return
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithCancel(s.ctx)
|
||||
if cmd.md != nil {
|
||||
// Carry caller-supplied gRPC metadata (e.g. UI user-agent) into the run
|
||||
// context so the engine's management/signal calls forward it. The cancel
|
||||
// still drives runCtx (metadata wrapping preserves cancellation).
|
||||
runCtx = metadata.NewOutgoingContext(runCtx, cmd.md)
|
||||
}
|
||||
s.runCancel = cancel
|
||||
s.curStart = &cmd
|
||||
s.curRun = &runState{connEstablishedChan: make(chan struct{}), ended: make(chan struct{})}
|
||||
|
||||
go func(ctx context.Context, cfg *profilemanager.Config, m MobileDependency, established chan struct{}, lp string) {
|
||||
err := s.run(ctx, cfg, m, established, lp)
|
||||
s.runEnded <- runEndResult{err: err}
|
||||
}(runCtx, cmd.config, cmd.mobileDep, s.curRun.connEstablishedChan, cmd.logPath)
|
||||
}
|
||||
|
||||
func (s *supervisor) handleStop(cmd lifecycleCmd) {
|
||||
if !s.isRunningInternal() {
|
||||
notify(cmd.done, nil)
|
||||
return
|
||||
}
|
||||
s.stopCurrentRun()
|
||||
notify(cmd.done, nil)
|
||||
}
|
||||
|
||||
// handleRestart tears down any in-flight run and starts a fresh one in a single
|
||||
// loop turn. No other command can interleave between the stop and the start
|
||||
// (the loop is single-threaded), so the swap is atomic without relying on any
|
||||
// daemon-side lock — that is what an explicit restart (e.g. MDM config change)
|
||||
// needs to avoid a window where the client is observably stopped.
|
||||
func (s *supervisor) handleRestart(cmd lifecycleCmd) {
|
||||
if s.isRunningInternal() {
|
||||
s.stopCurrentRun()
|
||||
}
|
||||
s.handleStart(cmd)
|
||||
}
|
||||
|
||||
// stopCurrentRun cancels the in-flight run and blocks the supervisor until it
|
||||
// has fully unwound, so the next action starts from a clean slate. The run
|
||||
// goroutine reports completion via runEnded. Caller must hold an in-flight run
|
||||
// (curStart != nil).
|
||||
func (s *supervisor) stopCurrentRun() {
|
||||
s.runCancel()
|
||||
res := <-s.runEnded
|
||||
s.finishRun(res.err)
|
||||
}
|
||||
|
||||
// finishRun resets lifecycle state after a run terminates and hands the run
|
||||
// error back to whoever asked to be notified of the start.
|
||||
func (s *supervisor) finishRun(err error) {
|
||||
s.runCancel = nil
|
||||
if s.isRunningInternal() {
|
||||
// Publish the result to the broadcast channel before nil-ing curRun, so
|
||||
// any opWaitEstablished goroutines blocked on ended observe err.
|
||||
s.curRun.err = err
|
||||
close(s.curRun.ended)
|
||||
s.curRun = nil
|
||||
|
||||
notify(s.curStart.done, err)
|
||||
s.curStart = nil
|
||||
}
|
||||
}
|
||||
|
||||
// handleWaitEstablished answers an opWaitEstablished request. The select itself
|
||||
// runs in a spawned goroutine on the run's channels so it never blocks the loop;
|
||||
// the loop only snapshots the in-flight run's channels (which it owns) here.
|
||||
func (s *supervisor) handleWaitEstablished(cmd lifecycleCmd) {
|
||||
caller := cmd.done
|
||||
if !s.isRunningInternal() {
|
||||
notify(caller, errNoRunInFlight)
|
||||
return
|
||||
}
|
||||
rs := s.curRun
|
||||
established := rs.connEstablishedChan
|
||||
ctx := cmd.waitCtx
|
||||
go func() {
|
||||
select {
|
||||
case <-established:
|
||||
notify(caller, nil)
|
||||
case <-rs.ended:
|
||||
if rs.err != nil {
|
||||
notify(caller, rs.err)
|
||||
return
|
||||
}
|
||||
notify(caller, errStoppedBeforeEstablished)
|
||||
case <-ctx.Done():
|
||||
notify(caller, ctx.Err())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// shutdown tears down the in-flight run when the parent context is cancelled,
|
||||
// then fails any still-queued commands so their callers never hang.
|
||||
func (s *supervisor) shutdown() {
|
||||
if s.runCancel != nil {
|
||||
s.runCancel()
|
||||
res := <-s.runEnded
|
||||
s.finishRun(res.err)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case cmd := <-s.cmdCh:
|
||||
notify(cmd.done, s.ctx.Err())
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startAsync enqueues a start without blocking. If done is non-nil it receives
|
||||
// the run's end result (or errAlreadyRunning on rejection, or the context error
|
||||
// on shutdown).
|
||||
func (s *supervisor) startAsync(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string, done chan error) {
|
||||
cmd := lifecycleCmd{op: opStart, config: config, md: md, mobileDep: mobileDep, logPath: logPath, done: done}
|
||||
select {
|
||||
case s.cmdCh <- cmd:
|
||||
case <-s.ctx.Done():
|
||||
notify(done, s.ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
// restartAsync enqueues an atomic stop+start without blocking. The supervisor
|
||||
// tears down any in-flight run and starts a fresh one with the supplied config
|
||||
// in a single loop turn (see handleRestart). Fire-and-forget: the new run owns
|
||||
// its lifecycle channels, observed via waitEstablishedOrDone.
|
||||
func (s *supervisor) restartAsync(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string) {
|
||||
cmd := lifecycleCmd{op: opRestart, config: config, md: md, mobileDep: mobileDep, logPath: logPath}
|
||||
select {
|
||||
case s.cmdCh <- cmd:
|
||||
case <-s.ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
// start enqueues a start and blocks until the run terminates, preserving the
|
||||
// blocking contract of the legacy Run entry points.
|
||||
func (s *supervisor) start(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string) error {
|
||||
done := make(chan error, 1)
|
||||
s.startAsync(config, md, mobileDep, logPath, done)
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// isRunning asks the loop whether a run is in flight. The query is serialized
|
||||
// with start/stop, so during a stop it waits for the teardown to settle and
|
||||
// then reports the final state — never a transient "half-stopped".
|
||||
func (s *supervisor) isRunning() bool {
|
||||
reply := make(chan bool, 1)
|
||||
select {
|
||||
case s.cmdCh <- lifecycleCmd{op: opStatus, reply: reply}:
|
||||
case <-s.ctx.Done():
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case r := <-reply:
|
||||
return r
|
||||
case <-s.ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *supervisor) isRunningInternal() bool {
|
||||
return s.curStart != nil
|
||||
}
|
||||
|
||||
// waitEstablishedOrDone blocks until the in-flight run becomes established
|
||||
// (returns nil) or ends before that (returns the run error, or
|
||||
// errStoppedBeforeEstablished on a clean stop), or ctx is cancelled. Returns
|
||||
// errNoRunInFlight if no run is in flight. The wait is performed by a goroutine
|
||||
// spawned inside the loop (see handleWaitEstablished); the run's channels never
|
||||
// leave the supervisor.
|
||||
func (s *supervisor) waitEstablishedOrDone(ctx context.Context) error {
|
||||
reply := make(chan error, 1)
|
||||
select {
|
||||
case s.cmdCh <- lifecycleCmd{op: opWaitEstablished, waitCtx: ctx, done: reply}:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
select {
|
||||
case err := <-reply:
|
||||
return err
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// stop enqueues a stop and blocks until the in-flight run is fully torn down.
|
||||
func (s *supervisor) stop() error {
|
||||
done := make(chan error, 1)
|
||||
select {
|
||||
case s.cmdCh <- lifecycleCmd{op: opStop, done: done}:
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// notify sends on a caller-supplied channel without blocking. The channel is
|
||||
// expected to be buffered (cap 1); a nil channel means the caller did not ask
|
||||
// to be notified.
|
||||
func notify(ch chan error, err error) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -51,13 +51,20 @@ type cachedRecord struct {
|
||||
}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
|
||||
// guarded by mutex.
|
||||
type Resolver struct {
|
||||
records map[dns.Question]*cachedRecord
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
|
||||
// failedResolves records the last failed initial resolve per domain so a
|
||||
// domain that never resolves isn't retried on every server-domains update
|
||||
// until refreshBackoff elapses. Entries are cleared on success and pruned
|
||||
// to the current server-domains set.
|
||||
failedResolves map[domain.Domain]time.Time
|
||||
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
@@ -76,9 +83,10 @@ type Resolver struct {
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
failedResolves: make(map[domain.Domain]time.Time),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,7 +181,9 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||
// entry for that qtype.
|
||||
// entry for that qtype. When one family hard-errors while the other succeeds,
|
||||
// the resolved family is still cached but AddDomain returns an error so the
|
||||
// caller retries the incomplete resolve rather than treating it as complete.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
@@ -203,6 +213,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
|
||||
if errA != nil || errAAAA != nil {
|
||||
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -462,6 +476,7 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
delete(m.records, qAAAA)
|
||||
delete(m.refreshing, qA)
|
||||
delete(m.refreshing, qAAAA)
|
||||
delete(m.failedResolves, d)
|
||||
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
@@ -505,6 +520,7 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
|
||||
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
|
||||
currentDomains := m.GetCachedDomains()
|
||||
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
||||
m.pruneFailedResolves(allDomains)
|
||||
}
|
||||
|
||||
m.addNewDomains(ctx, newDomains)
|
||||
@@ -577,13 +593,85 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
|
||||
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
||||
}
|
||||
|
||||
// addNewDomains resolves and caches all domains from the update
|
||||
// addNewDomains resolves and caches domains that are not yet in the cache,
|
||||
// running the lookups concurrently. Domains already cached are skipped and left
|
||||
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
|
||||
// synchronously: once NetBird owns the OS resolver the resolve runs through the
|
||||
// handler chain and would otherwise dial the managed upstreams under the engine
|
||||
// sync lock on every update.
|
||||
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
||||
var wg sync.WaitGroup
|
||||
seen := make(map[domain.Domain]struct{}, len(newDomains))
|
||||
for _, newDomain := range newDomains {
|
||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
||||
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||
} else {
|
||||
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
||||
if _, dup := seen[newDomain]; dup {
|
||||
continue
|
||||
}
|
||||
seen[newDomain] = struct{}{}
|
||||
|
||||
if !m.needsResolve(newDomain) {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(d domain.Domain) {
|
||||
defer wg.Done()
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
m.markResolveFailed(d)
|
||||
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
|
||||
return
|
||||
}
|
||||
m.clearResolveFailed(d)
|
||||
log.Debugf("added/updated management cache domain=%s", d.SafeString())
|
||||
}(newDomain)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// needsResolve reports whether d should be resolved now. A recent failed or
|
||||
// incomplete resolve gates retries on the backoff even when one family is
|
||||
// already cached, so a transiently-failed family is retried instead of being
|
||||
// treated as fully resolved. Otherwise a domain with any cached record is left
|
||||
// to the stale-while-revalidate refresh path.
|
||||
func (m *Resolver) needsResolve(d domain.Domain) bool {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
if failedAt, ok := m.failedResolves[d]; ok {
|
||||
return time.Since(failedAt) >= refreshBackoff
|
||||
}
|
||||
|
||||
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
if _, ok := m.records[q]; ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Resolver) markResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
m.failedResolves[d] = time.Now()
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Resolver) clearResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
delete(m.failedResolves, d)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// pruneFailedResolves drops failure markers for domains no longer present in
|
||||
// the server-domains set, keeping the map bounded to the current set (a
|
||||
// failed-only domain has no cached record, so RemoveDomain never sees it).
|
||||
func (m *Resolver) pruneFailedResolves(domains domain.List) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
for d := range m.failedResolves {
|
||||
if !slices.Contains(domains, d) {
|
||||
delete(m.failedResolves, d)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ type fakeChain struct {
|
||||
mu sync.Mutex
|
||||
calls map[string]int
|
||||
answers map[string][]dns.RR
|
||||
qErr map[string]error
|
||||
err error
|
||||
hasRoot bool
|
||||
onLookup func()
|
||||
@@ -30,6 +31,7 @@ func newFakeChain() *fakeChain {
|
||||
return &fakeChain{
|
||||
calls: map[string]int{},
|
||||
answers: map[string][]dns.RR{},
|
||||
qErr: map[string]error{},
|
||||
hasRoot: true,
|
||||
}
|
||||
}
|
||||
@@ -47,6 +49,9 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
|
||||
f.calls[key]++
|
||||
answers := f.answers[key]
|
||||
err := f.err
|
||||
if err == nil {
|
||||
err = f.qErr[key]
|
||||
}
|
||||
onLookup := f.onLookup
|
||||
f.mu.Unlock()
|
||||
|
||||
@@ -75,6 +80,12 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
|
||||
}
|
||||
|
||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
183
client/internal/dns/mgmt/mgmt_resolve_test.go
Normal file
183
client/internal/dns/mgmt/mgmt_resolve_test.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// A domain already in the cache must not be re-resolved on a subsequent server
|
||||
// domains update; it is left to the stale-while-revalidate refresh path.
|
||||
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must resolve the domain")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"cached domain must not be re-resolved on a subsequent update")
|
||||
}
|
||||
|
||||
// New domains in a single update must resolve concurrently rather than serially.
|
||||
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
|
||||
var inflight, maxInflight atomic.Int32
|
||||
chain.onLookup = func() {
|
||||
n := inflight.Add(1)
|
||||
for {
|
||||
old := maxInflight.Load()
|
||||
if n <= old || maxInflight.CompareAndSwap(old, n) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
inflight.Add(-1)
|
||||
}
|
||||
|
||||
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
|
||||
for _, d := range relays {
|
||||
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
|
||||
}
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
start := time.Now()
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
|
||||
require.NoError(t, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
|
||||
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
|
||||
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
|
||||
}
|
||||
|
||||
// A domain that fails to resolve must not be retried on every update; the
|
||||
// failure backoff suppresses re-resolution until it expires.
|
||||
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must attempt the resolve")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"failed resolve must back off and not retry on the next update")
|
||||
}
|
||||
|
||||
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
|
||||
// the same host) must be resolved once per update, not once per occurrence.
|
||||
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{
|
||||
Stuns: []domain.Domain{"dup.example.com"},
|
||||
Turns: []domain.Domain{"dup.example.com"},
|
||||
}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
|
||||
"a domain appearing under multiple server-domain types must resolve once")
|
||||
}
|
||||
|
||||
// A failure marker must be dropped once its domain leaves the server-domains set
|
||||
// so the map stays bounded to the current set.
|
||||
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, marked, "failed resolve must be recorded")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
|
||||
}
|
||||
|
||||
// When one family hard-errors while the other resolves, the domain is cached
|
||||
// for the working family but recorded as incomplete so the failed family is
|
||||
// retried under backoff instead of being treated as fully resolved forever.
|
||||
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
|
||||
d := domain.Domain("relay.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, aCached, "the working family must still be cached")
|
||||
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
|
||||
|
||||
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
|
||||
|
||||
r.mutex.Lock()
|
||||
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
|
||||
r.mutex.Unlock()
|
||||
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
|
||||
}
|
||||
|
||||
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
|
||||
// not a failure: the domain must not be marked for retry, otherwise it would be
|
||||
// re-resolved on every sync.
|
||||
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
|
||||
d := domain.Domain("v4only.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
|
||||
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
|
||||
}
|
||||
@@ -207,3 +207,35 @@ func FormatAnswers(answers []dns.RR) string {
|
||||
}
|
||||
return "[" + strings.Join(parts, ", ") + "]"
|
||||
}
|
||||
|
||||
// StripOPT removes any OPT pseudo-RRs from the message's Extra section. Per
|
||||
// RFC 6891 a responder must not include an OPT RR toward a client that did not
|
||||
// advertise EDNS0.
|
||||
func StripOPT(msg *dns.Msg) {
|
||||
if len(msg.Extra) == 0 {
|
||||
return
|
||||
}
|
||||
out := msg.Extra[:0]
|
||||
for _, rr := range msg.Extra {
|
||||
if _, ok := rr.(*dns.OPT); ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
msg.Extra = out
|
||||
}
|
||||
|
||||
// ExtractEDE returns the first Extended DNS Error (RFC 8914) option carried in
|
||||
// the message, if present.
|
||||
func ExtractEDE(msg *dns.Msg) (*dns.EDNS0_EDE, bool) {
|
||||
opt := msg.IsEdns0()
|
||||
if opt == nil {
|
||||
return nil, false
|
||||
}
|
||||
for _, o := range opt.Option {
|
||||
if ede, ok := o.(*dns.EDNS0_EDE); ok {
|
||||
return ede, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -120,3 +120,42 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
||||
}
|
||||
|
||||
func TestStripOPT(t *testing.T) {
|
||||
rm := &dns.Msg{
|
||||
Extra: []dns.RR{
|
||||
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
||||
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
}
|
||||
StripOPT(rm)
|
||||
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
||||
_, isOPT := rm.Extra[0].(*dns.OPT)
|
||||
assert.False(t, isOPT, "remaining record must not be OPT")
|
||||
}
|
||||
|
||||
func TestExtractEDE(t *testing.T) {
|
||||
t.Run("no edns", func(t *testing.T) {
|
||||
_, ok := ExtractEDE(&dns.Msg{})
|
||||
assert.False(t, ok, "message without OPT has no EDE")
|
||||
})
|
||||
|
||||
t.Run("edns without ede", func(t *testing.T) {
|
||||
rm := &dns.Msg{}
|
||||
rm.SetEdns0(4096, false)
|
||||
_, ok := ExtractEDE(rm)
|
||||
assert.False(t, ok, "OPT without EDE option returns false")
|
||||
})
|
||||
|
||||
t.Run("with ede", func(t *testing.T) {
|
||||
rm := &dns.Msg{}
|
||||
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: 49152, ExtraText: "upstream timeout"})
|
||||
rm.Extra = append(rm.Extra, opt)
|
||||
|
||||
ede, ok := ExtractEDE(rm)
|
||||
assert.True(t, ok, "EDE option should be found")
|
||||
assert.Equal(t, uint16(49152), ede.InfoCode)
|
||||
assert.Equal(t, "upstream timeout", ede.ExtraText)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -38,11 +39,15 @@ const (
|
||||
// defaultWarningDelayBase is the starting grace window before a
|
||||
// "Nameserver group unreachable" event fires for a group that's
|
||||
// never been healthy and only has overlay upstreams with no
|
||||
// Connected peer. Per-server and overridable; see warningDelayFor.
|
||||
defaultWarningDelayBase = 30 * time.Second
|
||||
// Connected peer. Per-server and overridable via envWarningDelay;
|
||||
// see warningDelay.
|
||||
defaultWarningDelayBase = 60 * time.Second
|
||||
// warningDelayBonusCap caps the route-count bonus added to the
|
||||
// base grace window. See warningDelayFor.
|
||||
// base grace window. See warningDelay.
|
||||
warningDelayBonusCap = 30 * time.Second
|
||||
// envWarningDelay overrides defaultWarningDelayBase with a Go duration
|
||||
// string (e.g. "90s", "2m"). Invalid or non-positive values are ignored.
|
||||
envWarningDelay = "NB_DNS_HEALTH_WARNING_DELAY"
|
||||
)
|
||||
|
||||
// errNoUsableNameservers signals that a merged-domain group has no usable
|
||||
@@ -135,7 +140,7 @@ type DefaultServer struct {
|
||||
disableSys bool
|
||||
mux sync.Mutex
|
||||
service service
|
||||
dnsMuxMap registeredHandlerMap
|
||||
dnsMuxHandlers []handlerWrapper
|
||||
localResolver *local.Resolver
|
||||
wgInterface WGIface
|
||||
hostManager hostManager
|
||||
@@ -199,8 +204,6 @@ type handlerWrapper struct {
|
||||
priority int
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[types.HandlerID]handlerWrapper
|
||||
|
||||
// DefaultServerConfig holds configuration parameters for NewDefaultServer
|
||||
type DefaultServerConfig struct {
|
||||
WgInterface WGIface
|
||||
@@ -289,7 +292,6 @@ func newDefaultServer(
|
||||
service: dnsService,
|
||||
handlerChain: handlerChain,
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: local.NewResolver(),
|
||||
wgInterface: wgInterface,
|
||||
statusRecorder: statusRecorder,
|
||||
@@ -298,7 +300,7 @@ func newDefaultServer(
|
||||
hostManager: &noopHostConfigurator{},
|
||||
mgmtCacheResolver: mgmtCacheResolver,
|
||||
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
warningDelayBase: warningDelayBaseFromEnv(),
|
||||
healthRefresh: make(chan struct{}, 1),
|
||||
}
|
||||
// Wire the local resolver against the peer status recorder so it can
|
||||
@@ -328,7 +330,7 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) {
|
||||
type routeSettable interface {
|
||||
setSelectedRoutes(func() route.HAMap)
|
||||
}
|
||||
for _, entry := range s.dnsMuxMap {
|
||||
for _, entry := range s.dnsMuxHandlers {
|
||||
if h, ok := entry.handler.(routeSettable); ok {
|
||||
h.setSelectedRoutes(selected)
|
||||
}
|
||||
@@ -978,19 +980,23 @@ func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []neti
|
||||
|
||||
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||
for _, existing := range s.dnsMuxMap {
|
||||
for _, existing := range s.dnsMuxHandlers {
|
||||
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||
existing.handler.Stop()
|
||||
// The local resolver is a persistent singleton shared by every custom
|
||||
// zone and reused across config updates. Its chain registrations are
|
||||
// per-config and must be deregistered, but Stop() cancels its lookup
|
||||
// context (breaking external CNAME-target resolution) and clears its
|
||||
// records, so it must not be torn down here.
|
||||
if existing.handler != s.localResolver {
|
||||
existing.handler.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
muxUpdateMap := make(registeredHandlerMap)
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||
muxUpdateMap[update.handler.ID()] = update
|
||||
}
|
||||
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
s.dnsMuxHandlers = muxUpdates
|
||||
}
|
||||
|
||||
// updateNSGroupStates records the new group set and pokes the refresher.
|
||||
@@ -1154,6 +1160,26 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor
|
||||
return false
|
||||
}
|
||||
|
||||
// warningDelayBaseFromEnv returns the base grace window, honoring
|
||||
// envWarningDelay when it holds a valid positive Go duration. Invalid or
|
||||
// non-positive values fall back to defaultWarningDelayBase.
|
||||
func warningDelayBaseFromEnv() time.Duration {
|
||||
val := os.Getenv(envWarningDelay)
|
||||
if val == "" {
|
||||
return defaultWarningDelayBase
|
||||
}
|
||||
d, err := time.ParseDuration(val)
|
||||
if err != nil {
|
||||
log.Warnf("invalid %s value %q, using default %v: %v", envWarningDelay, val, defaultWarningDelayBase, err)
|
||||
return defaultWarningDelayBase
|
||||
}
|
||||
if d <= 0 {
|
||||
log.Warnf("%s must be positive, got %v, using default %v", envWarningDelay, d, defaultWarningDelayBase)
|
||||
return defaultWarningDelayBase
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// warningDelay returns the grace window for the given selected-route
|
||||
// count. Scales gently: +1s per 100 routes, capped by
|
||||
// warningDelayBonusCap. Parallel handshakes mean handshake time grows
|
||||
@@ -1204,7 +1230,7 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap
|
||||
// in more than one handler.
|
||||
func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||
merged := make(map[netip.AddrPort]UpstreamHealth)
|
||||
for _, entry := range s.dnsMuxMap {
|
||||
for _, entry := range s.dnsMuxHandlers {
|
||||
reporter, ok := entry.handler.(upstreamHealthReporter)
|
||||
if !ok {
|
||||
continue
|
||||
|
||||
@@ -104,19 +104,6 @@ func init() {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []netip.AddrPort
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, srv.AddrPort())
|
||||
}
|
||||
u := &upstreamResolverBase{
|
||||
domain: domain.Domain(d),
|
||||
cancel: func() {},
|
||||
}
|
||||
u.addRace(srvs)
|
||||
return u
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
@@ -132,22 +119,20 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
dummyHandler := local.NewResolver()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap registeredHandlerMap
|
||||
initUpstreamMap []handlerWrapper
|
||||
initLocalZones []nbdns.CustomZone
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
shouldFail bool
|
||||
expectedUpstreamMap registeredHandlerMap
|
||||
expectedUpstreamMap []handlerWrapper
|
||||
expectedLocalQs []dns.Question
|
||||
}{
|
||||
{
|
||||
name: "Initial Config Should Succeed",
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -169,20 +154,17 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
dummyHandler.ID(): handlerWrapper{
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
{
|
||||
domain: nbdns.RootZone,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
@@ -191,10 +173,10 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
@@ -215,15 +197,13 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
"local-resolver": handlerWrapper{
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
},
|
||||
@@ -232,7 +212,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
@@ -240,7 +220,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -262,7 +242,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -284,7 +264,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -301,42 +281,41 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
expectedUpstreamMap: []handlerWrapper{{
|
||||
domain: ".",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedUpstreamMap: nil,
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedUpstreamMap: nil,
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
}
|
||||
@@ -393,7 +372,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
@@ -405,14 +384,20 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
|
||||
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
|
||||
}
|
||||
|
||||
for key := range testCase.expectedUpstreamMap {
|
||||
_, found := dnsServer.dnsMuxMap[key]
|
||||
for _, expected := range testCase.expectedUpstreamMap {
|
||||
found := false
|
||||
for _, got := range dnsServer.dnsMuxHandlers {
|
||||
if got.domain == expected.domain && got.priority == expected.priority {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
|
||||
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -512,8 +497,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||
"id1": handlerWrapper{
|
||||
dnsServer.dnsMuxHandlers = []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityUpstream,
|
||||
@@ -1029,15 +1014,15 @@ func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||
func (m *mockService) DeregisterMux(string) {}
|
||||
|
||||
func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
baseMatchHandlers := registeredHandlerMap{
|
||||
"upstream-group1": {
|
||||
baseMatchHandlers := []handlerWrapper{
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
"upstream-group2": {
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
@@ -1046,15 +1031,15 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
baseRootHandlers := registeredHandlerMap{
|
||||
"upstream-root1": {
|
||||
baseRootHandlers := []handlerWrapper{
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root1",
|
||||
},
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
"upstream-root2": {
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root2",
|
||||
@@ -1063,22 +1048,22 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
baseMixedHandlers := registeredHandlerMap{
|
||||
"upstream-group1": {
|
||||
baseMixedHandlers := []handlerWrapper{
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
"upstream-group2": {
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityUpstream - 1,
|
||||
},
|
||||
"upstream-other": {
|
||||
{
|
||||
domain: "other.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-other",
|
||||
@@ -1089,7 +1074,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialHandlers registeredHandlerMap
|
||||
initialHandlers []handlerWrapper
|
||||
updates []handlerWrapper
|
||||
expectedHandlers map[string]string // map[HandlerID]domain
|
||||
description string
|
||||
@@ -1373,32 +1358,38 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := &DefaultServer{
|
||||
dnsMuxMap: tt.initialHandlers,
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
dnsMuxHandlers: tt.initialHandlers,
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
}
|
||||
|
||||
// Perform the update
|
||||
server.updateMux(tt.updates)
|
||||
|
||||
// Verify the results
|
||||
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
|
||||
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxHandlers),
|
||||
"Number of handlers after update doesn't match expected")
|
||||
|
||||
// Check each expected handler
|
||||
for id, expectedDomain := range tt.expectedHandlers {
|
||||
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
|
||||
assert.True(t, exists, "Expected handler %s not found", id)
|
||||
if exists {
|
||||
assert.Equal(t, expectedDomain, handler.domain,
|
||||
var found *handlerWrapper
|
||||
for i := range server.dnsMuxHandlers {
|
||||
if server.dnsMuxHandlers[i].handler.ID() == types.HandlerID(id) {
|
||||
found = &server.dnsMuxHandlers[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotNil(t, found, "Expected handler %s not found", id)
|
||||
if found != nil {
|
||||
assert.Equal(t, expectedDomain, found.domain,
|
||||
"Domain mismatch for handler %s", id)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no unexpected handlers exist
|
||||
for HandlerID := range server.dnsMuxMap {
|
||||
_, expected := tt.expectedHandlers[string(HandlerID)]
|
||||
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
|
||||
for _, entry := range server.dnsMuxHandlers {
|
||||
_, expected := tt.expectedHandlers[string(entry.handler.ID())]
|
||||
assert.True(t, expected, "Unexpected handler found: %s", entry.handler.ID())
|
||||
}
|
||||
|
||||
// Verify the handlerChain state and order
|
||||
@@ -1413,7 +1404,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
|
||||
// Verify handler exists in mux
|
||||
foundInMux := false
|
||||
for _, muxEntry := range server.dnsMuxMap {
|
||||
for _, muxEntry := range server.dnsMuxHandlers {
|
||||
if chainEntry.Handler == muxEntry.handler &&
|
||||
chainEntry.Priority == muxEntry.priority &&
|
||||
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
||||
@@ -1422,12 +1413,108 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
}
|
||||
}
|
||||
assert.True(t, foundInMux,
|
||||
"Handler in chain not found in dnsMuxMap")
|
||||
"Handler in chain not found in dnsMuxHandlers")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// chainHasPattern reports whether the handler chain holds an entry registered
|
||||
// for the given fqdn pattern at the given priority.
|
||||
func chainHasPattern(s *DefaultServer, pattern string, priority int) bool {
|
||||
for _, h := range s.handlerChain.handlers {
|
||||
if h.OrigPattern == pattern && h.Priority == priority {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval verifies that updateMux
|
||||
// tracks each (handler, domain) registration independently when one handler
|
||||
// serves multiple zones. Every custom zone is served by the same handler
|
||||
// instance (the local resolver, whose ID is the constant "local-resolver"), so
|
||||
// removing one zone must deregister exactly that zone's chain entry and leave
|
||||
// the others in place. Tracking registrations by handler ID alone collapses all
|
||||
// zones onto one entry, leaving removed zones in the chain to answer
|
||||
// authoritatively with no records.
|
||||
func TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval(t *testing.T) {
|
||||
// One handler serves every custom zone, mirroring s.localResolver.
|
||||
shared := &mockHandler{Id: "local-resolver"}
|
||||
|
||||
server := &DefaultServer{
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
}
|
||||
|
||||
// Two custom zones under the same handler. The surviving zone is registered
|
||||
// last, mirroring the management emission order.
|
||||
server.updateMux([]handlerWrapper{
|
||||
{domain: "userzone.test", handler: shared, priority: PriorityLocal},
|
||||
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
|
||||
})
|
||||
|
||||
require.True(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
|
||||
"userzone.test should be registered after the first update")
|
||||
require.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
|
||||
"peerzone.test should be registered after the first update")
|
||||
|
||||
// Remove one zone, keep the other.
|
||||
server.updateMux([]handlerWrapper{
|
||||
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
|
||||
})
|
||||
|
||||
assert.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
|
||||
"peerzone.test should remain after removing userzone.test")
|
||||
assert.False(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
|
||||
"userzone.test handler must be deregistered, not leaked in the chain")
|
||||
}
|
||||
|
||||
// TestDefaultServer_UpdateMux_PreservesLocalResolver verifies that updateMux
|
||||
// does not tear down the shared local resolver during reconfiguration. The
|
||||
// resolver is a process-lifetime singleton reused across config updates;
|
||||
// Stop() cancels its lookup context (breaking external CNAME-target
|
||||
// resolution) and clears its records. updateMux must deregister its chain
|
||||
// entries without stopping it. Records surviving a teardown update is the
|
||||
// observable proxy: Stop() would have cleared them.
|
||||
func TestDefaultServer_UpdateMux_PreservesLocalResolver(t *testing.T) {
|
||||
resolver := local.NewResolver()
|
||||
require.NoError(t, resolver.RegisterRecord(nbdns.SimpleRecord{
|
||||
Name: "peer.netbird.cloud.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "10.0.0.1",
|
||||
}))
|
||||
|
||||
server := &DefaultServer{
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
localResolver: resolver,
|
||||
}
|
||||
|
||||
server.updateMux([]handlerWrapper{
|
||||
{domain: "netbird.cloud", handler: resolver, priority: PriorityLocal},
|
||||
})
|
||||
|
||||
// Remove the zone. The resolver must survive so its records and lookup
|
||||
// context stay intact for the next registration.
|
||||
server.updateMux(nil)
|
||||
|
||||
var response *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
response = m
|
||||
return nil
|
||||
},
|
||||
}, &dns.Msg{Question: []dns.Question{{Name: "peer.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}})
|
||||
|
||||
require.NotNil(t, response, "local resolver should answer after teardown")
|
||||
assert.Equal(t, dns.RcodeSuccess, response.Rcode,
|
||||
"local resolver records must survive teardown; updateMux must not Stop() the shared resolver")
|
||||
assert.NotEmpty(t, response.Answer, "answer should contain the surviving record")
|
||||
}
|
||||
|
||||
func TestExtraDomains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -2049,7 +2136,6 @@ func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
}
|
||||
|
||||
groups := []*nbdns.NameServerGroup{
|
||||
@@ -2207,7 +2293,7 @@ func TestEvaluateNSGroupHealth(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
|
||||
// healthStubHandler is a minimal dnsMuxHandlers entry that exposes a fixed
|
||||
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
|
||||
// without spinning up real handlers.
|
||||
type healthStubHandler struct {
|
||||
@@ -2283,12 +2369,11 @@ func newProjTestFixture(t *testing.T) *projTestFixture {
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return fx.selected },
|
||||
activeRoutes: func() route.HAMap { return fx.active },
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
}
|
||||
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
|
||||
fx.server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}}
|
||||
|
||||
fx.server.mux.Lock()
|
||||
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
|
||||
@@ -2395,7 +2480,6 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return nil },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: 50 * time.Millisecond,
|
||||
@@ -2407,7 +2491,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
|
||||
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
|
||||
}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
@@ -2444,7 +2528,6 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
service: NewServiceViaMemory(wgIface),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
extraDomains: map[domain.Domain]int{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
statusRecorder: peer.NewRecorder("mgm"),
|
||||
selectedRoutes: func() route.HAMap { return nil },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
@@ -2459,7 +2542,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||
}
|
||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
@@ -2484,6 +2567,32 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
// rule 3: startup failures while the peer is handshaking, then the peer
|
||||
// comes up and a query succeeds before the grace window elapses. No
|
||||
// warning should ever have fired, and no recovery either.
|
||||
func TestWarningDelayBaseFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
set bool
|
||||
val string
|
||||
want time.Duration
|
||||
}{
|
||||
{name: "unset uses default", set: false, want: defaultWarningDelayBase},
|
||||
{name: "valid override", set: true, val: "90s", want: 90 * time.Second},
|
||||
{name: "valid minutes", set: true, val: "2m", want: 2 * time.Minute},
|
||||
{name: "invalid falls back", set: true, val: "notaduration", want: defaultWarningDelayBase},
|
||||
{name: "zero falls back", set: true, val: "0s", want: defaultWarningDelayBase},
|
||||
{name: "negative falls back", set: true, val: "-30s", want: defaultWarningDelayBase},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(envWarningDelay, tc.val)
|
||||
if !tc.set {
|
||||
os.Unsetenv(envWarningDelay)
|
||||
}
|
||||
assert.Equal(t, tc.want, warningDelayBaseFromEnv(), "grace window base")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
fx.server.warningDelayBase = 200 * time.Millisecond
|
||||
@@ -2595,7 +2704,6 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return overlayMap },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: time.Hour,
|
||||
@@ -2613,7 +2721,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
||||
overlay: {LastFail: time.Now(), LastErr: "timeout"},
|
||||
},
|
||||
}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
@@ -2640,7 +2748,6 @@ func TestDNSLoopPrevention(t *testing.T) {
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
||||
@@ -443,29 +443,32 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||
}
|
||||
|
||||
// A valid response means the upstream is reachable, whatever the Rcode.
|
||||
u.markUpstreamOk(upstream)
|
||||
|
||||
proto := ""
|
||||
if upstreamProto != nil {
|
||||
proto = upstreamProto.protocol
|
||||
}
|
||||
|
||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||
// SERVFAIL and REFUSED are per-question outcomes (DNSSEC-bogus names,
|
||||
// refused zones, transient recursion errors), not reachability
|
||||
// problems: fail over for a better answer but keep the upstream healthy.
|
||||
if code, ok := nonRetryableEDE(rm); ok {
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
resutil.StripOPT(rm)
|
||||
}
|
||||
u.markUpstreamOk(upstream)
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
|
||||
}
|
||||
reason := dns.RcodeToString[rm.Rcode]
|
||||
u.markUpstreamFail(upstream, reason)
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
resutil.StripOPT(rm)
|
||||
}
|
||||
|
||||
u.markUpstreamOk(upstream)
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
|
||||
}
|
||||
|
||||
@@ -520,22 +523,6 @@ func upstreamUDPSize() uint16 {
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
|
||||
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
|
||||
// the response complies with RFC 6891 when the client did not advertise EDNS0.
|
||||
func stripOPT(rm *dns.Msg) {
|
||||
if len(rm.Extra) == 0 {
|
||||
return
|
||||
}
|
||||
out := rm.Extra[:0]
|
||||
for _, rr := range rm.Extra {
|
||||
if _, ok := rr.(*dns.OPT); ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
rm.Extra = out
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
||||
|
||||
@@ -517,6 +517,78 @@ func TestUpstreamResolver_HealthTracking(t *testing.T) {
|
||||
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
|
||||
}
|
||||
|
||||
// TestUpstreamResolver_HealthTracking_ResponseMeansReachable verifies that an
|
||||
// upstream which answers with SERVFAIL or REFUSED is recorded as healthy:
|
||||
// those are per-question outcomes from a reachable server and must not mark
|
||||
// the upstream unhealthy. Only transport failures (timeouts) do.
|
||||
func TestUpstreamResolver_HealthTracking_ResponseMeansReachable(t *testing.T) {
|
||||
a := netip.MustParseAddrPort("192.0.2.10:53")
|
||||
b := netip.MustParseAddrPort("192.0.2.11:53")
|
||||
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
respA mockUpstreamResponse
|
||||
respB mockUpstreamResponse
|
||||
wantHealthy bool
|
||||
}{
|
||||
{
|
||||
name: "both SERVFAIL are reachable",
|
||||
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
wantHealthy: true,
|
||||
},
|
||||
{
|
||||
name: "both REFUSED are reachable",
|
||||
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||
wantHealthy: true,
|
||||
},
|
||||
{
|
||||
name: "timeout marks unhealthy",
|
||||
respA: mockUpstreamResponse{err: timeoutErr},
|
||||
respB: mockUpstreamResponse{err: timeoutErr},
|
||||
wantHealthy: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
a.String(): tc.respA,
|
||||
b.String(): tc.respB,
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{a, b})
|
||||
|
||||
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
|
||||
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||
|
||||
health := resolver.UpstreamHealth()
|
||||
require.Contains(t, health, a, "primary upstream should have a health record")
|
||||
if tc.wantHealthy {
|
||||
assert.False(t, health[a].LastOk.IsZero(), "responding upstream should have LastOk set")
|
||||
assert.True(t, health[a].LastFail.IsZero(), "responding upstream should not be marked failed")
|
||||
assert.Empty(t, health[a].LastErr, "responding upstream should have no error")
|
||||
} else {
|
||||
assert.False(t, health[a].LastFail.IsZero(), "timed-out upstream should be marked failed")
|
||||
assert.NotEmpty(t, health[a].LastErr, "timed-out upstream should record an error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatFailures(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -913,19 +985,6 @@ func TestEDEName(t *testing.T) {
|
||||
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
|
||||
}
|
||||
|
||||
func TestStripOPT(t *testing.T) {
|
||||
rm := &dns.Msg{
|
||||
Extra: []dns.RR{
|
||||
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
||||
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
}
|
||||
stripOPT(rm)
|
||||
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
||||
_, isOPT := rm.Extra[0].(*dns.OPT)
|
||||
assert.False(t, isOPT, "remaining record must not be OPT")
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
||||
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
@@ -26,6 +26,15 @@ import (
|
||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||
const upstreamTimeout = 15 * time.Second
|
||||
|
||||
// EDE info codes the forwarder emits on upstream failures so the querying
|
||||
// client can see the reason without inspecting this peer's logs. They live in
|
||||
// the RFC 8914 Private Use range (49152-65535); the Go resolver never exposes a
|
||||
// real upstream EDE here, so these cannot collide with a genuine code.
|
||||
const (
|
||||
edeNetbirdUpstreamTimeout uint16 = 49152
|
||||
edeNetbirdUpstreamFailure uint16 = 49153
|
||||
)
|
||||
|
||||
type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
@@ -220,7 +229,7 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
|
||||
|
||||
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||
if result.Err != nil {
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -333,6 +342,7 @@ func (f *DNSForwarder) handleDNSError(
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
result resutil.LookupResult,
|
||||
reqHasEdns bool,
|
||||
startTime time.Time,
|
||||
) {
|
||||
qType := question.Qtype
|
||||
@@ -374,6 +384,10 @@ func (f *DNSForwarder) handleDNSError(
|
||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||
}
|
||||
|
||||
if reqHasEdns {
|
||||
attachEDE(resp, edeCodeFor(dnsErr), edeText(dnsErr))
|
||||
}
|
||||
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
}
|
||||
|
||||
@@ -414,3 +428,33 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
|
||||
|
||||
return selectedResId, matches
|
||||
}
|
||||
|
||||
// edeCodeFor maps an upstream lookup error to the NetBird EDE info code.
|
||||
func edeCodeFor(dnsErr *net.DNSError) uint16 {
|
||||
if dnsErr != nil && dnsErr.IsTimeout {
|
||||
return edeNetbirdUpstreamTimeout
|
||||
}
|
||||
return edeNetbirdUpstreamFailure
|
||||
}
|
||||
|
||||
// edeText builds the EDE extra-text describing the class of upstream failure.
|
||||
// It deliberately omits the upstream server address, which may be an internal
|
||||
// resolver and is exposed to any client permitted to use the route; the full
|
||||
// detail stays in the forwarder's local log.
|
||||
func edeText(dnsErr *net.DNSError) string {
|
||||
if dnsErr != nil && dnsErr.IsTimeout {
|
||||
return "netbird forwarder: upstream timeout"
|
||||
}
|
||||
return "netbird forwarder: upstream failure"
|
||||
}
|
||||
|
||||
// attachEDE adds an Extended DNS Error (RFC 8914) option to the response,
|
||||
// creating the OPT pseudo-record if the response does not already carry one.
|
||||
func attachEDE(resp *dns.Msg, code uint16, text string) {
|
||||
opt := resp.IsEdns0()
|
||||
if opt == nil {
|
||||
resp.SetEdns0(dns.DefaultMsgSize, false)
|
||||
opt = resp.IsEdns0()
|
||||
}
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: code, ExtraText: text})
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -617,6 +618,85 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lookupErr error
|
||||
reqEdns bool
|
||||
wantEDE bool
|
||||
wantCode uint16
|
||||
wantTextHas string
|
||||
}{
|
||||
{
|
||||
name: "timeout with edns0",
|
||||
lookupErr: &net.DNSError{Err: "i/o timeout", Server: "10.0.0.53:53", IsTimeout: true},
|
||||
reqEdns: true,
|
||||
wantEDE: true,
|
||||
wantCode: edeNetbirdUpstreamTimeout,
|
||||
wantTextHas: "netbird forwarder: upstream timeout",
|
||||
},
|
||||
{
|
||||
name: "server failure with edns0",
|
||||
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
|
||||
reqEdns: true,
|
||||
wantEDE: true,
|
||||
wantCode: edeNetbirdUpstreamFailure,
|
||||
wantTextHas: "netbird forwarder: upstream failure",
|
||||
},
|
||||
{
|
||||
name: "no edns0 in request omits ede",
|
||||
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
|
||||
reqEdns: false,
|
||||
wantEDE: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
require.NoError(t, err)
|
||||
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
|
||||
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||
Return([]netip.Addr(nil), tt.lookupErr).Once()
|
||||
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion("example.com.", dns.TypeA)
|
||||
if tt.reqEdns {
|
||||
query.SetEdns0(dns.DefaultMsgSize, false)
|
||||
}
|
||||
|
||||
var writtenResp *dns.Msg
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writtenResp = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
mockResolver.AssertExpectations(t)
|
||||
|
||||
require.NotNil(t, writtenResp, "expected a response")
|
||||
assert.Equal(t, dns.RcodeServerFailure, writtenResp.Rcode, "upstream failure must be SERVFAIL")
|
||||
|
||||
ede, ok := resutil.ExtractEDE(writtenResp)
|
||||
if !tt.wantEDE {
|
||||
assert.False(t, ok, "response must not carry EDE")
|
||||
return
|
||||
}
|
||||
require.True(t, ok, "response must carry EDE")
|
||||
assert.Equal(t, tt.wantCode, ede.InfoCode, "EDE info code")
|
||||
assert.Contains(t, ede.ExtraText, tt.wantTextHas, "EDE extra-text")
|
||||
assert.NotContains(t, ede.ExtraText, "10.0.0.53", "must not leak upstream server address")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
// Test that large UDP responses are truncated with TC bit set
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
@@ -22,8 +22,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
@@ -88,6 +86,8 @@ const (
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
|
||||
var ErrEngineAlreadyStarted = errors.New("engine already started")
|
||||
|
||||
type EngineConfig struct {
|
||||
WgPort int
|
||||
WgIfaceName string
|
||||
@@ -201,6 +201,8 @@ type Engine struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
started bool
|
||||
|
||||
wgInterface WGIface
|
||||
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
@@ -281,9 +283,15 @@ func NewEngine(
|
||||
services EngineServices,
|
||||
mobileDep MobileDependency,
|
||||
) *Engine {
|
||||
// The engine is single-use: a fresh instance is built per connection
|
||||
// cycle (see Client.run), so the run context is created once here rather
|
||||
// than in Start.
|
||||
ctx, cancel := context.WithCancel(clientCtx)
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
@@ -316,8 +324,34 @@ func (e *Engine) Stop() error {
|
||||
log.Debugf("tried stopping engine that is nil")
|
||||
return nil
|
||||
}
|
||||
e.cancel()
|
||||
e.syncMsgMux.Lock()
|
||||
|
||||
e.stopLocked()
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopLocked tears down everything Start may have brought up, in the order
|
||||
// teardown requires (DNS before the interface goes down, flow manager after).
|
||||
// The caller must hold syncMsgMux. It is shared by Stop and by Start's failure
|
||||
// path, so a partially-initialized engine is cleaned up the same way; every
|
||||
// step is nil-guarded. It does not wait on shutdownWg — the caller does that
|
||||
// after releasing the lock, since the goroutines also take syncMsgMux.
|
||||
func (e *Engine) stopLocked() {
|
||||
if e.connMgr != nil {
|
||||
e.connMgr.Close()
|
||||
}
|
||||
@@ -368,10 +402,6 @@ func (e *Engine) Stop() error {
|
||||
// so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
e.jobExecutorWG.Wait() // block until job goroutines finish
|
||||
|
||||
e.close()
|
||||
@@ -390,21 +420,6 @@ func (e *Engine) Stop() error {
|
||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
||||
@@ -442,18 +457,38 @@ func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
|
||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||
// Connections to remote peers are not established here.
|
||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
|
||||
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) (err error) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if err := iface.ValidateMTU(e.config.MTU); err != nil {
|
||||
// The engine is single-use. Reject a duplicate start and a start on an
|
||||
// already-stopped engine (run context cancelled).
|
||||
if e.started {
|
||||
return ErrEngineAlreadyStarted
|
||||
}
|
||||
|
||||
if ctxErr := e.ctx.Err(); ctxErr != nil {
|
||||
return fmt.Errorf("engine already stopped: %w", ctxErr)
|
||||
}
|
||||
|
||||
e.started = true
|
||||
|
||||
// Tear down any partially-initialized state on a failed start. Cancel the
|
||||
// run context first so goroutines started before the failure (connMgr,
|
||||
// srWatcher, monitors) unwind, then stopLocked mirrors Stop's teardown (we
|
||||
// already hold syncMsgMux), cleaning up route/DNS/flow/state managers too,
|
||||
// not just what close() covers.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
e.cancel()
|
||||
e.stopLocked()
|
||||
}
|
||||
}()
|
||||
|
||||
if err = iface.ValidateMTU(e.config.MTU); err != nil {
|
||||
return fmt.Errorf("invalid MTU configuration: %w", err)
|
||||
}
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
@@ -487,13 +522,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("read initial settings: %w", err)
|
||||
}
|
||||
|
||||
dnsServer, err := e.newDnsServer(dnsConfig)
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("create dns server: %w", err)
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
@@ -528,7 +561,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
if err = e.wgInterfaceCreate(); err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
e.close()
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
}
|
||||
|
||||
@@ -537,7 +569,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
}
|
||||
|
||||
if err := e.createFirewall(); err != nil {
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -549,7 +580,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.udpMux, err = e.wgInterface.Up()
|
||||
if err != nil {
|
||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||
e.close()
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
@@ -574,9 +604,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.acl = acl.NewDefaultManager(e.firewall)
|
||||
}
|
||||
|
||||
err = e.dnsServer.Initialize()
|
||||
if err != nil {
|
||||
e.close()
|
||||
if err := e.dnsServer.Initialize(); err != nil {
|
||||
return fmt.Errorf("initialize dns server: %w", err)
|
||||
}
|
||||
|
||||
@@ -588,7 +616,9 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start(peer.IsForceRelayed())
|
||||
|
||||
e.receiveSignalEvents()
|
||||
if err = e.receiveSignalEvents(); err != nil {
|
||||
return err
|
||||
}
|
||||
e.receiveManagementEvents()
|
||||
e.receiveJobEvents()
|
||||
|
||||
@@ -640,7 +670,6 @@ func (e *Engine) createFirewall() error {
|
||||
|
||||
func (e *Engine) initFirewall() error {
|
||||
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("set firewall: %w", err)
|
||||
}
|
||||
|
||||
@@ -1127,20 +1156,6 @@ func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool {
|
||||
return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked()
|
||||
}
|
||||
|
||||
// wrapDisconnectError classifies a receive-loop failure before the run is torn
|
||||
// down. An auth rejection (PermissionDenied/Unauthenticated) means the session
|
||||
// needs re-login and retrying is futile, so mark it terminal (NeedsLogin) — run()
|
||||
// then exits on its own instead of spinning the backoff. Any other failure is a
|
||||
// recoverable connection reset that the backoff should retry.
|
||||
func (e *Engine) wrapDisconnectError(err error) {
|
||||
state := CtxGetState(e.ctx)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied || s.Code() == codes.Unauthenticated) {
|
||||
state.Set(StatusNeedsLogin)
|
||||
return
|
||||
}
|
||||
_ = state.Wrap(ErrResetConnection)
|
||||
}
|
||||
|
||||
func (e *Engine) receiveJobEvents() {
|
||||
e.jobExecutorWG.Add(1)
|
||||
go func() {
|
||||
@@ -1167,9 +1182,9 @@ func (e *Engine) receiveJobEvents() {
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time, or rejects
|
||||
// us (auth). wrapDisconnectError decides retry vs needs-login.
|
||||
e.wrapDisconnectError(err)
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.clientCancel()
|
||||
return
|
||||
}
|
||||
@@ -1251,9 +1266,9 @@ func (e *Engine) receiveManagementEvents() {
|
||||
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time, or rejects
|
||||
// us (auth). wrapDisconnectError decides retry vs needs-login.
|
||||
e.wrapDisconnectError(err)
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.clientCancel()
|
||||
return
|
||||
}
|
||||
@@ -1714,7 +1729,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
}
|
||||
|
||||
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
||||
func (e *Engine) receiveSignalEvents() {
|
||||
func (e *Engine) receiveSignalEvents() error {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
@@ -1777,15 +1792,20 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// happens if signal is unavailable for a long time, or rejects us
|
||||
// (auth). wrapDisconnectError decides retry vs needs-login.
|
||||
e.wrapDisconnectError(err)
|
||||
// happens if signal is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.clientCancel()
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
e.signal.WaitStreamConnected()
|
||||
// todo: consider to remove this blocker. I do not see benefit to block the Start operations
|
||||
e.signal.WaitStreamConnected(e.ctx)
|
||||
if err := e.ctx.Err(); err != nil {
|
||||
return fmt.Errorf("wait for signal stream: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
|
||||
@@ -247,7 +247,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
@@ -426,7 +426,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
@@ -638,7 +638,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
// feed updates to Engine via mocked Management client
|
||||
@@ -817,7 +817,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
@@ -1024,7 +1024,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
|
||||
@@ -251,6 +251,14 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
// Advertise EDNS0 to the forwarder so it may return an Extended DNS Error
|
||||
// describing why a lookup failed. The OPT is stripped from the reply when
|
||||
// the original client did not request EDNS0.
|
||||
hadEdns := r.IsEdns0() != nil
|
||||
if !hadEdns {
|
||||
r.SetEdns0(dns.DefaultMsgSize, false)
|
||||
}
|
||||
|
||||
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
@@ -260,6 +268,13 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
if ede, ok := resutil.ExtractEDE(reply); ok {
|
||||
resutil.SetMeta(w, "ede", fmt.Sprintf("%d %s", ede.InfoCode, ede.ExtraText))
|
||||
}
|
||||
if !hadEdns {
|
||||
resutil.StripOPT(reply)
|
||||
}
|
||||
|
||||
resutil.SetMeta(w, "peer", peerKey)
|
||||
|
||||
reply.Id = r.Id
|
||||
|
||||
@@ -171,13 +171,13 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
c.onHostDnsFn = func([]string) {}
|
||||
cfg.WgIface = interfaceName
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, c.recorder)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
c.setState(cfg, connectClient)
|
||||
// Persist the latest sync response so DebugBundle can include the network
|
||||
// map. On iOS this is backed by disk to keep it out of the constrained
|
||||
// process memory (see the syncstore package).
|
||||
connectClient.SetSyncResponsePersistence(true)
|
||||
return connectClient.RunOniOS(cfg, fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
|
||||
return connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
|
||||
@@ -36,6 +36,7 @@ type URLOpener interface {
|
||||
// Auth can register or login new client
|
||||
type Auth struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
config *profilemanager.Config
|
||||
cfgPath string
|
||||
}
|
||||
@@ -51,8 +52,19 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use a cancellable context so Stop() can abort an in-progress interactive
|
||||
// login. The PKCE flow's WaitToken blocks (and keeps its loopback HTTP server
|
||||
// bound to a port) until the OAuth callback arrives or the flow expires;
|
||||
// cancelling the context unblocks WaitToken, which then shuts that server down
|
||||
// and frees the port for the next login attempt. iOS runs login in the main-app
|
||||
// process (decoupled from the network extension), so without this the server
|
||||
// lingers after the user dismisses the browser and the next connect stalls
|
||||
// trying to bind the same port.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &Auth{
|
||||
ctx: context.Background(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
config: cfg,
|
||||
cfgPath: cfgPath,
|
||||
}, nil
|
||||
@@ -60,12 +72,24 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||
|
||||
// NewAuthWithConfig instantiate Auth based on existing config
|
||||
func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Auth{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// Stop aborts an in-progress interactive login started via Login/LoginWithDeviceName.
|
||||
// It cancels the auth context, which unblocks the PKCE WaitToken and shuts down its
|
||||
// loopback HTTP server, freeing the redirect port. Safe to call multiple times and
|
||||
// safe to call when no login is running.
|
||||
func (a *Auth) Stop() {
|
||||
if a.cancel != nil {
|
||||
a.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
||||
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
||||
// is not supported and returns false without saving the configuration. For other errors return false.
|
||||
|
||||
@@ -344,6 +344,9 @@ func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Eng
|
||||
}
|
||||
|
||||
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {
|
||||
if s.connectClient == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "client not connected")
|
||||
}
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "engine not initialized")
|
||||
|
||||
@@ -5,6 +5,7 @@ package server
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime/pprof"
|
||||
|
||||
@@ -27,9 +28,11 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
}
|
||||
|
||||
var clientMetrics debug.MetricsExporter
|
||||
if engine := s.connectClient.Engine(); engine != nil {
|
||||
if cm := engine.GetClientMetrics(); cm != nil {
|
||||
clientMetrics = cm
|
||||
if s.connectClient != nil {
|
||||
if engine := s.connectClient.Engine(); engine != nil {
|
||||
if cm := engine.GetClientMetrics(); cm != nil {
|
||||
clientMetrics = cm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,10 +48,13 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
defer s.cleanupBundleCapture()
|
||||
|
||||
var refreshStatus func()
|
||||
if engine := s.connectClient.Engine(); engine != nil {
|
||||
refreshStatus = func() {
|
||||
log.Debug("refreshing system health status for debug bundle")
|
||||
engine.RunHealthProbes(true)
|
||||
if s.connectClient != nil {
|
||||
engine := s.connectClient.Engine()
|
||||
if engine != nil {
|
||||
refreshStatus = func() {
|
||||
log.Debug("refreshing system health status for debug bundle")
|
||||
engine.RunHealthProbes(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,7 +118,9 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
|
||||
|
||||
log.SetLevel(level)
|
||||
|
||||
s.connectClient.SetLogLevel(level)
|
||||
if s.connectClient != nil {
|
||||
s.connectClient.SetLogLevel(level)
|
||||
}
|
||||
|
||||
log.Infof("Log level set to %s", level.String())
|
||||
|
||||
@@ -126,13 +134,20 @@ func (s *Server) SetSyncResponsePersistence(_ context.Context, req *proto.SetSyn
|
||||
|
||||
enabled := req.GetEnabled()
|
||||
s.persistSyncResponse = enabled
|
||||
s.connectClient.SetSyncResponsePersistence(enabled)
|
||||
if s.connectClient != nil {
|
||||
s.connectClient.SetSyncResponsePersistence(enabled)
|
||||
}
|
||||
|
||||
return &proto.SetSyncResponsePersistenceResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) getLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
return s.connectClient.GetLatestSyncResponse()
|
||||
cClient := s.connectClient
|
||||
if cClient == nil {
|
||||
return nil, errors.New("connect client is not initialized")
|
||||
}
|
||||
|
||||
return cClient.GetLatestSyncResponse()
|
||||
}
|
||||
|
||||
// StartCPUProfile starts CPU profiling in the daemon.
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -38,11 +39,12 @@ type conflictCheck struct {
|
||||
// OS-native managed-config store reports a diff vs the last observation.
|
||||
//
|
||||
// Restart sequence:
|
||||
// 1. Stop the in-flight run via the supervisor (blocks until fully torn down).
|
||||
// 2. Re-resolve Config from disk + MDM policy (Config.apply re-runs
|
||||
// 1. Cancel the active engine context (terminates connectWithRetryRuns).
|
||||
// 2. Wait briefly for that goroutine to exit (giveUpChan is closed on exit).
|
||||
// 3. Re-resolve Config from disk + MDM policy (Config.apply re-runs
|
||||
// applyMDMPolicy with the freshly loaded Policy).
|
||||
// 3. Start a fresh run with the new config.
|
||||
// 4. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
|
||||
// 4. Spawn a fresh connectWithRetryRuns with the new context and config.
|
||||
// 5. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
|
||||
// RPC) can refresh its cached config view without polling.
|
||||
//
|
||||
// The callback runs in the ticker's own goroutine. Ticker has already
|
||||
@@ -50,24 +52,39 @@ type conflictCheck struct {
|
||||
func (s *Server) onMDMPolicyChange(_, _ *mdm.Policy) error {
|
||||
log.Warn("MDM policy changed; restarting engine to apply new configuration")
|
||||
|
||||
// Hold s.mutex for the entire restart sequence (stop + re-start). Any
|
||||
// concurrent Up/Down/Status arriving while MDM is restarting blocks on the
|
||||
// Lock until we are done — they then observe the post-restart state coherently.
|
||||
// Hold s.mutex for the entire restart sequence (cancel + quiescence
|
||||
// wait + re-spawn). Any concurrent Up/Down/Status arriving while
|
||||
// MDM is restarting blocks on the Lock until we are done — they
|
||||
// then observe the post-restart state coherently. This is safe
|
||||
// because the connectWithRetryRuns goroutine no longer acquires
|
||||
// s.mutex in its defer (intent vs. goroutine-alive concerns are
|
||||
// fully separated; see the connectionGoroutineRunning helper).
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if !s.connectClient.ConnectionRunning() {
|
||||
// No run in flight, so there's no engine to restart.
|
||||
if !s.clientRunning {
|
||||
// The client is not running, so there's no engine to restart.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel daemon-side login/status activities tied to the old run; the run
|
||||
// itself is torn down atomically by the supervisor inside Restart (see
|
||||
// restartEngineForMDMLocked), which stops and re-starts in one operation.
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
|
||||
// Wait for previous connectWithRetryRuns to exit so we don't end up
|
||||
// with two goroutines fighting over the same status recorder + engine.
|
||||
// The teardown engages a fan-out of engine goroutines (peer workers,
|
||||
// signal handler, route manager, ...). close(clientGiveUpChan)
|
||||
// happens in the function-scope defer of connectWithRetryRuns, on
|
||||
// every exit path (ctx cancel, backoff exhausted, panic) — see the
|
||||
// defer in server.go.
|
||||
if s.clientGiveUpChan != nil {
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
case <-time.After(10 * time.Second):
|
||||
return fmt.Errorf("failed to restart the engine due to timeout")
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.restartEngineForMDMLocked(); err != nil {
|
||||
log.Errorf("MDM restart failed: %v", err)
|
||||
return err
|
||||
@@ -114,13 +131,14 @@ func (s *Server) publishConfigChangedEvent(source string) {
|
||||
}
|
||||
|
||||
// restartEngineForMDMLocked re-resolves the active profile config
|
||||
// (re-running applyMDMPolicy via Config.apply) and starts a fresh run.
|
||||
// Mirrors the tail of Server.Start so a runtime MDM change behaves
|
||||
// identically to a fresh boot under the new policy.
|
||||
// (re-running applyMDMPolicy via Config.apply) and re-spawns
|
||||
// connectWithRetryRuns. Mirrors the tail of Server.Start so a runtime
|
||||
// MDM change behaves identically to a fresh boot under the new policy.
|
||||
//
|
||||
// MUST be called with s.mutex held — onMDMPolicyChange holds the lock
|
||||
// for the entire restart sequence so concurrent Up/Down/Status RPCs
|
||||
// observe a coherent post-restart state.
|
||||
// for the entire restart sequence (cancel + quiescence wait + re-spawn)
|
||||
// so concurrent Up/Down/Status RPCs observe a coherent post-restart
|
||||
// state.
|
||||
func (s *Server) restartEngineForMDMLocked() error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
@@ -136,13 +154,13 @@ func (s *Server) restartEngineForMDMLocked() error {
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
|
||||
|
||||
_, cancel := context.WithCancel(s.rootCtx)
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
log.Info("MDM restart: atomically restarting the run with re-resolved config")
|
||||
// MDM restart has no incoming RPC metadata; fire and forget. Restart is a
|
||||
// single supervisor op (atomic stop+start), so there is no observable
|
||||
// "stopped" window between tearing down the old run and starting the new.
|
||||
s.connectClient.Restart(config, nil)
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
log.Info("MDM restart: spawning connectWithRetryRuns with re-resolved config")
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("mdm")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -34,6 +34,10 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
@@ -143,6 +147,10 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
@@ -191,6 +199,10 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
|
||||
@@ -8,10 +8,12 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -37,7 +39,15 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
probeThreshold = time.Second * 5
|
||||
probeThreshold = time.Second * 5
|
||||
retryInitialIntervalVar = "NB_CONN_RETRY_INTERVAL_TIME"
|
||||
maxRetryIntervalVar = "NB_CONN_MAX_RETRY_INTERVAL_TIME"
|
||||
maxRetryTimeVar = "NB_CONN_MAX_RETRY_TIME_TIME"
|
||||
retryMultiplierVar = "NB_CONN_RETRY_MULTIPLIER"
|
||||
defaultInitialRetryTime = 30 * time.Minute
|
||||
defaultMaxRetryInterval = 60 * time.Minute
|
||||
defaultMaxRetryTime = 14 * 24 * time.Hour
|
||||
defaultRetryMultiplier = 1.7
|
||||
|
||||
// JWT token cache TTL for the client daemon (disabled by default)
|
||||
defaultJWTCacheTTL = 0
|
||||
@@ -62,8 +72,15 @@ type Server struct {
|
||||
mutex sync.Mutex
|
||||
config *profilemanager.Config
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
// Run state (in-flight? established/done channels?) is owned entirely by the
|
||||
// supervisor inside connectClient — the daemon keeps no per-run fields.
|
||||
// clientRunning tracks "the daemon wants to be connected" — set true by
|
||||
// Start / Up, cleared by Down / Logout. Persists across retry
|
||||
// loops, signal disconnects, and ErrResetConnection cycles. NOT
|
||||
// changed by connectWithRetryRuns goroutine exit — for that
|
||||
// (goroutine-still-alive) check, see connectionGoroutineRunning() which
|
||||
// derives from clientGiveUpChan close state. Protected by s.mutex.
|
||||
clientRunning bool
|
||||
clientRunningChan chan struct{}
|
||||
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
|
||||
|
||||
connectClient *internal.ConnectClient
|
||||
|
||||
@@ -119,13 +136,6 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
networksDisabled: networksDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
// The ConnectClient is daemon-lifetime: build it exactly once, here. Its
|
||||
// supervisor lives as long as the daemon; Up/Down/MDM and reconnects all
|
||||
// drive this same instance. updateManager isn't ready yet (created in
|
||||
// Start) and is injected there via SetUpdateManager.
|
||||
s.connectClient = internal.NewConnectClient(ctx, s.statusRecorder)
|
||||
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
|
||||
|
||||
agent := &serverAgent{s}
|
||||
s.sleepHandler = sleephandler.New(agent)
|
||||
s.startSleepDetector()
|
||||
@@ -137,7 +147,7 @@ func (s *Server) Start() error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.connectClient.ConnectionRunning() {
|
||||
if s.clientRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -155,7 +165,6 @@ func (s *Server) Start() error {
|
||||
stateMgr := statemanager.New(s.profileManager.GetStatePath())
|
||||
s.updateManager = updater.NewManager(s.statusRecorder, stateMgr)
|
||||
s.updateManager.CheckUpdateSuccess(s.rootCtx)
|
||||
s.connectClient.SetUpdateManager(s.updateManager)
|
||||
}
|
||||
|
||||
// MDM policy reload ticker: every minute the desktop daemon re-reads
|
||||
@@ -181,9 +190,7 @@ func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// actCancel cancels in-flight foreground operations (login/status); the run
|
||||
// itself is owned by the supervisor and stopped via Stop, not this cancel.
|
||||
_, cancel := context.WithCancel(s.rootCtx)
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
|
||||
// copy old default config
|
||||
@@ -225,14 +232,99 @@ func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Boot autoconnect: no incoming RPC metadata. The supervisor runs the
|
||||
// client and reconnects internally; we just fire and forget (the run owns
|
||||
// its established/done channels).
|
||||
s.connectClient.RunAsync(config, nil)
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("startup")
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
|
||||
// mechanism to keep the client connected even when the connection is lost.
|
||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||
//
|
||||
// The goroutine's exit is signalled to the daemon via close(giveUpChan)
|
||||
// — placed in the function-scope defer so every return path (panic,
|
||||
// DisableAutoConnect early-exit, backoff exhausted, ctx cancel) closes
|
||||
// it. Callers that need to observe "is the goroutine still alive?" use
|
||||
// Server.connectionGoroutineRunning() which non-blockingly checks the close state
|
||||
// of clientGiveUpChan. The defer does NOT touch s.mutex; the daemon's
|
||||
// "intent" (clientRunning) is maintained by the RPC handlers, not by this
|
||||
// goroutine.
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
|
||||
defer func() {
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
}
|
||||
}()
|
||||
|
||||
if s.config.DisableAutoConnect {
|
||||
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
|
||||
log.Debugf("run client connection exited with error: %v", err)
|
||||
}
|
||||
log.Tracef("client connection exited")
|
||||
return
|
||||
}
|
||||
|
||||
backOff := getConnectWithBackoff(ctx)
|
||||
go func() {
|
||||
t := time.NewTicker(24 * time.Hour)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
return
|
||||
case <-t.C:
|
||||
mgmtState := statusRecorder.GetManagementState()
|
||||
signalState := statusRecorder.GetSignalState()
|
||||
if mgmtState.Connected && signalState.Connected {
|
||||
log.Tracef("resetting status")
|
||||
backOff.Reset()
|
||||
} else {
|
||||
log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
runOperation := func() error {
|
||||
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
|
||||
if err != nil {
|
||||
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Tracef("client connection exited gracefully, do not need to retry")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := backoff.Retry(runOperation, backOff); err != nil {
|
||||
log.Errorf("operation failed: %v", err)
|
||||
}
|
||||
// giveUpChan is closed by the function-scope defer.
|
||||
}
|
||||
|
||||
// connectionGoroutineRunning reports whether the connectWithRetryRuns goroutine is
|
||||
// still running. Returns false when no goroutine has ever been started
|
||||
// AND when the most recent one has already closed clientGiveUpChan on
|
||||
// exit (whether due to ctx cancel, DisableAutoConnect single-shot
|
||||
// completion, or backoff retry exhaustion).
|
||||
//
|
||||
// MUST be called with s.mutex held — accesses s.clientGiveUpChan which
|
||||
// is written by Start/Up under the same lock.
|
||||
func (s *Server) connectionGoroutineRunning() bool {
|
||||
if s.clientGiveUpChan == nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
||||
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
|
||||
@@ -628,22 +720,13 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
// Up starts engine work in the daemon.
|
||||
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
|
||||
s.mutex.Lock()
|
||||
|
||||
// The client (and its supervisor) is built once in New(), so a nil here
|
||||
// never happens in production — Up is only reachable after New() has run and
|
||||
// the gRPC server is serving. The real case this guards is the daemon
|
||||
// SHUTTING DOWN: rootCtx is cancelled, the supervisor is no longer accepting
|
||||
// commands, so ServiceRunning() is false even though the client exists. Bail
|
||||
// loud instead of enqueuing a run that will never start. (nil only happens in
|
||||
// tests that build a Server without New(); ServiceRunning is nil-safe.)
|
||||
if !s.connectClient.ServiceRunning() {
|
||||
s.mutex.Unlock()
|
||||
return nil, fmt.Errorf("service is not running, start the netbird service for 'up' to take effect")
|
||||
}
|
||||
|
||||
// If a connection run is already in flight, the existing engine is on the
|
||||
// job — just wait for it. Otherwise fall through to start a fresh run.
|
||||
if s.connectClient.ConnectionRunning() {
|
||||
// clientRunning is the daemon-intent flag (set by previous Up/Start, cleared
|
||||
// by Down). connectionGoroutineRunning() reports whether the previous retry-loop
|
||||
// goroutine is still trying. When intent is up AND goroutine is alive,
|
||||
// the existing engine is on the job — just wait for it. When intent
|
||||
// is up but the goroutine has given up (backoff exhausted) OR when
|
||||
// intent is down, fall through to spawn a fresh retry loop.
|
||||
if s.clientRunning && s.connectionGoroutineRunning() {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
@@ -681,13 +764,13 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
// actCancel cancels in-flight foreground ops (login/status); the run is
|
||||
// owned by the supervisor and stopped via Stop, not this cancel.
|
||||
_, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
md, ok := metadata.FromIncomingContext(callerCtx)
|
||||
if ok {
|
||||
ctx = metadata.NewOutgoingContext(ctx, md)
|
||||
}
|
||||
|
||||
// Forward the caller's gRPC metadata (e.g. UI user-agent) into the run.
|
||||
md, _ := metadata.FromIncomingContext(callerCtx)
|
||||
s.actCancel = cancel
|
||||
|
||||
if s.config == nil {
|
||||
s.mutex.Unlock()
|
||||
@@ -729,26 +812,35 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
s.connectClient.RunAsync(s.config, md)
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("up_rpc")
|
||||
|
||||
s.mutex.Unlock()
|
||||
return s.waitForUp(callerCtx)
|
||||
}
|
||||
|
||||
// waitForUp blocks until the in-flight run becomes established (success) or ends
|
||||
// before that (failure). The wait is owned by the supervisor (via the client) —
|
||||
// the daemon holds no per-run state here.
|
||||
// todo: handle potential race conditions
|
||||
func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) {
|
||||
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.connectClient.WaitEstablishedOrDone(timeoutCtx); err != nil {
|
||||
log.Debugf("waiting for the connection to be established failed: %v", err)
|
||||
return nil, fmt.Errorf("connection not established: %w", err)
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
return nil, fmt.Errorf("client gave up to connect")
|
||||
case <-s.clientRunningChan:
|
||||
s.isSessionActive.Store(true)
|
||||
return &proto.UpResponse{}, nil
|
||||
case <-callerCtx.Done():
|
||||
log.Debug("context done, stopping the wait for engine to become ready")
|
||||
return nil, callerCtx.Err()
|
||||
case <-timeoutCtx.Done():
|
||||
log.Debug("up is timed out, stopping the wait for engine to become ready")
|
||||
return nil, timeoutCtx.Err()
|
||||
}
|
||||
s.isSessionActive.Store(true)
|
||||
return &proto.UpResponse{}, nil
|
||||
}
|
||||
|
||||
// resolveProfileHandle resolves a wire-level profile handle (display
|
||||
@@ -843,11 +935,11 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
|
||||
// Down engine work in the daemon.
|
||||
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
// cleanupConnection stops the run through the supervisor, which blocks until
|
||||
// the run has fully unwound — no separate goroutine-quiescence wait needed.
|
||||
giveUpChan := s.clientGiveUpChan
|
||||
|
||||
if err := s.cleanupConnection(); err != nil {
|
||||
s.mutex.Unlock()
|
||||
// todo review to update the status in case any type of error
|
||||
log.Errorf("failed to shut down properly: %v", err)
|
||||
return nil, err
|
||||
@@ -856,6 +948,20 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusIdle)
|
||||
|
||||
s.mutex.Unlock()
|
||||
|
||||
// Wait for the connectWithRetryRuns goroutine to finish with a short timeout.
|
||||
// This prevents the goroutine from setting ErrResetConnection after Down() returns.
|
||||
// The giveUpChan is closed at the end of connectWithRetryRuns.
|
||||
if giveUpChan != nil {
|
||||
select {
|
||||
case <-giveUpChan:
|
||||
log.Debugf("client goroutine finished successfully")
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Warnf("timeout waiting for client goroutine to finish, proceeding anyway")
|
||||
}
|
||||
}
|
||||
|
||||
return &proto.DownResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -866,19 +972,38 @@ func (s *Server) cleanupConnection() error {
|
||||
return ErrServiceNotUp
|
||||
}
|
||||
|
||||
// Tear the client down through the lifecycle supervisor BEFORE cancelling
|
||||
// the retry context. Stop serializes on the supervisor queue and blocks
|
||||
// until the in-flight run has fully unwound (a clean, synchronous teardown).
|
||||
// It must run before actCancel: cancelling the context first would make
|
||||
// Stop observe a dead context and return early without waiting.
|
||||
if err := s.connectClient.Stop(); err != nil {
|
||||
return err
|
||||
// Daemon intent flips to "down" — all callers (Down RPC,
|
||||
// Logout RPC handlers) tear down the connection because the user
|
||||
// explicitly asked for it. MDM restart does NOT go through this
|
||||
// path, so its clientRunning stays true.
|
||||
s.clientRunning = false
|
||||
|
||||
// Capture the engine reference before cancelling the context.
|
||||
// After actCancel(), the connectWithRetryRuns goroutine wakes up
|
||||
// and sets connectClient.engine = nil, causing connectClient.Stop()
|
||||
// to skip the engine shutdown entirely.
|
||||
var engine *internal.Engine
|
||||
if s.connectClient != nil {
|
||||
engine = s.connectClient.Engine()
|
||||
}
|
||||
|
||||
// Stop the retry goroutine so it does not start a fresh run. The client
|
||||
// itself is daemon-lifetime and intentionally kept (a later Up reuses it).
|
||||
s.actCancel()
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: consider calling s.connectClient.Stop() instead of engine.Stop().
|
||||
// actCancel() lets the run loop stop the engine too, so both stop it
|
||||
// concurrently; ConnectClient.Stop cancels and waits for the run loop,
|
||||
// making the run loop the sole owner of engine shutdown.
|
||||
if engine != nil {
|
||||
if err := engine.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.connectClient = nil
|
||||
s.isSessionActive.Store(false)
|
||||
|
||||
log.Infof("service is down")
|
||||
@@ -1013,7 +1138,7 @@ func (s *Server) validateProfileOperation(id profilemanager.ID, allowActiveProfi
|
||||
|
||||
func (s *Server) logoutFromProfile(ctx context.Context, profile *profilemanager.Profile) error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err == nil && activeProf.ID == profile.ID && s.connectClient.ConnectionRunning() {
|
||||
if err == nil && activeProf.ID == profile.ID && s.connectClient != nil {
|
||||
return s.sendLogoutRequest(ctx)
|
||||
}
|
||||
|
||||
@@ -1059,13 +1184,48 @@ func (s *Server) Status(
|
||||
ctx context.Context,
|
||||
msg *proto.StatusRequest,
|
||||
) (*proto.StatusResponse, error) {
|
||||
// A run that hits a terminal auth failure now exits on its own (engine marks
|
||||
// NeedsLogin), so we no longer poll-and-cancel: we wait for the in-flight run
|
||||
// to become established or to end. With no run in flight this returns
|
||||
// immediately (errNoRunInFlight); either way we then report the status below.
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady {
|
||||
if err := s.connectClient.WaitEstablishedOrDone(ctx); err != nil && ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
s.mutex.Lock()
|
||||
// Only wait if the retry-loop goroutine is alive and making
|
||||
// progress. clientRunning=true with connectionGoroutineRunning=false means the
|
||||
// backoff has given up — there is nothing to wait for; let the
|
||||
// caller observe the failed status directly.
|
||||
alive := s.connectionGoroutineRunning()
|
||||
s.mutex.Unlock()
|
||||
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady && alive {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
|
||||
s.actCancel()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
ticker.Stop()
|
||||
break loop
|
||||
case <-s.clientRunningChan:
|
||||
ticker.Stop()
|
||||
break loop
|
||||
case <-ticker.C:
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
|
||||
s.actCancel()
|
||||
}
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1103,6 +1263,10 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil
|
||||
@@ -1140,6 +1304,10 @@ func (s *Server) GetPeerSSHHostKey(
|
||||
statusRecorder := s.statusRecorder
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil, errors.New("client not initialized")
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
@@ -1306,13 +1474,17 @@ func (s *Server) WaitJWTToken(
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy.
|
||||
func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.DaemonService_ExposeServiceServer) error {
|
||||
s.mutex.Lock()
|
||||
if !s.connectClient.ConnectionRunning() {
|
||||
if !s.clientRunning {
|
||||
s.mutex.Unlock()
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "client is not running, run 'netbird up' first")
|
||||
}
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
|
||||
@@ -1366,6 +1538,10 @@ func isUnixRunningDesktop() bool {
|
||||
}
|
||||
|
||||
func (s *Server) runProbes(waitForProbeResult bool) {
|
||||
if s.connectClient == nil {
|
||||
return
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return
|
||||
@@ -1644,6 +1820,22 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
|
||||
return features, nil
|
||||
}
|
||||
|
||||
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
|
||||
log.Tracef("running client connection")
|
||||
client := internal.NewConnectClient(ctx, config, statusRecorder)
|
||||
client.SetUpdateManager(s.updateManager)
|
||||
client.SetSyncResponsePersistence(s.persistSyncResponse)
|
||||
|
||||
s.mutex.Lock()
|
||||
s.connectClient = client
|
||||
s.mutex.Unlock()
|
||||
|
||||
if err := client.Run(runningChan, s.logFile); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MDM authority: when the platform-native MDM source sets a kill switch
|
||||
// key (regardless of true/false value), that value wins. The CLI flag
|
||||
// supplied at service install time is the fallback used only when the
|
||||
@@ -1705,6 +1897,45 @@ func (s *Server) onSessionExpire() {
|
||||
}
|
||||
}
|
||||
|
||||
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
|
||||
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
|
||||
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
|
||||
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
|
||||
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
|
||||
multiplier := defaultRetryMultiplier
|
||||
|
||||
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
|
||||
// parse the multiplier from the environment variable string value to float64
|
||||
value, err := strconv.ParseFloat(envValue, 64)
|
||||
if err != nil {
|
||||
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
|
||||
} else {
|
||||
multiplier = value
|
||||
}
|
||||
}
|
||||
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: initialInterval,
|
||||
RandomizationFactor: 1,
|
||||
Multiplier: multiplier,
|
||||
MaxInterval: maxInterval,
|
||||
MaxElapsedTime: maxElapsedTime, // 14 days
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
// parseEnvDuration parses the environment variable and returns the duration
|
||||
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
|
||||
if envValue := os.Getenv(envVar); envValue != "" {
|
||||
if duration, err := time.ParseDuration(envValue); err == nil {
|
||||
return duration
|
||||
}
|
||||
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
|
||||
}
|
||||
return defaultDuration
|
||||
}
|
||||
|
||||
// sendTerminalNotification sends a terminal notification message
|
||||
// to inform the user that the NetBird connection session has expired.
|
||||
func sendTerminalNotification() error {
|
||||
|
||||
@@ -15,19 +15,14 @@ import (
|
||||
)
|
||||
|
||||
func newTestServer() *Server {
|
||||
ctx := context.Background()
|
||||
s := &Server{
|
||||
rootCtx: ctx,
|
||||
return &Server{
|
||||
rootCtx: context.Background(),
|
||||
statusRecorder: peer.NewRecorder(""),
|
||||
}
|
||||
// Honor the production invariant: the daemon-lifetime client always exists
|
||||
// (built in New). Server methods rely on s.connectClient being non-nil.
|
||||
s.connectClient = internal.NewConnectClient(ctx, s.statusRecorder)
|
||||
return s
|
||||
}
|
||||
|
||||
func newDummyConnectClient(ctx context.Context) *internal.ConnectClient {
|
||||
return internal.NewConnectClient(ctx, nil)
|
||||
return internal.NewConnectClient(ctx, nil, nil)
|
||||
}
|
||||
|
||||
// TestConnectSetsClientWithMutex validates that connect() sets s.connectClient
|
||||
@@ -92,36 +87,41 @@ func TestConcurrentConnectClientAccess(t *testing.T) {
|
||||
assert.Equal(t, 50, nilCount+setCount, "all goroutines should complete without panic")
|
||||
}
|
||||
|
||||
// TestCleanupConnection_KeepsClientStopsRunning validates that cleanupConnection
|
||||
// clears the daemon "up" intent but KEEPS the daemon-lifetime ConnectClient
|
||||
// (it is reused across Up/Down; only the run is stopped).
|
||||
func TestCleanupConnection_KeepsClientStopsRunning(t *testing.T) {
|
||||
// TestCleanupConnection_ClearsConnectClient validates that cleanupConnection
|
||||
// properly nils out connectClient.
|
||||
func TestCleanupConnection_ClearsConnectClient(t *testing.T) {
|
||||
s := newTestServer()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
s.actCancel = cancel
|
||||
|
||||
s.connectClient = newDummyConnectClient(context.Background())
|
||||
s.clientRunning = true
|
||||
|
||||
err := s.cleanupConnection()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, s.connectClient, "connectClient is daemon-lifetime and must persist after cleanup")
|
||||
assert.False(t, s.connectClient.ConnectionRunning(), "no run should be in flight after cleanup")
|
||||
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
|
||||
assert.False(t, s.clientRunning, "clientRunning should be cleared after cleanup (intent = down)")
|
||||
}
|
||||
|
||||
// TestCleanState_NotConnected validates that CleanState doesn't panic when no
|
||||
// connection run is in flight.
|
||||
func TestCleanState_NotConnected(t *testing.T) {
|
||||
// TestCleanState_NilConnectClient validates that CleanState doesn't panic
|
||||
// when connectClient is nil.
|
||||
func TestCleanState_NilConnectClient(t *testing.T) {
|
||||
s := newTestServer()
|
||||
s.profileManager = nil // will cause error if it tries to proceed
|
||||
s.connectClient = nil
|
||||
s.profileManager = nil // will cause error if it tries to proceed past the nil check
|
||||
|
||||
// Should not panic — the nil check should prevent calling Status() on nil
|
||||
assert.NotPanics(t, func() {
|
||||
_, _ = s.CleanState(context.Background(), &proto.CleanStateRequest{All: true})
|
||||
})
|
||||
}
|
||||
|
||||
// TestDeleteState_NotConnected validates that DeleteState doesn't panic when no
|
||||
// connection run is in flight.
|
||||
func TestDeleteState_NotConnected(t *testing.T) {
|
||||
// TestDeleteState_NilConnectClient validates that DeleteState doesn't panic
|
||||
// when connectClient is nil.
|
||||
func TestDeleteState_NilConnectClient(t *testing.T) {
|
||||
s := newTestServer()
|
||||
s.connectClient = nil
|
||||
s.profileManager = nil
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
@@ -129,6 +129,60 @@ func TestDeleteState_NotConnected(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestDownThenUp_StaleRunningChan documents the known state issue where
|
||||
// clientRunningChan from a previous connection is already closed, causing
|
||||
// waitForUp() to return immediately on reconnect.
|
||||
func TestDownThenUp_StaleRunningChan(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
// Simulate state after a successful connection
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
close(s.clientRunningChan) // closed when engine started
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
s.connectClient = newDummyConnectClient(context.Background())
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
s.actCancel = cancel
|
||||
|
||||
// Simulate Down(): cleanupConnection sets connectClient = nil and
|
||||
// flips clientRunning to false (intent = down). The connectionGoroutineRunning state
|
||||
// remains independent of intent — derived from clientGiveUpChan.
|
||||
s.mutex.Lock()
|
||||
err := s.cleanupConnection()
|
||||
s.mutex.Unlock()
|
||||
require.NoError(t, err)
|
||||
|
||||
// After cleanup: connectClient is nil, clientRunning is false (intent
|
||||
// cleared by cleanupConnection), connectionGoroutineRunning may still be true
|
||||
// (goroutine teardown is independent of the intent flag).
|
||||
s.mutex.Lock()
|
||||
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
|
||||
assert.False(t, s.clientRunning, "clientRunning should be cleared by cleanupConnection (intent = down)")
|
||||
s.mutex.Unlock()
|
||||
|
||||
// waitForUp() returns immediately due to stale closed clientRunningChan
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer ctxCancel()
|
||||
|
||||
waitDone := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := s.waitForUp(ctx)
|
||||
waitDone <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-waitDone:
|
||||
assert.NoError(t, err, "waitForUp returns success on stale channel")
|
||||
// But connectClient is still nil — this is the stale state issue
|
||||
s.mutex.Lock()
|
||||
assert.Nil(t, s.connectClient, "connectClient is nil despite waitForUp success")
|
||||
s.mutex.Unlock()
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("waitForUp should have returned immediately due to stale closed channel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectClient_EngineNilOnFreshClient validates that a newly created
|
||||
// ConnectClient has nil Engine (before Run is called).
|
||||
func TestConnectClient_EngineNilOnFreshClient(t *testing.T) {
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
@@ -60,6 +61,65 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
|
||||
// we will use a management server started via to simulate the server and capture the number of retries
|
||||
func TestConnectWithRetryRuns(t *testing.T) {
|
||||
// start the signal server
|
||||
_, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start signal server: %v", err)
|
||||
}
|
||||
|
||||
counter := 0
|
||||
// start the management server
|
||||
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start management server: %v", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
|
||||
defer cancel()
|
||||
// create new server
|
||||
ic := profilemanager.ConfigInput{
|
||||
ManagementURL: "http://" + mgmtAddr,
|
||||
ConfigPath: t.TempDir() + "/test-profile.json",
|
||||
}
|
||||
|
||||
config, err := profilemanager.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create config: %v", err)
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: "test-profile",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "debug", "", false, false, false, false)
|
||||
|
||||
s.config = config
|
||||
|
||||
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||
t.Setenv(retryInitialIntervalVar, "1s")
|
||||
t.Setenv(maxRetryIntervalVar, "2s")
|
||||
t.Setenv(maxRetryTimeVar, "5s")
|
||||
t.Setenv(retryMultiplierVar, "1")
|
||||
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
||||
if counter < 3 {
|
||||
t.Fatalf("expected counter > 2, got %d", counter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Up(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -37,7 +38,7 @@ func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*pro
|
||||
|
||||
// CleanState handles cleaning of states (performing cleanup operations)
|
||||
func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) {
|
||||
if s.connectClient.ConnectionRunning() {
|
||||
if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
|
||||
}
|
||||
|
||||
@@ -80,7 +81,7 @@ func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (
|
||||
|
||||
// DeleteState handles deletion of states without cleanup
|
||||
func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) {
|
||||
if s.connectClient.ConnectionRunning() {
|
||||
if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
|
||||
}
|
||||
|
||||
|
||||
@@ -62,6 +62,10 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
}
|
||||
|
||||
func (s *Server) getPacketTracer() (packetTracer, *internal.Engine, error) {
|
||||
if s.connectClient == nil {
|
||||
return nil, nil, fmt.Errorf("connect client not initialized")
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, nil, fmt.Errorf("engine not initialized")
|
||||
|
||||
@@ -418,7 +418,14 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
||||
case args.showProfiles:
|
||||
s.showProfilesUI()
|
||||
case args.showQuickActions:
|
||||
s.showQuickActionsUI()
|
||||
// Suppress the on-boot Quick Actions popup when the daemon
|
||||
// reports DisableAutoConnect=true — that flag carries both the
|
||||
// user's "Connect on Startup = off" preference AND any MDM-
|
||||
// enforced override (applyMDMPolicy writes the policy value
|
||||
// into the same Config field). See netbirdio/netbird#5744.
|
||||
if !s.disableAutoConnectFromDaemon() {
|
||||
s.showQuickActionsUI()
|
||||
}
|
||||
case args.showUpdate:
|
||||
s.showUpdateProgress(ctx, args.showUpdateVersion)
|
||||
}
|
||||
@@ -1338,6 +1345,40 @@ func (s *serviceClient) getFeatures() (*proto.GetFeaturesResponse, error) {
|
||||
return features, nil
|
||||
}
|
||||
|
||||
// disableAutoConnectFromDaemon returns true when the daemon reports
|
||||
// the active profile has DisableAutoConnect=true. Used by the
|
||||
// --quick-actions startup path to suppress the on-boot popup when the
|
||||
// user (or an MDM admin) opted out of auto-connecting; both cases
|
||||
// converge on the same Config field because applyMDMPolicy writes the
|
||||
// policy value into it. Returns false on any RPC / lookup failure so a
|
||||
// daemon hiccup does not silently swallow the popup.
|
||||
func (s *serviceClient) disableAutoConnectFromDaemon() bool {
|
||||
activeProf, err := s.profileManager.GetActiveProfile()
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get active profile: %v", err)
|
||||
return false
|
||||
}
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get current user: %v", err)
|
||||
return false
|
||||
}
|
||||
conn, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get daemon client: %v", err)
|
||||
return false
|
||||
}
|
||||
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: GetConfig RPC: %v", err)
|
||||
return false
|
||||
}
|
||||
return srvCfg.GetDisableAutoConnect()
|
||||
}
|
||||
|
||||
// getSrvConfig from the service to show it in the settings window.
|
||||
func (s *serviceClient) getSrvConfig() {
|
||||
s.managementURL = profilemanager.DefaultManagementURL
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
# Build environments
|
||||
|
||||
Dockerfiles that pin the same toolchain CI uses, so a developer can
|
||||
reproduce a CI build locally without installing platform SDKs on their
|
||||
workstation. The version pins in each `Dockerfile` must stay in lockstep
|
||||
with `.github/workflows/`.
|
||||
|
||||
## `android/`
|
||||
|
||||
Mirrors `.github/workflows/mobile-build-validation.yml` (`android_build`
|
||||
job). Carries Go 1.25.5, Adopt JDK 11, Android cmdline-tools 8512546,
|
||||
NDK 23.1.7779620 and gomobile pinned at the CI commit. Use it to
|
||||
produce `netbird.aar` from `./client/android`:
|
||||
|
||||
```bash
|
||||
docker build -t netbird/build-android docker/build-env/android
|
||||
docker run --rm -v "$PWD:/src" -w /src netbird/build-android \
|
||||
gomobile bind \
|
||||
-o netbird.aar \
|
||||
-javapkg=io.netbird.gomobile \
|
||||
-ldflags="-checklinkname=0 \
|
||||
-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard \
|
||||
-X github.com/netbirdio/netbird/version.version=local" \
|
||||
./client/android
|
||||
```
|
||||
|
||||
To build the full Android APK, bind-mount the `android-client` repo as
|
||||
well and run its own `./gradlew assembleDebug` from inside the
|
||||
container (the gradle wrapper ships with `android-client`).
|
||||
|
||||
## `windows-cross/`
|
||||
|
||||
Cross-compiles Windows binaries from Linux using `mingw-w64`. Lets you
|
||||
verify that `GOOS=windows go build ./...` compiles cleanly without
|
||||
needing a Windows VM. Cannot run Windows tests — the `golang-test-windows`
|
||||
CI job executes on a native `windows-latest` runner with wintun.dll
|
||||
and PsExec, neither of which lives under Linux containers.
|
||||
|
||||
```bash
|
||||
docker build -t netbird/build-windows docker/build-env/windows-cross
|
||||
docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
|
||||
bash -c 'GOOS=windows GOARCH=amd64 go build ./...'
|
||||
```
|
||||
|
||||
## What is NOT here
|
||||
|
||||
- **iOS / macOS**: cannot legally run macOS in Docker (Apple EULA),
|
||||
and Xcode is not redistributable. The `ios_build` CI job uses a
|
||||
`macos-latest` GitHub runner; locally you need a real Mac.
|
||||
|
||||
- **Native Windows tests**: see note above. The Linux+mingw image
|
||||
builds, it does not execute Windows-host code paths
|
||||
(registry, wintun, services, PsExec workflows).
|
||||
|
||||
When CI version pins change, update the corresponding `ARG` lines in
|
||||
the Dockerfiles and the README's table of versions.
|
||||
@@ -1,86 +0,0 @@
|
||||
# Android build environment.
|
||||
#
|
||||
# Mirrors the toolchain pinned by .github/workflows/mobile-build-validation.yml
|
||||
# so a `gomobile bind` against ./client/android in this image produces the
|
||||
# same netbird.aar that CI builds.
|
||||
#
|
||||
# Tooling versions (must stay in sync with the CI workflow):
|
||||
# - Ubuntu 22.04 (matches the ubuntu-latest GitHub runner)
|
||||
# - Go 1.25.5 (matches go.mod)
|
||||
# - Adopt JDK 11 (matches actions/setup-java@v3 java-version: 11, distribution: adopt)
|
||||
# - Android SDK cmdline-tools 8512546
|
||||
# - Android NDK 23.1.7779620
|
||||
# - gomobile commit v0.0.0-20251113184115-a159579294ab
|
||||
#
|
||||
# Usage (from the netbird repo root):
|
||||
#
|
||||
# docker build -t netbird/build-android docker/build-env/android
|
||||
#
|
||||
# # bind the netbird checkout in and run the same gomobile command CI runs
|
||||
# docker run --rm -v "$PWD:/src" -w /src netbird/build-android \
|
||||
# gomobile bind \
|
||||
# -o netbird.aar \
|
||||
# -javapkg=io.netbird.gomobile \
|
||||
# -ldflags="-checklinkname=0 \
|
||||
# -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard \
|
||||
# -X github.com/netbirdio/netbird/version.version=local" \
|
||||
# ./client/android
|
||||
#
|
||||
# To build the full APK, mount the android-client repo too and run
|
||||
# `./gradlew assembleDebug` from /android-client (this image carries
|
||||
# gradle's prerequisites JDK + Android SDK but not the gradle wrapper —
|
||||
# that ships with android-client).
|
||||
|
||||
FROM ubuntu:22.04
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Versions — bump in lockstep with .github/workflows/mobile-build-validation.yml.
|
||||
ARG GO_VERSION=1.25.5
|
||||
ARG ANDROID_CMDLINE_TOOLS_VERSION=8512546
|
||||
ARG ANDROID_NDK_VERSION=23.1.7779620
|
||||
ARG GOMOBILE_VERSION=v0.0.0-20251113184115-a159579294ab
|
||||
|
||||
ENV ANDROID_HOME=/opt/android-sdk
|
||||
ENV ANDROID_NDK_HOME=${ANDROID_HOME}/ndk/${ANDROID_NDK_VERSION}
|
||||
ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64
|
||||
ENV GOPATH=/go
|
||||
ENV GOTOOLCHAIN=local
|
||||
ENV CGO_ENABLED=0
|
||||
ENV PATH=${GOPATH}/bin:/usr/local/go/bin:${ANDROID_HOME}/cmdline-tools/latest/bin:${ANDROID_HOME}/platform-tools:${JAVA_HOME}/bin:${PATH}
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
unzip \
|
||||
git \
|
||||
openjdk-11-jdk-headless \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Go (matches go.mod). actions/setup-go fetches the same tarball.
|
||||
RUN curl -fsSL "https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz" \
|
||||
| tar -C /usr/local -xz \
|
||||
&& go version
|
||||
|
||||
# Install Android SDK command-line tools, accept licenses, install NDK.
|
||||
RUN mkdir -p "${ANDROID_HOME}/cmdline-tools" \
|
||||
&& curl -fsSL -o /tmp/cmdline.zip \
|
||||
"https://dl.google.com/android/repository/commandlinetools-linux-${ANDROID_CMDLINE_TOOLS_VERSION}_latest.zip" \
|
||||
&& unzip -q /tmp/cmdline.zip -d "${ANDROID_HOME}/cmdline-tools" \
|
||||
&& mv "${ANDROID_HOME}/cmdline-tools/cmdline-tools" "${ANDROID_HOME}/cmdline-tools/latest" \
|
||||
&& rm /tmp/cmdline.zip \
|
||||
&& yes | sdkmanager --licenses > /dev/null \
|
||||
&& sdkmanager --install "ndk;${ANDROID_NDK_VERSION}" "platform-tools" > /dev/null
|
||||
|
||||
# Install gomobile at the same commit CI pins. Don't run `gomobile init` here:
|
||||
# `init` resolves the NDK at runtime, do it on the first bind in the mounted
|
||||
# workspace so the cache lands on the host volume.
|
||||
RUN GOBIN=/usr/local/bin go install "golang.org/x/mobile/cmd/gomobile@${GOMOBILE_VERSION}" \
|
||||
&& gomobile version
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
# Default entrypoint is a plain shell so the image is composable: callers pass
|
||||
# the full gomobile / gradle command they want to run.
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,63 +0,0 @@
|
||||
# Windows-cross build environment.
|
||||
#
|
||||
# Cross-compiles Windows .exe targets from a Linux container using
|
||||
# mingw-w64. Mirrors the toolchain set used by
|
||||
# .github/workflows/golang-test-windows.yml insofar as that is possible
|
||||
# without a Windows kernel.
|
||||
#
|
||||
# IMPORTANT — what this image CAN do:
|
||||
# - `GOOS=windows go build ./...` to validate that Windows builds compile
|
||||
# - CGO Windows cross-compile via x86_64-w64-mingw32-gcc when CGO_ENABLED=1
|
||||
# (matches CI's choco-installed mingw-w64)
|
||||
#
|
||||
# IMPORTANT — what this image CANNOT do:
|
||||
# - Run Windows binaries (no Windows kernel under Docker on Linux).
|
||||
# - Replicate the CI's `go test` runs which execute on a real
|
||||
# windows-latest runner (wintun.dll, PsExec, registry, etc.).
|
||||
# Use the CI for that or a native Windows VM.
|
||||
#
|
||||
# Usage (from the netbird repo root):
|
||||
#
|
||||
# docker build -t netbird/build-windows docker/build-env/windows-cross
|
||||
#
|
||||
# # Cross-compile a static client (.exe) from Linux:
|
||||
# docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
|
||||
# bash -c 'CGO_ENABLED=1 GOOS=windows GOARCH=amd64 \
|
||||
# CC=x86_64-w64-mingw32-gcc CXX=x86_64-w64-mingw32-g++ \
|
||||
# go build -o netbird.exe ./client'
|
||||
#
|
||||
# # Just validate that everything *compiles* on Windows (no CGO):
|
||||
# docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
|
||||
# bash -c 'GOOS=windows GOARCH=amd64 go build ./...'
|
||||
#
|
||||
# Tooling versions (keep in sync with go.mod and any future explicit pin
|
||||
# documented in golang-test-windows.yml):
|
||||
# - Ubuntu 22.04
|
||||
# - Go 1.25.5 (matches go.mod)
|
||||
# - mingw-w64 (Ubuntu package — pin further if drift becomes a problem)
|
||||
|
||||
FROM ubuntu:22.04
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ARG GO_VERSION=1.25.5
|
||||
|
||||
ENV GOPATH=/go
|
||||
ENV GOTOOLCHAIN=local
|
||||
ENV PATH=${GOPATH}/bin:/usr/local/go/bin:${PATH}
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
git \
|
||||
build-essential \
|
||||
mingw-w64 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Go (matches go.mod).
|
||||
RUN curl -fsSL "https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz" \
|
||||
| tar -C /usr/local -xz \
|
||||
&& go version
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
109
docs/agent-networks/00-overview.md
Normal file
109
docs/agent-networks/00-overview.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# Agent Networks — overview
|
||||
|
||||
Single-entry point. Feature scope, the module map, and the cross-cutting
|
||||
topics worth keeping in mind, with links into every per-module guide.
|
||||
|
||||
## TL;DR
|
||||
|
||||
Agent Networks introduces an **LLM-aware reverse-proxy middleware system**
|
||||
plus **account-level controls** (budget rules, log collection toggles,
|
||||
PII redaction). The management server synthesises a per-peer middleware
|
||||
chain that the proxy executes on every LLM request; the chain enforces
|
||||
quotas, injects identity, redacts PII, parses tokens/cost, and emits
|
||||
access-log entries. The dashboard exposes the surface as a single **AI
|
||||
Observability** page with four tabs.
|
||||
|
||||
- **Backend** lives in this repo, primarily under
|
||||
`management/server/agentnetwork`, `proxy/internal/middleware`, and
|
||||
`proxy/internal/llm`, with wire contracts in `shared/management`.
|
||||
- **Dashboard** lives in the dashboard repo under
|
||||
`src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`.
|
||||
|
||||
## Reading order
|
||||
|
||||
| # | Doc | Why |
|
||||
|---|-----|-----|
|
||||
| 1 | [01-end-to-end-flows.md](01-end-to-end-flows.md) | Get the three big diagrams in your head first. |
|
||||
| 2 | [modules/10-shared-api.md](modules/10-shared-api.md) | Wire contracts — every other module either produces or consumes these. |
|
||||
| 3 | [modules/21-management-agentnetwork.md](modules/21-management-agentnetwork.md) | The largest module; everything the proxy executes originates here. |
|
||||
| 4 | [modules/30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md) | The generic plugin system on the proxy side. |
|
||||
| 5 | [modules/31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md) | The 8 LLM middlewares that ride on the framework. |
|
||||
| 6 | Everything else in any order. | |
|
||||
|
||||
## Module map
|
||||
|
||||
11 modules. Each is described in detail in its own file under
|
||||
[`modules/`](modules/).
|
||||
|
||||
| # | Module | Risk | BC impact |
|
||||
|---|--------|------|-----------|
|
||||
| 10 | [shared/api](modules/10-shared-api.md) — proto + OpenAPI | Low | Additive only |
|
||||
| 20 | [management/store](modules/20-management-store.md) — SQL persistence | Medium | Auto-migrate (additive) |
|
||||
| 21 | [management/agentnetwork](modules/21-management-agentnetwork.md) — domain layer + synthesizer | **High** | Additive |
|
||||
| 22 | [management/handlers + wiring](modules/22-management-handlers-wiring.md) — HTTP API + gRPC delivery | Medium | Additive |
|
||||
| 30 | [proxy/middleware-framework](modules/30-proxy-middleware-framework.md) — generic plugin system | High | Additive |
|
||||
| 31 | [proxy/middleware-builtin](modules/31-proxy-middleware-builtin.md) — 8 LLM middlewares | High | Additive |
|
||||
| 32 | [proxy/llm-parsers](modules/32-proxy-llm-parsers.md) — SDK adapters + pricing | Medium | Additive |
|
||||
| 33 | [proxy/runtime](modules/33-proxy-runtime.md) — translate + serve + access-log | High | Additive (touches hot path) |
|
||||
| 40 | [dashboard](modules/40-dashboard.md) — UI for everything above | Medium | Sidebar reshape |
|
||||
| 50 | [path-routed-providers](modules/50-path-routed-providers.md) — Vertex AI + Bedrock | Medium | Additive (new catalog entries) |
|
||||
|
||||
The largest and highest-risk module is `management/agentnetwork`: it is
|
||||
the single writer of the middleware chain the proxy executes.
|
||||
|
||||
## Cross-cutting topics
|
||||
|
||||
These are the items most likely to bite production. Each is fully
|
||||
documented in the linked module guide.
|
||||
|
||||
1. **Capture-pointer semantics** (`*bool` for `capture_prompt` and
|
||||
`capture_completion`): nil = legacy emit, false = suppress, true =
|
||||
emit. nil-vs-false must be handled at every JSON hop. See
|
||||
[21-management-agentnetwork.md](modules/21-management-agentnetwork.md)
|
||||
and [31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md).
|
||||
2. **`ProxyMapping.Private` preservation** on per-proxy live updates.
|
||||
Failure mode: `auth` skips `ValidateTunnelPeer` →
|
||||
`CapturedData.UserGroups` empty → `llm_router` denies. See
|
||||
[33-proxy-runtime.md](modules/33-proxy-runtime.md).
|
||||
3. **respInput carrying `UserEmail`/`UserGroups`/`UserGroupNames` onto
|
||||
the response leg** in `reverseproxy.go`. Load-bearing wire that lets
|
||||
`llm_limit_record` ship non-empty `group_ids` on `RecordLLMUsage`. See
|
||||
[33-proxy-runtime.md](modules/33-proxy-runtime.md).
|
||||
4. **Min-wins all-must-pass budget rule semantics**. Every matching
|
||||
rule's remaining quota must be > 0 for the request to proceed; one
|
||||
exhausted rule blocks the whole call. Documented in
|
||||
[21-management-agentnetwork.md](modules/21-management-agentnetwork.md)
|
||||
and the `llm_limit_check` middleware in
|
||||
[31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md).
|
||||
5. **body-tap memory bounds**: per-direction 1 MiB cap, shared 256 MiB
|
||||
budget, `LimitReader(r.Body, limit+1)` for truncation detection with
|
||||
`replayReadCloser` fallback so upstream still sees the full body.
|
||||
`cloneInputFor` deep-copies the body up to 16 times per chain — a
|
||||
perf hot-spot. See
|
||||
[30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md).
|
||||
6. **UpstreamRewrite.AuthHeader bypasses the header denylist**
|
||||
deliberately. The runtime consumer only unpacks it via the
|
||||
trusted upstream-build path. See
|
||||
[30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md).
|
||||
7. **`disable_access_log` default-false semantics**: the synth target
|
||||
sets it true, all other targets leave it false. See
|
||||
[10-shared-api.md](modules/10-shared-api.md).
|
||||
8. **String-typed `decision` / `deny_code`** on
|
||||
`CheckLLMPolicyLimitsResponse` — would benefit from enum pinning
|
||||
before external consumers integrate. See
|
||||
[10-shared-api.md](modules/10-shared-api.md).
|
||||
|
||||
## Explicit non-goals
|
||||
|
||||
- **Reaper / GC pass over stale synth services** — designed but cut from
|
||||
scope.
|
||||
- **URL-sync for tab state on AI Observability** — read path is wired
|
||||
(`?tab=`) but write path isn't. Future work.
|
||||
- **CI golden-file regen-and-diff for `types.gen.go` /
|
||||
`proxy_service.pb.go`** — would catch codegen drift; not yet in place.
|
||||
|
||||
## Where to read the code
|
||||
|
||||
Per-module file scopes are listed in each module guide. Behaviour is
|
||||
covered by Go tests co-located with each package (and an end-to-end
|
||||
chain integration test under `proxy/internal/proxy`).
|
||||
217
docs/agent-networks/01-end-to-end-flows.md
Normal file
217
docs/agent-networks/01-end-to-end-flows.md
Normal file
@@ -0,0 +1,217 @@
|
||||
# End-to-end flows
|
||||
|
||||
Three cross-module mermaid diagrams. Each per-module guide repeats the
|
||||
slice that's relevant to its own scope — these are the canonical
|
||||
top-down views.
|
||||
|
||||
- [Flow A — Config → runtime (synth + deliver)](#flow-a--config--runtime-synth--deliver)
|
||||
- [Flow B — Request lifecycle through the LLM chain](#flow-b--request-lifecycle-through-the-llm-chain)
|
||||
- [Flow C — Budget rule feedback loop](#flow-c--budget-rule-feedback-loop)
|
||||
|
||||
---
|
||||
|
||||
## Flow A — Config → runtime (synth + deliver)
|
||||
|
||||
How an operator's change to a Provider, Policy, Guardrail, Budget Rule,
|
||||
or Settings record ends up as live middleware on a peer's proxy.
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
actor Op as Operator
|
||||
participant UI as Dashboard
|
||||
participant HTTP as management/handlers
|
||||
participant Mgr as agentnetwork.Manager
|
||||
participant Store as management/store (SQL)
|
||||
participant Ctl as network_map.Controller
|
||||
participant Synth as agentnetwork.SynthesizeServices
|
||||
participant Grpc as management gRPC
|
||||
participant Proxy as netbird-proxy
|
||||
participant Xlate as middleware_translate
|
||||
participant Chain as middleware.Chain
|
||||
|
||||
Op->>UI: edit provider/policy/budget/settings
|
||||
UI->>HTTP: REST PUT/POST /api/agent-network/*
|
||||
HTTP->>Mgr: SaveProvider / SavePolicy / SaveBudgetRule / SaveSettings
|
||||
Mgr->>Store: persist (gorm)
|
||||
Mgr-->>Ctl: account change event (Network-Map dirty)
|
||||
loop per connected peer
|
||||
Ctl->>Synth: SynthesizeServices(ctx, store, accountID)
|
||||
Synth->>Store: load providers, policies, guardrails, budget rules, settings
|
||||
Synth-->>Synth: build per-peer Service list
|
||||
Note over Synth: each Service has a middleware<br/>chain with capture_prompt /<br/>capture_completion / redact_pii<br/>baked from account settings
|
||||
Synth-->>Ctl: []rpservice.Service
|
||||
Ctl->>Grpc: NetworkMap push (services + middleware configs)
|
||||
end
|
||||
Grpc-->>Proxy: NetworkMap stream
|
||||
Proxy->>Xlate: translate proto MiddlewareConfig → runtime Spec
|
||||
Xlate->>Chain: register / replace per-service chain
|
||||
Note over Chain: chain replacement is live<br/>(no proxy restart, in-flight<br/>requests unaffected)
|
||||
```
|
||||
|
||||
**Notes on the diagram**
|
||||
|
||||
- The `network_map.Controller` synthesises on every push, not on a
|
||||
timer. A single config change costs O(connected peers × policies ×
|
||||
providers) per push. See [`modules/22-management-handlers-wiring.md`](modules/22-management-handlers-wiring.md).
|
||||
- `SynthesizeServices` is the single source of truth for the wire
|
||||
format the proxy executes. Anything the proxy does that the
|
||||
synthesiser didn't request is a bug. See
|
||||
[`modules/21-management-agentnetwork.md`](modules/21-management-agentnetwork.md).
|
||||
- The translate step (step 13) is the only place that knows the
|
||||
middleware-ID strings on the proxy side. It must reject unknown IDs;
|
||||
silently dropping middlewares would create a security gap (e.g.
|
||||
missing `llm_limit_check` ⇒ unbounded spend). See
|
||||
[`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md).
|
||||
|
||||
---
|
||||
|
||||
## Flow B — Request lifecycle through the LLM chain
|
||||
|
||||
What happens when an agent on the client peer sends a chat-completion /
|
||||
messages request through the synthesised reverse-proxy.
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
actor Agent as Agent (local)
|
||||
participant Px as netbird-proxy
|
||||
participant Auth as auth middleware
|
||||
participant Map as service-mapping
|
||||
participant Req as llm_request_parser
|
||||
participant Rt as llm_router
|
||||
participant Chk as llm_limit_check
|
||||
participant Inj as llm_identity_inject
|
||||
participant Grd as llm_guardrail
|
||||
participant Up as upstream LLM
|
||||
participant Resp as llm_response_parser
|
||||
participant Cost as cost_meter
|
||||
participant Rec as llm_limit_record
|
||||
participant Log as access-log
|
||||
participant MgmtGrpc as management gRPC
|
||||
|
||||
Agent->>Px: POST /v1/chat/completions (OpenAI / Anthropic)
|
||||
Px->>Auth: identify peer (user, groups)
|
||||
Auth->>Map: resolve service from Host + path
|
||||
Map-->>Req: dispatch chain in slot order
|
||||
|
||||
Req->>Req: parse body → provider, model, prompt, token estimate
|
||||
Note over Req: capture_prompt gates raw_prompt<br/>capture (nil = legacy emit,<br/>false = drop, true = emit)
|
||||
Req->>Rt: pass metadata
|
||||
Rt->>Chk: route to upstream candidate
|
||||
|
||||
Chk->>MgmtGrpc: CheckLLMPolicyLimits(provider, model, est_tokens, groups, user)
|
||||
MgmtGrpc-->>Chk: decision = allow / deny + deny_code
|
||||
alt decision == deny
|
||||
Chk-->>Log: emit access-log with deny_code<br/>(if EnableLogCollection)
|
||||
Chk-->>Agent: 429 (or 403 per deny_code)
|
||||
else decision == allow
|
||||
Chk->>Inj: continue
|
||||
Inj->>Inj: inject NetBird identity headers per provider config
|
||||
Inj->>Grd: continue
|
||||
Grd->>Grd: enforce model allowlist
|
||||
Grd->>Up: forward (over WireGuard)
|
||||
Up-->>Resp: response (JSON or SSE stream)
|
||||
Resp->>Resp: parse usage tokens, completion
|
||||
Note over Resp: capture_completion gates raw<br/>completion capture
|
||||
Resp->>Cost: tokens
|
||||
Cost->>Cost: lookup pricing.yaml + compute cost
|
||||
Cost->>Rec: tokens + cost
|
||||
Rec->>MgmtGrpc: RecordLLMUsage(provider, model, prompt_t, completion_t, cost, groups, user)
|
||||
Rec-->>Log: emit access-log entry<br/>(if EnableLogCollection)
|
||||
Log-->>Agent: 200 + body (streamed if SSE)
|
||||
end
|
||||
```
|
||||
|
||||
**Notes on the diagram**
|
||||
|
||||
- The chain runs in synth-defined order. Re-ordering middlewares
|
||||
changes invariants — `llm_limit_check` must precede `llm_router` so
|
||||
a denied request never hits upstream, and `llm_limit_record` must
|
||||
pair with `llm_limit_check` so a successful check is always recorded
|
||||
(or the rate-limit semantics break). See
|
||||
[`modules/31-proxy-middleware-builtin.md`](modules/31-proxy-middleware-builtin.md).
|
||||
- `llm_guardrail` is also where PII redaction happens
|
||||
(`redact_pii = settings.RedactPii`). Phones, emails, credit cards,
|
||||
PII names — see `redact.go` for the full set. See
|
||||
[`modules/31-proxy-middleware-builtin.md`](modules/31-proxy-middleware-builtin.md).
|
||||
- SSE streaming requires special handling on the response side; the
|
||||
parser must handle partial chunks without buffering the whole
|
||||
stream. See [`modules/32-proxy-llm-parsers.md`](modules/32-proxy-llm-parsers.md).
|
||||
- Access-log emission is gated on `settings.EnableLogCollection`. With
|
||||
it OFF, neither the deny nor the allow leg writes an entry — the
|
||||
chain still runs (budget rules are still enforced) but no audit trail
|
||||
is kept. See
|
||||
[`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md).
|
||||
|
||||
---
|
||||
|
||||
## Flow C — Budget rule feedback loop
|
||||
|
||||
How an account's budget rules tighten ceilings on every request and how
|
||||
consumption flows back into the dashboard.
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph Operator
|
||||
DashBud[Dashboard Budget Settings tab]
|
||||
end
|
||||
subgraph Mgmt[Management]
|
||||
Save[POST/PUT /api/agent-network/budget-rules]
|
||||
Store[(SQL store)]
|
||||
Synth[SynthesizeServices]
|
||||
Check[CheckLLMPolicyLimits RPC]
|
||||
Rec[RecordLLMUsage RPC]
|
||||
Cons[/api/agent-network/consumption]
|
||||
end
|
||||
subgraph Proxy[Proxy]
|
||||
Chk[llm_limit_check]
|
||||
RecMw[llm_limit_record]
|
||||
end
|
||||
subgraph DashView[Dashboard Budget Dashboard tab]
|
||||
Panel[AgentConsumptionPanel]
|
||||
end
|
||||
|
||||
DashBud -->|create / update rules| Save
|
||||
Save --> Store
|
||||
Store --> Synth
|
||||
Synth -->|push synth-services to peer| Proxy
|
||||
|
||||
Chk -->|per request| Check
|
||||
Check -->|aggregate matching rules<br/>min-wins all-must-pass| Store
|
||||
Check -->|allow / deny| Chk
|
||||
|
||||
RecMw -->|post-response| Rec
|
||||
Rec -->|tokens + cost + groups + user| Store
|
||||
|
||||
Store -->|read counters| Cons
|
||||
Cons --> Panel
|
||||
```
|
||||
|
||||
**Notes on the diagram**
|
||||
|
||||
- **min-wins all-must-pass** is the core semantic. A budget rule binds
|
||||
to (group set, user set) with a (window, ceiling). At check time,
|
||||
every rule that matches the caller is evaluated; if ANY rule has
|
||||
zero remaining quota the request is denied. This is the most
|
||||
surprising semantic for operators — see the invariants section of
|
||||
[`modules/21-management-agentnetwork.md`](modules/21-management-agentnetwork.md).
|
||||
- The proxy never makes its own budget decisions. It always asks
|
||||
management via `CheckLLMPolicyLimits` and reports back via
|
||||
`RecordLLMUsage`. This keeps account-wide accounting in one place
|
||||
and avoids per-proxy drift.
|
||||
- `RecordLLMUsage` must carry `group_ids` and `user_id` so the
|
||||
decrement hits the right rule(s). The wire that carries those
|
||||
fields onto the response leg is `respInput` in `reverseproxy.go`. See
|
||||
[`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md).
|
||||
- The dashboard's Budget Dashboard tab polls
|
||||
`/api/agent-network/consumption` — not gRPC, not WebSocket. Poll
|
||||
interval lives in `AgentConsumptionPanel.tsx`. See
|
||||
[`modules/40-dashboard.md`](modules/40-dashboard.md).
|
||||
|
||||
---
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Per-module guides: [`modules/`](modules/)
|
||||
- Overview + module map: [`00-overview.md`](00-overview.md)
|
||||
66
docs/agent-networks/README.md
Normal file
66
docs/agent-networks/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# Agent Networks — architecture documentation
|
||||
|
||||
A self-contained set of documents describing the agent-networks feature:
|
||||
an LLM-aware reverse-proxy middleware system plus account-level controls
|
||||
(budget rules, log collection toggles, PII redaction). The management
|
||||
server synthesises a per-peer middleware chain that the proxy executes on
|
||||
every LLM request.
|
||||
|
||||
## What to read first
|
||||
|
||||
1. **[00-overview.md](00-overview.md)** — the single entry point. Feature
|
||||
scope, the module map, and the cross-cutting topics worth keeping in
|
||||
mind, with links to every per-module guide.
|
||||
2. **[01-end-to-end-flows.md](01-end-to-end-flows.md)** — three
|
||||
high-level mermaid diagrams: config-to-runtime synth/delivery,
|
||||
per-request lifecycle through the LLM chain, and the budget-rule
|
||||
feedback loop.
|
||||
3. **Per-module guides** under `modules/` — one file per package. Each
|
||||
describes the module boundary, the file-level layout, its own flow
|
||||
diagrams, the public contracts, the invariants it relies on, and the
|
||||
areas worth the closest attention.
|
||||
|
||||
## Directory layout
|
||||
|
||||
```
|
||||
docs/agent-networks/
|
||||
├── README.md # you are here
|
||||
├── 00-overview.md # feature summary + module map
|
||||
├── 01-end-to-end-flows.md # cross-module mermaid diagrams
|
||||
└── modules/
|
||||
├── 10-shared-api.md # proto + OpenAPI wire contracts
|
||||
├── 20-management-store.md # SQL persistence layer
|
||||
├── 21-management-agentnetwork.md # domain layer + synthesizer (largest)
|
||||
├── 22-management-handlers-wiring.md # HTTP API + gRPC delivery
|
||||
├── 30-proxy-middleware-framework.md # generic plugin system
|
||||
├── 31-proxy-middleware-builtin.md # 8 LLM-aware middlewares
|
||||
├── 32-proxy-llm-parsers.md # OpenAI/Anthropic/Bedrock SDKs + pricing
|
||||
├── 33-proxy-runtime.md # translate + serve + access-log
|
||||
├── 40-dashboard.md # UI for everything above (lives in the dashboard repo)
|
||||
└── 50-path-routed-providers.md # Vertex AI + Bedrock (path-routed, keyfile:: creds, /bedrock prefix)
|
||||
```
|
||||
|
||||
The `40-dashboard.md` module documents code that lives in the **dashboard
|
||||
repo**, not in this repo. The guide is co-located here so backend readers
|
||||
see the full picture in one place.
|
||||
|
||||
## How the per-module guides are structured
|
||||
|
||||
Every `modules/*.md` follows the same template so the docs are easy to
|
||||
scan:
|
||||
|
||||
- **Module boundary** — what this package owns; where it sits in the stack.
|
||||
- **Files** — path / role.
|
||||
- **Architecture & flow** — one or more mermaid diagrams.
|
||||
- **Public contracts** — function signatures, gRPC messages, JSON shapes.
|
||||
- **Invariants** — semantic guarantees the module relies on or enforces.
|
||||
- **Things to scrutinize** — split by correctness / security /
|
||||
concurrency / backward-compat / performance / observability.
|
||||
- **Test coverage** — the test files that lock down behaviour in this
|
||||
module.
|
||||
- **Known limitations / non-goals** — what is intentionally out of scope.
|
||||
- **Cross-references** — upstream/downstream module links + the
|
||||
end-to-end flow + the overview.
|
||||
|
||||
See [00-overview.md](00-overview.md) for the module map and the
|
||||
cross-cutting topics.
|
||||
105
docs/agent-networks/modules/10-shared-api.md
Normal file
105
docs/agent-networks/modules/10-shared-api.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# shared/api — wire contracts (proto + OpenAPI)
|
||||
|
||||
> **Risk level:** Medium — wire-format surface that every other module pins against; backward-compat hinges on field-number discipline more than on logic correctness.
|
||||
> **Backward-compat impact:** Additive only (new proto fields use unallocated numbers, new RPCs default to `Unimplemented`, new OpenAPI schemas/paths are append-only; no existing field/RPC/schema removed or renumbered).
|
||||
|
||||
## Module boundary
|
||||
This module owns the cross-process contract surface between management, proxy, and dashboard. Two artefacts: `shared/management/proto/proxy_service.proto` (management↔proxy gRPC) and `shared/management/http/api/openapi.yml` (dashboard/CLI↔management REST). Both have generated companions checked in (`proxy_service.pb.go`, `proxy_service_grpc.pb.go`, `types.gen.go`) which must travel in lockstep with their sources. `shared/management/status/error.go` is in scope only for the four new typed `NotFound` constructors that the new HTTP handlers return.
|
||||
|
||||
Everything downstream — `management/agentnetwork`, `management/server/http/handlers/*`, `proxy/internal/*`, the dashboard SDK — consumes these types verbatim. The concern here is wire stability and codegen reproducibility, not behaviour: behaviour is covered in the management and proxy module guides.
|
||||
|
||||
`management.proto` and `signalexchange.proto` are unchanged. `status/error.go` only receives four additive constructors (lines 208-227); no existing error types are reshaped.
|
||||
|
||||
## Files
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `shared/management/proto/proxy_service.proto` | Source of truth: 2 new RPCs, 1 new message group (`MiddlewareConfig` + slot enum), additive fields on `PathTargetOptions`, `AccessLog`, `RecordLLMUsageRequest` |
|
||||
| `shared/management/proto/proxy_service.pb.go` | Generated (protoc-gen-go) |
|
||||
| `shared/management/proto/proxy_service_grpc.pb.go` | Generated; adds `CheckLLMPolicyLimits` + `RecordLLMUsage` client/server stubs and `UnimplementedProxyServiceServer` defaults |
|
||||
| `shared/management/http/api/openapi.yml` | 15 new `AgentNetwork*` schemas, 9 new path groups under `/api/agent-network/*` |
|
||||
| `shared/management/http/api/types.gen.go` | Generated (oapi-codegen; see codegen note below) |
|
||||
| `shared/management/status/error.go` | Four `NotFound` constructors for the new resource kinds (lines 208-227) |
|
||||
|
||||
## Architecture & flow
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Dash as Dashboard / CLI
|
||||
participant Mgmt as management (HTTP+gRPC)
|
||||
participant Px as proxy
|
||||
|
||||
Note over Dash,Mgmt: REST (OpenAPI / types.gen.go)
|
||||
Dash->>Mgmt: PUT /api/agent-network/providers (AgentNetworkProviderRequest)
|
||||
Dash->>Mgmt: PUT /api/agent-network/settings (AgentNetworkSettingsRequest)
|
||||
Dash->>Mgmt: GET /api/agent-network/consumption -> [AgentNetworkConsumption]
|
||||
|
||||
Note over Mgmt,Px: gRPC ProxyService (proxy_service.proto)
|
||||
Mgmt-->>Px: SyncMappingsResponse{ ProxyMapping.path[*].options.middlewares,<br/>agent_network, disable_access_log, capture_* }
|
||||
Px->>Mgmt: CheckLLMPolicyLimits(account, user, groups, provider, model)
|
||||
Mgmt-->>Px: decision=allow|deny + selected_policy_id + attribution_group_id + window_seconds
|
||||
Px->>Mgmt: RecordLLMUsage(account, user, group_id, group_ids, window_seconds, tokens, cost)
|
||||
Px->>Mgmt: SendAccessLog(AccessLog{ agent_network=true })
|
||||
```
|
||||
|
||||
The proto changes split into three independent slices: (1) **mapping enrichment** — `PathTargetOptions` grows fields 8-13 so management can ship middleware configs, capture limits, and the agent-network / log-suppression flags down to the proxy without a second RPC; (2) **two new request/response RPCs** (`CheckLLMPolicyLimits`, `RecordLLMUsage`) for per-LLM-request budget arbitration; (3) **observability tag** — `AccessLog.agent_network` so management can route logs to the right surface.
|
||||
|
||||
The OpenAPI side is a thin CRUD surface — every resource (`Provider`, `Policy`, `Guardrail`, `BudgetRule`, `Settings`) follows the same `GET-list / POST / GET / PUT / DELETE` pattern, plus a read-only `/consumption` listing and a catalog endpoint. The `*Request` variants drop server-controlled fields (id, timestamps). `AgentNetworkBudgetRule` deliberately reuses `AgentNetworkPolicyLimits` to keep wire-shape parity with policies.
|
||||
|
||||
## Public contracts added
|
||||
- gRPC RPCs (`proxy_service.proto:52-57`): `CheckLLMPolicyLimits(CheckLLMPolicyLimitsRequest) → CheckLLMPolicyLimitsResponse`, `RecordLLMUsage(RecordLLMUsageRequest) → RecordLLMUsageResponse`. Both unary; default `UnimplementedProxyServiceServer` returns `codes.Unimplemented` (`proxy_service_grpc.pb.go:283-289`).
|
||||
- New messages (`proxy_service.proto:145-175,448-502`): `MiddlewareConfig`, `MiddlewareSlot` enum, `CheckLLMPolicyLimitsRequest`/`Response`, `RecordLLMUsageRequest`/`Response`.
|
||||
- New `PathTargetOptions` fields 8-13 (`proxy_service.proto:124-140`): `capture_max_request_bytes`, `capture_max_response_bytes`, `capture_content_types`, `middlewares`, `agent_network`, `disable_access_log`. All default-false / zero; pre-existing fields 1-7 byte-for-byte unchanged.
|
||||
- `AccessLog.agent_network = 18` (`proxy_service.proto:258-261`).
|
||||
- `RecordLLMUsageRequest.group_ids = 8` (`proxy_service.proto:496-498`) — so the record path can fan out to every applicable budget rule's window without a re-lookup.
|
||||
- 15 new OpenAPI component schemas (`openapi.yml:5072-5829`): `AgentNetworkProvider[Request|Model]`, `AgentNetworkCatalog{Model,Provider,IdentityInjection,HeaderPairInjection,JSONMetadataInjection,ExtraHeader}`, `AgentNetworkPolicy[Request|TokenLimit|BudgetLimit|Limits]`, `AgentNetworkGuardrail[Checks|Request]`, `AgentNetworkConsumption`, `AgentNetworkSettings[Request]`, `AgentNetworkBudgetRule[Request]`.
|
||||
- 9 new path groups (`openapi.yml:12797-13460`): `/api/agent-network/{consumption,settings,budget-rules,budget-rules/{ruleId},catalog/providers,providers,providers/{providerId},policies,policies/{policyId},guardrails,guardrails/{guardrailId}}`.
|
||||
- Four typed NotFound errors (`shared/management/status/error.go:208-227`).
|
||||
|
||||
## Invariants
|
||||
- **Field-number monotonicity.** Every new proto field uses a previously-unallocated number in its message: `PathTargetOptions` 8-13 (was 1-7), `AccessLog` 18 (was 1-17), `RecordLLMUsageRequest` 8. `SendStatusUpdateRequest.inbound_listener = 50` (pre-existing) reserves 50+ for observability extensions, so 8 on `RecordLLMUsageRequest` doesn't conflict.
|
||||
- **Old proxies stay compatible.** Old management never sends `disable_access_log`/`middlewares`/`agent_network` (zero value → existing behaviour); old proxies that don't decode these fields just drop them silently (proto3 unknown-field semantics) — log emission stays on. No pre-existing field number changed: the proto change is insertions only.
|
||||
- **Old management stays compatible.** The two new RPCs are registered on the same `management.ProxyService` descriptor; old proxies hitting them get `codes.Unimplemented` from the unimplemented embed (`proxy_service_grpc.pb.go:283-289`), which is the same fallback pattern `SyncMappings` already documents (`proxy_service.proto:20-21`).
|
||||
- **OpenAPI shapes are append-only.** New schemas are placed at the end of `components.schemas` (line 5072+); new paths at the end of `paths` (line 12797+). No existing schema's `required` list, enum, or property type was changed.
|
||||
- **`*Request` vs response asymmetry.** Read shapes (`AgentNetworkProvider`, `AgentNetworkPolicy`, `AgentNetworkGuardrail`, `AgentNetworkSettings`, `AgentNetworkBudgetRule`) require `created_at`/`updated_at`; the matching `*Request` shapes do not — server fills them. `AgentNetworkProviderRequest.api_key` is write-only (`openapi.yml:5158-5161` "never returned in responses"); reviewers should confirm the response schema (5072-5138) actually omits `api_key`.
|
||||
|
||||
## Things to scrutinize
|
||||
### Correctness
|
||||
- `RecordLLMUsageRequest` carries both `group_id` (singular, the attribution group — field 3) and `group_ids` (plural, full membership — field 8). `b22d5a181` adds field 8 to drive account-budget fan-out; double-check that consumers can't accidentally key counters on the wrong one. Field comments at `proxy_service.proto:489-491` and `496-498` distinguish them but it's the kind of subtle thing a follow-up commit might collapse.
|
||||
- `PathTargetOptions.disable_access_log` is the only field whose default-false meaning **changes semantics** on the proxy side: false → log (status quo), true → suppress. Synthesizer sets `DisableAccessLog = !settings.EnableLogCollection`, so a missing/default settings row yields `EnableLogCollection=false → DisableAccessLog=true → suppressed`. Worth confirming downstream (`agentnetwork.synthesizer`) that operator-defined private services never inherit this flag — the proto field default protects them, but only if synth code is explicit.
|
||||
- `CheckLLMPolicyLimitsResponse.decision` is a free-form `string` (`proxy_service.proto:471`) rather than an enum. Only documented values are "allow" / "deny". An enum would prevent typo drift; consider before this RPC ships to external consumers.
|
||||
- `deny_code` (`proxy_service.proto:478-481`) is documented as "a stable label" but is also a free string. Pin the allowed set somewhere observable to the proxy.
|
||||
|
||||
### Security
|
||||
- `AgentNetworkProvider.api_key` MUST be write-only. Schema split (request has it at line 5158; response omits it) looks correct, but a regression here leaks the upstream provider credential to every dashboard reader. Check that the handler explicitly zeros it on the response path.
|
||||
- `extra_values` / `identity_header_*` headers on `AgentNetworkProvider` get stamped onto upstream requests. Description at `openapi.yml:5099` says "values not declared by the catalog are ignored at synth time" — a contract this module documents but the synthesizer must enforce. Confirm the synth module honours it.
|
||||
- Cluster + subdomain on `AgentNetworkSettings` are documented immutable (`openapi.yml:5686-5694`) and the `AgentNetworkSettingsRequest` (lines 5733-5752) doesn't accept them. Verify the `PUT /api/agent-network/settings` handler can't be tricked by extra JSON keys (oapi-codegen's `additionalProperties: false` is not declared here; spec defaults to permissive).
|
||||
|
||||
### Backward compatibility
|
||||
- The proto change is field-number additive: every previously numbered field keeps the same name + type, and the change is insertions only (no deletions in `proxy_service.proto`), so this holds at the source-text level.
|
||||
- `proxy_service_grpc.pb.go` adds two RPC handlers and registers them in `ProxyService_ServiceDesc.Methods` (lines 543-552). The existing entries are unchanged and order-preserving — gRPC method dispatch is name-keyed, so order doesn't matter, but reviewing the diff (no method renamed/dropped) is still worth a glance.
|
||||
- OpenAPI 3.0 doesn't have a built-in deprecation flow for paths; if any client tooling iterates `paths.*`, the additive routes shouldn't break it, but generated SDKs (especially the dashboard's) need a regen to gain access to `AgentNetwork*`.
|
||||
|
||||
### Codegen pinning
|
||||
- `generate.sh` (`shared/management/http/api/generate.sh:14`) installs `oapi-codegen@latest` rather than a pinned version. **This is a reproducibility gap** — re-running the script later may produce a different `types.gen.go`. Either pin the version in `generate.sh` (e.g. `@v2.7.0`) or document the pin in a `tools.go`.
|
||||
- proto codegen has the protoc / protoc-gen-go version stamped in the generated file header (`proxy_service.pb.go:3-4`).
|
||||
- Regenerate locally and confirm zero diff against the committed `types.gen.go` / `proxy_service.pb.go`.
|
||||
|
||||
## Test coverage
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| None in this scope | The proto and OpenAPI sources are tested transitively by the handler tests (`shared/management/http/handlers/agentnetwork/...`) and by the synthesizer/manager tests (`management/server/agentnetwork/...`). No round-trip serialisation test exists in the `proto/` or `api/` packages themselves. |
|
||||
| `shared/management/proto/*_test.go` | (absent) |
|
||||
| `shared/management/http/api/*_test.go` | (absent) |
|
||||
|
||||
Acceptable for codegen artefacts, but a single golden-file test that re-runs `oapi-codegen` and `protoc` in CI and diffs against the checked-in files would close the reproducibility gap noted above.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
- **No deprecation surface.** Old fields/RPCs are kept silently; there is no `[deprecated = true]` annotation on anything. Acceptable here because nothing is being removed.
|
||||
- **No proto-side validation.** Numeric ranges (e.g. `window_seconds >= 60`, `cost_usd >= 0`, capture-byte clamps) are enforced in the OpenAPI schema via `minimum:` and inside Go code by the proxy/management, but `proto3` itself can't express them; downstream is expected to validate every message.
|
||||
- **`MiddlewareConfig.config_json` is `bytes`** (`proxy_service.proto:163`) — opaque to the proto layer. Schema validity is the middleware factory's problem. This is a deliberate tradeoff (per the comment at 161-162) but worth flagging: a corrupted/malicious config_json can only fail at proxy apply time, not at the wire-decode step.
|
||||
- **No catalog endpoint schema for the catalog itself** — the catalog data ships as a `GET /api/agent-network/catalog/providers` returning `[AgentNetworkCatalogProvider]` (`openapi.yml:13024`), but the catalog source-of-truth lives in `management/server/agentnetwork/catalog`, not here.
|
||||
- The reaper / GC design was cut from scope; no reaper-related types appear here.
|
||||
|
||||
## Cross-references
|
||||
- Downstream: [management/store](20-management-store.md), [management/agentnetwork](21-management-agentnetwork.md), [management/handlers + wiring](22-management-handlers-wiring.md), [proxy/runtime](33-proxy-runtime.md)
|
||||
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level: [../00-overview.md](../00-overview.md)
|
||||
112
docs/agent-networks/modules/20-management-store.md
Normal file
112
docs/agent-networks/modules/20-management-store.md
Normal file
@@ -0,0 +1,112 @@
|
||||
# management/store — persistence for agent-network entities
|
||||
|
||||
> **Risk level:** Medium — six brand-new tables behind AutoMigrate, one upsert-counter table that runs on the request hot path, and one column carrying an encrypted secret.
|
||||
> **Backward-compat impact:** Additive (six new tables created by AutoMigrate; the `Store` interface gains 23 methods, but no existing column/index is touched).
|
||||
|
||||
## Module boundary
|
||||
|
||||
This module is the persistence layer for the Agent Network feature. Everything the management server stores about LLM proxying — providers, policies, guardrails, the per-account settings row, a usage-counter table written on every proxied LLM request, and the account-budget rules — flows through the methods added to `store.Store`. The module owns six tables, six entity types from `management/server/agentnetwork/types`, and a single hot-path upsert (`IncrementAgentNetworkConsumption`) consumed by the proxy fleet.
|
||||
|
||||
Out of scope here: the catalog of provider definitions (compiled-in, no DB), the synthesizer/manager built on top of these CRUDs (covered in [21-management-agentnetwork.md](21-management-agentnetwork.md)), and the HTTP handlers that translate API requests into Save/Delete calls.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `management/server/store/sql_store_agentnetwork.go` | gorm implementations of all 23 store methods |
|
||||
| `management/server/store/sql_store_agentnetwork_budgetrule_test.go` | round-trip + account-scoping coverage against a real sqlite store |
|
||||
| `management/server/store/sql_store.go` | one import, six entities appended to the `AutoMigrate` slice (sql_store.go:40, sql_store.go:141-142) |
|
||||
| `management/server/store/store.go` | 23 methods added to the `Store` interface (store.go:328-354) |
|
||||
| `management/server/store/store_mock_agentnetwork.go` | mockgen output for the new interface surface |
|
||||
|
||||
## Tables added / migrations
|
||||
|
||||
All six tables are created by `db.AutoMigrate` invoked from `NewSqlStore` at sql_store.go:133-143. There is no hand-rolled SQL migration script — the schema is whatever GORM derives from the struct tags.
|
||||
|
||||
- `agent_network_providers` — `Provider.TableName()` at provider.go:76. PK `id`, index on `account_id`, named index `idx_agent_network_provider` on `provider_id`. Carries an at-rest-encrypted `api_key` and ed25519 `session_private_key` (provider.go:35,56). `extra_values` and `models` are JSON blobs (`serializer:json`).
|
||||
- `agent_network_policies` — `Policy.TableName()` at policy.go:70. PK `id`, index on `account_id`. JSON columns: `source_groups`, `destination_provider_ids`, `guardrail_ids`, `limits`.
|
||||
- `agent_network_guardrails` — `Guardrail.TableName()` at guardrail.go:41. PK `id`, index on `account_id`. JSON `checks`.
|
||||
- `agent_network_settings` — `Settings.TableName()` at settings.go:33. PK `account_id` (one row per account), named index `idx_agent_network_settings_cluster_subdomain` on `subdomain` only — the index name implies a composite, but only one column is tagged.
|
||||
- `agent_network_consumption` — `Consumption.TableName()` at consumption.go:46. Composite PK across `(account_id, dim_kind, dim_id, window_seconds, window_start_utc)` — the same tuple the upsert keys on.
|
||||
- `agent_network_budget_rules` — `AccountBudgetRule.TableName()` at budgetrule.go:35. PK `id`, index on `account_id`. JSON `target_groups`, `target_users`, `limits`.
|
||||
|
||||
## CRUD surface added
|
||||
|
||||
Provider, Policy, Guardrail, BudgetRule follow the same pattern: `Get<Kind>ByID`, `GetAccount<Kind>` (list), `Save<Kind>` (upsert), `Delete<Kind>`, with account-scoping enforced by the existing `accountAndIDQueryCondition` / `accountIDCondition` constants (sql_store.go:59-62). Provider additionally exposes `GetAllAgentNetworkProviders` (cross-account, used by the synthesizer). Settings exposes `Get`/`GetByCluster`/`Save` (no delete — one row per account, created on first save). Consumption exposes the upsert `Increment`, a point `Get`, and a cross-window `List`.
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
handlers["HTTP handlers<br/>(management/server/agentnetwork)"] -->|Save/Delete| iface["Store interface<br/>store.go:328-354"]
|
||||
manager["agentnetwork.Manager"] -->|Get*| iface
|
||||
synth["synthesizer<br/>(global)"] -->|GetAllAgentNetworkProviders| iface
|
||||
proxy["proxy fleet<br/>(hot path)"] -->|IncrementAgentNetworkConsumption| iface
|
||||
iface --> sql["SqlStore methods<br/>sql_store_agentnetwork.go"]
|
||||
iface -.gomock.-> mock["MockStore<br/>store_mock_agentnetwork.go"]
|
||||
sql --> gorm["gorm.DB"]
|
||||
gorm --> tables[("6 tables<br/>agent_network_*")]
|
||||
sql --> enc["crypt.FieldEncrypt<br/>(provider only)"]
|
||||
```
|
||||
|
||||
Reads decrypt provider secrets in-place; writes do `provider.Copy().EncryptSensitiveData(...)` before `db.Save` so the caller's in-memory object keeps the plaintext `api_key` (sql_store_agentnetwork.go:88-102). Every list/get takes a `LockingStrength` and applies `clause.Locking{Strength: ...}` when non-`None` — matching the rest of the store. The upsert path uses `clause.OnConflict` with `gorm.Expr` server-side increments so concurrent proxy nodes converge without read-modify-write races (sql_store_agentnetwork.go:321-335).
|
||||
|
||||
## Invariants enforced at the store layer
|
||||
|
||||
- **Account scoping.** Every entity-by-ID method keys on `account_id = ? and id = ?`; no cross-tenant leak path through the API is reachable as long as callers always pass the auth'd `accountID` (sql_store_agentnetwork.go:70,141,201,429).
|
||||
- **NotFound mapping.** `gorm.ErrRecordNotFound` is translated to typed `status.NewAgentNetwork*NotFoundError`; `Delete*` returns NotFound when `RowsAffected == 0` (sql_store_agentnetwork.go:111-113,171-173,231-233,461-463).
|
||||
- **Provider secret encryption at rest.** `SaveAgentNetworkProvider` always encrypts before persist; `Get*` always decrypts after read. The plaintext `api_key` never reaches the DB through this layer (sql_store_agentnetwork.go:31,54,80,90).
|
||||
- **Consumption monotonicity.** The upsert only ever issues `col = col + ?` for the three counter columns — no decrement path exists (sql_store_agentnetwork.go:330-332).
|
||||
- **Window alignment is the caller's responsibility.** The store stamps `WindowStartUTC` as-passed; alignment to epoch happens in `types.WindowStart` at consumption.go:51-58.
|
||||
- **Settings has no Delete.** Intentional — one row per account, created on first save; the row sticks around for the account lifetime.
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
- `SaveAgentNetworkProvider` saves the copy (sql_store_agentnetwork.go:95). The caller's in-memory pointer therefore keeps plaintext `api_key` and any `CreatedAt`/`UpdatedAt` gorm autofills land on the copy, not the original. Callers that need synced timestamps must re-fetch.
|
||||
- `IncrementAgentNetworkConsumption`'s `Create` provides initial counter values (`TokensInput: tokensIn`, etc.) in the row, and on conflict the assignments add the same deltas to the existing values. The insert-vs-update arithmetic is consistent. Cross-check that no engine in use (sqlite, postgres, mysql) silently rejects the `OnConflict` clause — GORM emits engine-specific SQL but `ON DUPLICATE KEY UPDATE` (mysql) vs `ON CONFLICT (...)` (sqlite/postgres) need their unique constraint to match the composite PK on `agent_network_consumption`; it does, by construction.
|
||||
- `IncrementAgentNetworkConsumption` writes `updated_at: time.Now().UTC()` literally inside the assignments map (sql_store_agentnetwork.go:333) — fine, but it's a Go-side timestamp captured at call time, not a DB-side `now()`. Acceptable for an audit field.
|
||||
- `GetAgentNetworkConsumption` returns a zero-valued non-nil row on `ErrRecordNotFound` (sql_store_agentnetwork.go:364-371). Document or rename — a typed sentinel error would be more orthodox; callers must know not to error-check.
|
||||
|
||||
### Concurrency / transactions
|
||||
- Hot-path `IncrementAgentNetworkConsumption` runs outside any explicit transaction; concurrency safety relies entirely on the DB serialising the `ON CONFLICT` upsert against the composite PK. This is correct for postgres and mysql; for sqlite it serialises behind the single writer.
|
||||
- `SaveAgentNetworkSettings` is a blind upsert with no version/etag — concurrent writes from two operators last-write-wins on the collection-toggle flags (settings.go:23-25). Acceptable for admin-curated state but worth flagging.
|
||||
- `Save*Provider` uses `db.Save` on a struct with a PK already set — GORM emits UPDATE or INSERT based on row existence. No upsert clause is attached, so a race between two creates with the same generated `xid` (vanishingly unlikely) would surface as a PK violation.
|
||||
|
||||
### Migration safety
|
||||
- All six tables ride `AutoMigrate` (sql_store.go:141-142). AutoMigrate is additive: new columns get added, but it never drops columns nor narrows types. Three `bool` columns on `agent_network_settings` (`EnableLogCollection`, `EnablePromptCollection`, `RedactPii`) default to false at the GORM/DDL layer for existing rows; the test at sql_store_agentnetwork_budgetrule_test.go:83-112 locks that down on a fresh sqlite. Verify postgres/mysql produce the same default.
|
||||
- The named index `idx_agent_network_settings_cluster_subdomain` on settings.go:15 is declared on only `subdomain`. Either the cluster column also needs `gorm:"index:idx_agent_network_settings_cluster_subdomain"` to make it composite, or the name is misleading.
|
||||
- The named index `idx_agent_network_provider` on `Provider.ProviderID` (provider.go:30) is *not* unique and not scoped to account — two providers in the same account with the same `provider_id` are permitted at the DB layer; uniqueness, if any, must live above the store.
|
||||
|
||||
### Backward compatibility
|
||||
- Net additive. No removed methods, no renamed columns, no schema change to existing tables. Existing deployments running a prior binary continue to work; the first boot of the new binary creates the six tables.
|
||||
- The `Store` interface grows by 23 methods (store.go:330-354); any non-mock external implementer of `store.Store` will fail to compile. The repo only has `SqlStore` + `MockStore`, both updated.
|
||||
|
||||
### Performance (indexes, N+1)
|
||||
- All by-account list queries hit the `idx_account_id` per-table index. No N+1: list methods return the full slice in one query.
|
||||
- `GetAgentNetworkSettingsByCluster` (sql_store_agentnetwork.go:263-277) does a tablescan on `cluster` — no index. Tolerable for the bootstrap label generator (one-shot at provisioning) but worth noting if the call moves onto a hot path.
|
||||
- `ListAgentNetworkConsumption` returns every row ever recorded for the account (sql_store_agentnetwork.go:382-400) — unbounded growth, no `LIMIT`, no time filter. With one row per (dim, window) per request burst, this table grows fastest of the six; a retention job + a paginated list method are obvious follow-ups.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkBudgetRule_RealStore_RoundTrip` | full save → reload of `AccountBudgetRule` including the JSON-serialised `PolicyLimits`, target slices, double-delete returns NotFound (lines 18-59) |
|
||||
| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkBudgetRule_RealStore_ScopedByAccount` | cross-account isolation for budget rules (lines 63-78) |
|
||||
| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkSettings_RealStore_CollectionTogglesRoundTrip` | collection toggles default off, survive save/reload at the set values (lines 83-112) |
|
||||
|
||||
Gap: there is no store-level test for providers (encryption round-trip), policies, guardrails, or `IncrementAgentNetworkConsumption` (concurrent upsert, window-key uniqueness). The consumption upsert is the most performance-sensitive method in this module and the only one without a real-sqlite test.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- No retention / GC for `agent_network_consumption`.
|
||||
- No `Delete` for `Settings` (one row per account, cleared with the account).
|
||||
- No DB-engine-specific tuning — the same struct tags drive sqlite, mysql, postgres.
|
||||
- Provider `extra_values` and `models` are JSON blobs; querying inside them is not supported by design.
|
||||
- `GetAgentNetworkConsumption` "not-found = zero row" contract is convenient but unconventional.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Upstream: [shared/api](10-shared-api.md), [management/agentnetwork](21-management-agentnetwork.md)
|
||||
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level: [../00-overview.md](../00-overview.md)
|
||||
225
docs/agent-networks/modules/21-management-agentnetwork.md
Normal file
225
docs/agent-networks/modules/21-management-agentnetwork.md
Normal file
@@ -0,0 +1,225 @@
|
||||
# management/agentnetwork — domain layer + synth pipeline
|
||||
|
||||
> **Risk level:** High — central business logic + budget enforcement + the source of every middleware-chain change the proxy executes.
|
||||
> **Backward-compat impact:** Additive within the agent-network surface; one **behavioural difference for opted-out accounts** in parser capture (the capture flag is stamped explicitly false instead of being absent — see capture-pointer semantics below). Non-agent-network proxy services are untouched (the synth chain only ships on `agent-net-svc-*` targets).
|
||||
|
||||
## Module boundary
|
||||
|
||||
`management/server/agentnetwork` owns every agent-network entity (providers, policies, guardrails, account budget rules, per-account settings, consumption rows) and **translates them into the in-memory `*rpservice.Service` that the reverse-proxy controller turns into `proto.ProxyMapping`s and pushes to clusters**. It is the *only* writer of the agent-network middleware chain.
|
||||
|
||||
Inside the package: `manager.go` is the CRUD + permissions-gated facade; `synthesizer.go` walks settings + providers + policies + guardrails and emits the per-account service plus every middleware's JSON config; `policyselect.go` runs per-request attribution (min-wins account ceiling, then "drain bigger pool first"); `reconcile.go` diffs successive synth outputs and emits precise Create/Update/Delete proxy-mapping updates plus a peer-map refresh. `labelgen/` mints DNS-safe subdomain labels; `catalog/` is the static provider catalogue; `types/` carries gorm entity structs. The `_realstack_test.go` files in the parent `management/server/` directory exercise the manager + network-map controller end-to-end with no mocks.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `agentnetwork/manager.go` | Manager interface + CRUD + permission gates + bootstrap-settings + reconcile trigger |
|
||||
| `agentnetwork/synthesizer.go` | Settings/policy → wire-format synthesis; sole writer of the proxy middleware chain |
|
||||
| `agentnetwork/policyselect.go` | Per-request policy attribution + account-budget ceiling (min-wins) |
|
||||
| `agentnetwork/reconcile.go` | Per-account synth diff vs in-memory cache → Create/Update/Delete |
|
||||
| `agentnetwork/catalog/catalog.go` | Static provider catalogue (auth headers, identity-injection shapes) |
|
||||
| `agentnetwork/labelgen/{labelgen,words}.go` | DNS-safe subdomain picker + curated wordlist |
|
||||
| `agentnetwork/types/provider.go` | Provider entity + APIKey + Models + ExtraValues + SessionKeys |
|
||||
| `agentnetwork/types/policy.go` | Policy entity + `PolicyLimits` (token + budget) |
|
||||
| `agentnetwork/types/guardrail.go` | Guardrail entity (`ModelAllowlist`, `PromptCapture`) |
|
||||
| `agentnetwork/types/budgetrule.go` | `AccountBudgetRule` (reuses `PolicyLimits`) |
|
||||
| `agentnetwork/types/settings.go` | Per-account `Settings` (Cluster, Subdomain, 3 toggles) |
|
||||
| `agentnetwork/types/consumption.go` | `Consumption` row + `WindowStart` aligner |
|
||||
| `agentnetwork/{synthesizer,policyselect,reconcile,wire_shape}_*test.go` | See test coverage table |
|
||||
| `agentnetwork/types/consumption_test.go` | `WindowStart` alignment proofs |
|
||||
| `agentnetwork/labelgen/labelgen_test.go` | Deterministic picks + exhaustion + fallback |
|
||||
| `management/server/agentnetwork_realstack_test.go` | No-mock provider CRUD → network-map fan-out |
|
||||
| `management/server/agentnetwork_budgetrule_realstack_test.go` | No-mock budget-rule CRUD + settings preserve-immutable |
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### Synthesis (settings/policy → wire format)
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Mutation: provider/policy/guardrail/settings] --> B[managerImpl.reconcile accountID]
|
||||
B --> C{proxyController nil?}
|
||||
C -- yes --> D[accountManager.UpdateAccountPeers only]
|
||||
C -- no --> E[SynthesizeServices]
|
||||
E --> F[loadSettings — NotFound returns ok=false, no synth]
|
||||
F --> G[filterEnabledProviders sorted by CreatedAt]
|
||||
G --> H[filterEnabledPolicies]
|
||||
H --> I[backfillProviderSessionKeys if missing]
|
||||
I --> J[indexProviderGroups: providerID -> sorted source groups]
|
||||
J --> K[buildRouterConfigJSON drops orphan providers]
|
||||
J --> L[buildIdentityInjectConfigJSON per catalog entry]
|
||||
H --> M[mergeGuardrails: union allowlist, OR redact]
|
||||
M --> N[applyAccountCollectionControls account toggle = SOLE capture control]
|
||||
N --> O[marshalGuardrailConfig]
|
||||
K --> P[buildMiddlewareChain 8 middleware entries]
|
||||
L --> P
|
||||
O --> P
|
||||
P --> Q[buildAccountService: AccessGroups=union source groups, noop.invalid target]
|
||||
Q --> R[reconcile.diffMappings vs cache]
|
||||
R --> S[SendServiceUpdateToCluster CREATE/MODIFY/REMOVE]
|
||||
R --> T[accountManager.UpdateAccountPeers — fans synth ACLs into network map]
|
||||
```
|
||||
|
||||
### Budget rule resolution (min-wins, group+user bound)
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[SelectPolicyForRequest in] --> B[checkAccountBudget — runs FIRST, independent of policies]
|
||||
B --> C[GetAccountAgentNetworkBudgetRules]
|
||||
C --> D{for each enabled rule}
|
||||
D --> E{budgetRuleApplies?}
|
||||
E -- no --> D
|
||||
E -- yes --> F[attrGroup = lowestIntersect TargetGroups, in.GroupIDs]
|
||||
F --> G{Token cap enabled?}
|
||||
G -- yes --> H[evalTokenCap user dim + group dim]
|
||||
H --> I{exhausted?}
|
||||
I -- yes --> J[DENY: llm_account.token_cap_exceeded - STOP]
|
||||
I -- no --> K{Budget cap enabled?}
|
||||
G -- no --> K
|
||||
K -- yes --> L[evalBudgetCap user dim + group dim]
|
||||
L --> M{exhausted?}
|
||||
M -- yes --> N[DENY: llm_account.budget_cap_exceeded - STOP]
|
||||
M -- no --> D
|
||||
K -- no --> D
|
||||
D --> O[All rules passed -> fall through to per-policy selection]
|
||||
```
|
||||
|
||||
Key invariant: **rules are checked sequentially and ANY exhausted rule denies (all-must-pass / min-wins).** Untargeted rules (`len(TargetGroups)==0 && len(TargetUsers)==0`) apply to every caller (`policyselect.go:393`).
|
||||
|
||||
### Policy selection (per-peer, per-request)
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Account-budget gate passed] --> B[GetAccountAgentNetworkPolicies]
|
||||
B --> C[filterApplicablePolicies enabled + provider match + group intersect]
|
||||
C --> D{candidates empty?}
|
||||
D -- yes --> E[Allow, empty SelectedPolicyID]
|
||||
D -- no --> F[scoreCandidates -> scoreOne per policy]
|
||||
F --> G[scoreOne: attrGroup + window]
|
||||
G --> H{any cap exhausted?}
|
||||
H -- yes --> I[Drop policy; record last deny code]
|
||||
H -- no --> K[Keep as live candidate]
|
||||
F --> L{live candidates exist?}
|
||||
L -- no --> M[Deny with last exhaustion code]
|
||||
L -- yes --> N[Sort: uncapped wins -> larger group token -> group budget -> user token -> user budget -> oldest CreatedAt]
|
||||
N --> O[winner = scored 0]
|
||||
O --> P[Allow + SelectedPolicyID + AttributionGroupID + WindowSeconds]
|
||||
```
|
||||
|
||||
End-to-end: a mutation calls `managerImpl.reconcile(ctx, accountID)` (`manager.go:205,239,...`). Reconcile defers an `accountManager.UpdateAccountPeers` so the network-map controller re-runs and `injectAllProxyPolicies` picks up the new access groups; with a `proxyController` wired, it re-synthesizes the service, diffs against `reconcileCache[accountID]` (guarded by `reconcileMu`), and emits proto mappings to the cluster derived from the mapping's domain (`reconcile.go:120`). Synthesis is stateless and idempotent. Sole persistent side effect: `backfillProviderSessionKeys` (`synthesizer.go:249`) mints ed25519 keys on legacy provider rows and writes them back.
|
||||
|
||||
At request time the path is independent: the proxy calls `SelectPolicyForRequest` (`policyselect.go:56`); account-budget ceiling first, then per-policy scoring. Token + budget caps share `evalTokenCap` / `evalBudgetCap` — same primitive for account rules and policy limits, `label` differentiates the deny reason. After a served request, `RecordAccountBudgetUsage` (`policyselect.go:415`) fans deltas to every applicable rule's distinct `(dim_kind, dim_id, window)` tuple, deduplicating to prevent double-count when two rules share target+window.
|
||||
|
||||
## Public contracts
|
||||
|
||||
- **Manager interface** (`manager.go:48-80`): CRUD for `Providers/Policies/Guardrails/BudgetRules`; `GetSettings/UpdateSettings` (cluster + subdomain immutable, only the three toggles mutate); `ListConsumption/RecordConsumption(account, kind, dimID, windowSec, in, out, USD)`; `RecordAccountBudgetUsage(account, user, groups, in, out, USD)`; `SelectPolicyForRequest(ctx, PolicySelectionInput) → *PolicySelectionResult{Allow, SelectedPolicyID, AttributionGroupID, WindowSeconds, DenyCode, DenyReason}`.
|
||||
- **`PolicySelectionInput`** (`manager.go:85-90`): `{AccountID, UserID, GroupIDs, ProviderID}` — populated by the proxy from CapturedData + `llm_router` resolution.
|
||||
- **Synthesized middleware chain** (`synthesizer.go:576-657`), order load-bearing — response slot runs reverse-of-slice:
|
||||
|
||||
| Slot | Idx | ID | ConfigJSON shape | CanMutate |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| on_request | 0 | `llm_request_parser` | `{"capture_prompt": <bool>, "redact_pii"?: true}` | – |
|
||||
| on_request | 1 | `llm_router` | `{"providers":[{id, models[], upstream_*, auth_header_*, allowed_group_ids[]}]}` | **true** |
|
||||
| on_request | 2 | `llm_limit_check` | `{}` | – |
|
||||
| on_request | 3 | `llm_identity_inject` | `{"providers":[{provider_id, header_pair?, json_metadata?, extra_headers?}]}` | **true** |
|
||||
| on_request | 4 | `llm_guardrail` | `{"model_allowlist"?, "prompt_capture":{enabled,redact_pii}}` | – |
|
||||
| on_response | 5 | `llm_limit_record` | `{}` (runs LAST at runtime) | – |
|
||||
| on_response | 6 | `cost_meter` | `{}` | – |
|
||||
| on_response | 7 | `llm_response_parser` | `{"capture_completion": <bool>, "redact_pii"?: true}` | – |
|
||||
- **Synthesized service shape** (`synthesizer.go:739`): `Mode=HTTP`, `Private=true`, `Domain=<subdomain>.<cluster>`, `AccessGroups=unionSourceGroups(enabledPolicies)`, one `TargetTypeCluster` target with `Host=noop.invalid:443` (router rewrites per request), `Options.{DirectUpstream,AgentNetwork}=true`, `DisableAccessLog=!settings.EnableLogCollection`, `CaptureMax{Req,Resp}Bytes=1<<20`, `CaptureContentTypes=["application/json","text/event-stream"]`.
|
||||
|
||||
## Invariants
|
||||
|
||||
- **Min-wins / all-must-pass for account budget rules** (`checkAccountBudget`, `policyselect.go:353`): every applicable enabled rule is checked; first exhausted cap denies. Untargeted rules bind every caller.
|
||||
- **Account toggle is the SOLE control for capture enablement.** `applyAccountCollectionControls` (`synthesizer.go:701`) sets `merged.PromptCapture.Enabled = settings.EnablePromptCollection` *unconditionally*.
|
||||
- **Capture-pointer semantics on parser configs** — see "Things to scrutinize" below.
|
||||
- **`EnableLogCollection` ↔ `DisableAccessLog` is the only access-log toggle** (`synthesizer.go:770`). Default off ⇒ access log suppressed.
|
||||
- **`RedactPii` flows verbatim to BOTH parsers** (`synthesizer.go:584-585`) and is OR'd into the merged guardrail (`synthesizer.go:706`).
|
||||
- **Cluster and Subdomain are immutable on Settings.** `UpdateSettings` reloads existing row and overlays only the three toggles (`manager.go:558-561`).
|
||||
- **Orphan providers (no enabled policy authorises them) NEVER reach the router** (`synthesizer.go:351-357`); skipped from `identity_inject` for symmetry.
|
||||
- **Provider creation refuses empty `api_key`** (`manager.go:175`); **deletion refuses while any policy still references it** (`manager.go:265-273`).
|
||||
- **Session keypair stability across provider edits** (`manager.go:226-228`) — server-managed, copied through every `UpdateProvider`, never API-surfaced.
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
|
||||
- **Capture-pointer semantics — `*bool` vs `bool`.** Three states, owned by separate sides:
|
||||
- **Wire JSON this module emits:** `buildParserConfigJSON` (`synthesizer.go:678-693`) *always* stamps the capture field. Agent-network targets ship `"capture_prompt": false` or `"capture_prompt": true` — never absent. Same for `"capture_completion"`. The happy-path test pins `{"capture_prompt":false}` (`synthesizer_test.go:174`).
|
||||
- **Proxy-side parser config (consumer):** parsers decode into `*bool`. Matrix:
|
||||
- `nil` (field absent) → **legacy default = emit**. Preserved for non-agent-network callers and pre-existing tests (the backward-compat hook).
|
||||
- `false` (field present, value false) → **suppress emission entirely**. The behaviour for opted-out agent-network accounts. Without this, `enable_log_collection=true` + `enable_prompt_collection=false` would leak raw user input AND raw model output to the access log.
|
||||
- `true` → emit normally.
|
||||
- **Why the synth always stamps a value:** an agent-network mapping omitting the field would hit legacy "always emit" and re-introduce the leak. The `json.Marshal` error fallback at `synthesizer.go:687` degrades to `{}` — comment-claimed unreachable, but if ever fired re-introduces the leak. Consider fail-closed (return literal `{"capture_prompt":false}`) instead.
|
||||
- **`scoreCandidates` non-cumulative deny code.** Only the *last* exhausted policy's deny code survives (`policyselect.go:188-190`). Iteration order is store's natural order. Auth signal is `len(scored)==0`, so this is informational only — verify no UI depends on "first exhausted policy" semantics.
|
||||
- **`effectiveWindowSeconds` token-wins tiebreak.** When both halves are enabled with different windows, token's window wins (`policyselect.go:482`). Verify `RecordLLMUsage` increments against the winning window only.
|
||||
- **`RecordAccountBudgetUsage` dedup.** Two rules with the same `(kind, dim_id, window)` would double-count without the `tuples` map (`policyselect.go:434-449`). Key includes all three dimensions — correct.
|
||||
- **Fail-closed on bad provider:** unknown catalog id (`synthesizer.go:794-796`) or empty API key (`synthesizer.go:801-803`) drops the **entire** account's synth, not just the bad provider. Confirm matches operator UX.
|
||||
|
||||
### Security
|
||||
|
||||
- **Redact OR-merge:** merged `RedactPii` = account OR guardrail (`synthesizer.go:706`). **Parser-side flag is `settings.RedactPii` only, NOT the OR** — a guardrail-only opt-in does not propagate to parsers. Correct because the account toggle gates capture, but worth noting on the proxy side.
|
||||
- **Group resolution must not leak across accounts.** Every store call carries `accountID` (`policyselect.go:73, 286, 298, 322, 334, 354`); `lowestIntersect` uses caller's claimed groups only (`policyselect.go:494`). Risk surface is upstream (handler populates `in.GroupIDs`).
|
||||
- **`UpdateSettings` preserves immutable Cluster + Subdomain** (`manager.go:558`). A client can't rebind the cluster.
|
||||
- **Provider session keypair backfill writes through `SaveAgentNetworkProvider`** (`synthesizer.go:256`) from a read-shaped call. Idempotent → worst case is a wasted write under concurrent reconcile + snapshot.
|
||||
|
||||
### Concurrency
|
||||
|
||||
- **`reconcileMu`** guards `reconcileCache`. Lock window is narrow — compute diff inside, send outside (`reconcile.go:56-68`).
|
||||
- **`labelRngMu`** guards `labelRng` because `math/rand.Source` is unsafe for concurrent use (`manager.go:638-640`).
|
||||
- **Real-store tests** use `store.NewTestStoreFromSQL` with `t.TempDir()` per test — no shared state, no `t.Parallel()`.
|
||||
- **`RecordAccountBudgetUsage` dedup `tuples` map is per-call;** concurrent calls fan out fully — correct (each request's tokens book once per applicable rule).
|
||||
- **Deferred `UpdateAccountPeers` runs inline after the proxy push** (`reconcile.go:28-35`); a slow call stretches CRUD response time.
|
||||
|
||||
### Backward compatibility
|
||||
|
||||
- **Capture-pointer semantics (restated):** non-agent-network callers see no field → legacy nil-default emit, identical to pre-PR. Agent-network targets always carry an explicit `capture_*` value.
|
||||
- **`TestSynthesizeServices_HappyPath` was updated:** request-parser config moved from `{}` to `{"capture_prompt":false}` (`synthesizer_test.go:174`). External snapshot tests against synth output need updating.
|
||||
- **`MergedGuardrails` retains zeroed `TokenLimits`/`Budget`/`Retention`** even though `Policy.Limits` carries the real values now; `llm_limit_check` is the authoritative enforcement. Comment at `synthesizer.go:940-948` calls this out.
|
||||
|
||||
### Performance
|
||||
|
||||
- **`SynthesizeServices` runs on every controller tick / mutation reconcile.** Cost: 4 store reads + optional per-provider keypair backfill. Sort + index + merge are O(N log N) / O(P × G); dominant cost is JSON marshalling. No nested loops escape these dimensions.
|
||||
- **`reconcile.diffMappings` is O(N + M)** with N=M=1 per account today — effectively constant.
|
||||
- **`SynthesizeServicesForCluster`** (`synthesizer.go:71`) walks every account on a cluster; per-account failures are **swallowed** (`synthesizer.go:91-93`) so a single misconfigured account doesn't drop the cluster. Runs per proxy reconnect.
|
||||
|
||||
### Observability
|
||||
|
||||
- **Activity codes:** `AgentNetwork{Provider,Policy,Guardrail,BudgetRule}{Created,Updated,Deleted}`; `AgentNetworkSettingsUpdated` with `log_collection/prompt_collection/redact_pii` payload (`manager.go:567-571`). **No activity code for `SelectPolicyForRequest` denies** — surfaced via proxy access log only (likely intentional given volume).
|
||||
- **Deny codes** namespaced: `llm_policy.{token,budget}_cap_exceeded`, `llm_account.{token,budget}_cap_exceeded` (`policyselect.go:18-26`).
|
||||
- **Reconcile failures are logged at warn and swallowed** (`reconcile.go:42-44`). Persistent synth failures (e.g. unknown catalog id) silently keep the proxy out of sync — consider a manager-level synth-health surface if this becomes a support burden.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| `synthesizer_test.go` | Mock-store: `HappyPath` (8-mw chain ordering, `{"capture_prompt":false}` baseline); `No{Settings,Providers}`; `Disabled{Provider,Policy}_NoService`; `RouterConfigOrdering`; `PolicyCheckConfig_UnionsSourceGroups`; `OrphanProvider_HasEmptyAllowedGroups`; identity-inject for LiteLLM / Bifrost (overrides + partial disable) / Cloudflare / Portkey / Vercel / OpenRouter / generic non-customizable; `GuardrailMerge_AllowlistUnion_LimitsRestrictive`; `BackfillsMissingSessionKeys`; `HTTPUpstream_KeepsExplicitPort`; `UpstreamURLPath_FlowsToRouter`; `UnknownProviderID_FailsClosed`; `EmptyAPIKey_FailsClosed`. |
|
||||
| `synthesizer_realstore_test.go` | Real-sqlite: `SurvivesStatusToggle` reproduces the disable/re-enable 403 regression; `Reconcile_RealStore_PushesPrivateAfterStatusToggle` extends through reconcile push. |
|
||||
| `synthesizer_guardrail_realstore_test.go` | `PromptCaptureAccountIsSoleControl`; `PromptCaptureFlowsWhenAccountOptsIn`; `AccountRedactWithoutGuardrailRedact`; `NoGuardrail_CaptureOff`. |
|
||||
| `synthesizer_log_collection_realstore_test.go` | `LogCollection{Off_SuppressesAccessLog,On_PermitsAccessLog}` — verifies `DisableAccessLog` propagation through `ToProtoMapping`. |
|
||||
| `synthesizer_parser_redact_realstore_test.go` | **Capture-pointer regression suite:** `ParserConfigsCarryRedactPii`; `ParserConfigsSuppressCaptureWhenLogCollectionOnly` (log=on/prompt=off ⇒ both capture flags false); `ParserConfigsOmitRedactPiiWhenOff`. |
|
||||
| `policyselect_test.go` | Mock-store: `NoApplicablePolicies`; `AllowWithLowestGroupAttribution`; `LargerPoolWinsAcrossUsageLevels`; `StaysOnLargerPoolAfterPartialDrain`; `FallsThroughToSmallerPoolWhenLargerExhausted`; `TiebreakBy{LargerGroupPool,CreatedAt}`; `DeniesWhenAllExhausted`; `UncappedPolicyAlwaysWinsAgainstCapped`; `DisabledPolicyIgnored`; `StoreErrorPropagates`; `RejectsEmptyAccount`; `SharesGroupCounterAcrossPolicies`; `AntiFallThroughOnLowestGroup`; `BudgetOnlyExhaustionDenies`; `BudgetTighterThanTokenWins`. |
|
||||
| `policyselect_realstore_test.go` | Real-sqlite regression guard: `NoApplicablePolicies`; `AllowAndLowestGroupAttribution`; `LargerPoolWins_FallsThroughWhenExhausted`; `BudgetCapDenies`; `GroupCounterSharedAcrossPolicies`; `DisabledPolicyIgnored`. |
|
||||
| `policyselect_account_realstore_test.go` | Account budget rules: `AccountCeilingBindsEvenWithUncappedPolicy` (min-wins); `AccountGroupCeiling`; `AccountTargetUsersBindsOnlyThatUser`; `AccountRuleRecordsToOwnWindow`. |
|
||||
| `reconcile_test.go` | `FirstSynth_EmitsCreate`; `NoChange_EmitsNothingExtra` (re-push as Modified — verify desired); `PolicyRemoved_EmitsDelete`; `NilProxyController_NoOp`; `EmptyAccountID_NoOp`; `ClusterFromMapping`. |
|
||||
| `wire_shape_test.go` | `TestSynthesizedService_WireShape` — proto-shape lockdown via `ToProtoMapping`. Catches "service not matching" (mapping reaches proxy but no SNI/HTTP route). Asserts ID, Domain, Mode, AuthToken, `Private`, `Auth.Oidc=false`, one path `/` + `https://noop.invalid/`, 8 middlewares with correct slot enums, router config `auth_header_value="Bearer sk-test-key"`. |
|
||||
| `labelgen/labelgen_test.go` | `PickUnique_{DeterministicWithSeededRng,AvoidsTakenWordsWhenMostAreReserved,FallsBackWhenAllReserved}`; `UniqueWords_DropsDuplicates`. |
|
||||
| `types/consumption_test.go` | `WindowStart_{AlignedToUnixEpoch,WithinWindowConverges,AcrossWindowsDiverges,DifferentWindowsHaveDifferentBuckets,SubMinuteAndMinuteAlignment,ZeroWindowReturnsInputUTC}`. Bucket alignment so multi-node reads converge. |
|
||||
| `agentnetwork_realstack_test.go` | `ProviderCRUD_FansOutToProxyAndClientPeers` — no-mock end-to-end through real account manager + network-map + agentnetwork: provider create propagates the updated map to both proxy peer and client peer with the synth DNS surface. |
|
||||
| `agentnetwork_budgetrule_realstack_test.go` | `BudgetRuleCRUD_RealManager`; `UpdateSettings_PreservesImmutableAndTogglesCollection`. |
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- **`MergedGuardrails.TokenLimits/Budget/Retention` emit at zero** (`synthesizer.go:940-948`); real enforcement is `Policy.Limits` via `llm_limit_check`. Future cleanup implied.
|
||||
- **Session keys picked from first enabled provider by created_at** (`pickServiceSessionKeys`, `synthesizer.go:270`). Existing session cookies survive provider edits only while the first-by-CreatedAt provider stays in place. Document for operators.
|
||||
- **Reconcile failures silently swallowed** (`reconcile.go:42-44`). Persistent failures keep the proxy out of sync until the next reconcile.
|
||||
- **`scoreCandidates` exposes only the LAST exhaustion's deny code** when multiple policies are exhausted.
|
||||
- **`bootstrapSettingsIfNeeded` failure is non-fatal to provider create** (`manager.go:200`): provider lands, synth is no-op until the next provider create retries the bootstrap.
|
||||
- **Budget rules do not trigger a reconcile** (`manager.go:476-477`). Request-time evaluation only; new rules take effect on the next request without a proxy push.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- **Upstream:** [shared/api](10-shared-api.md), [management/store](20-management-store.md), reverseproxy `service`/`proxy`/`sessionkey` packages, `management/server/permissions` + `activity`.
|
||||
- **Downstream:** [management/handlers (HTTP wiring)](22-management-handlers-wiring.md), [proxy/middleware-builtin](31-proxy-middleware-builtin.md), network-map controller (`injectAllProxyPolicies` fan-out).
|
||||
- **End-to-end flow:** [../01-end-to-end-flows.md](../01-end-to-end-flows.md) — "Provider create → reconcile → proxy push → peer map refresh" and "request → policy select → record" diagrams.
|
||||
- **Top-level:** [../00-overview.md](../00-overview.md)
|
||||
203
docs/agent-networks/modules/22-management-handlers-wiring.md
Normal file
203
docs/agent-networks/modules/22-management-handlers-wiring.md
Normal file
@@ -0,0 +1,203 @@
|
||||
# management/handlers + wiring — HTTP API + gRPC delivery
|
||||
|
||||
> **Risk level:** Medium — the surface is mostly additive, but two changes are load-bearing: `injectAllProxyPolicies` runs on every per-peer compute, and `shallowCloneMapping` must round-trip `Private` (a missed field silently breaks every MODIFIED).
|
||||
> **Backward-compat impact:** Additive on the wire (new routes, new RPCs, new proto fields, new gorm column on `AccessLogEntry`). One management-internal break: `nbhttp.NewAPIHandler` gains a trailing `agentNetworkManager` parameter; `nil` is tolerated and silently skips route registration.
|
||||
|
||||
## Module boundary
|
||||
|
||||
This module is the seam between the public Agent Network HTTP API and the proxy fleet that serves agent traffic. North side: a `/api/agent-network/*` surface (providers, policies, guardrails, budget rules, settings, consumption) on the existing gorilla router, delegating to `agentnetwork.Manager`. Handlers are thin — they translate `api.*` ↔ `types.*`, validate shape, forward. RBAC and event emission stay inside the manager (`manager.go:680-682`).
|
||||
|
||||
South side: `ProxyServiceServer` (`proxy.go`) learns to (a) ship synth services to a proxy on initial snapshot, (b) resolve agent-network domains in `getServiceByDomain` for OIDC/session/tunnel-peer flows, (c) gate LLM requests via `CheckLLMPolicyLimits` + `RecordLLMUsage`, (d) preserve `Private` through `shallowCloneMapping` so per-proxy live updates don't silently flip services public. The network_map controller prepends synth services to `account.Services` on every per-peer compute; `accesslogentry.go` gains an indexed `AgentNetwork` column so the dashboard can filter cheaply.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `handlers/agentnetwork/providers_handler.go` | Catalog + provider CRUD + central `AddEndpoints` |
|
||||
| `handlers/agentnetwork/policies_handler.go` | Policy CRUD + shared `validatePolicy*` |
|
||||
| `handlers/agentnetwork/guardrails_handler.go` | Guardrail CRUD |
|
||||
| `handlers/agentnetwork/budget_handler.go` | Account-level budget rule CRUD |
|
||||
| `handlers/agentnetwork/settings_handler.go` | GET (200+`null` if unbootstrapped) + PUT toggles |
|
||||
| `handlers/agentnetwork/consumption_handler.go` | Read-only consumption rows |
|
||||
| `handlers/agentnetwork/handlers_test.go` | Real-store fixture; wire round-trip + validation |
|
||||
| `handlers/agentnetwork/budget_handler_test.go` | Budget-rule + settings toggles |
|
||||
| `server/http/handler.go` | New `agentNetworkManager` arg; conditional `AddEndpoints` |
|
||||
| `server/permissions/modules/module.go` | New `AgentNetwork` module key |
|
||||
| `internals/server/boot.go` | Wires synthesiser adapter + limits service into proxy server |
|
||||
| `internals/server/modules.go` | `AgentNetworkManager()` lazy-create node |
|
||||
| `internals/controllers/network_map/controller/controller.go` | `injectAllProxyPolicies` replaces 4 `InjectProxyPolicies` calls |
|
||||
| `internals/controllers/network_map/controller/repository.go` | `SynthesizeAgentNetworkServices` repo method |
|
||||
| `internals/modules/reverseproxy/service/service.go` | `MiddlewareConfig`, capture limits, `AgentNetwork`, `DisableAccessLog` + proto |
|
||||
| `internals/modules/reverseproxy/accesslogs/accesslogentry.go` | Indexed `AgentNetwork bool` from proto |
|
||||
| `internals/shared/grpc/proxy.go` | Synth wiring, 2 RPCs, domain fallback, `Private` in clone |
|
||||
| `internals/shared/grpc/proxy_clone_test.go` | Locks every `ProxyMapping` field minus `AuthToken` |
|
||||
| `server/activity/codes.go` | 13 new activity codes (125-137) |
|
||||
|
||||
## HTTP routes added
|
||||
|
||||
All routes inherit the platform's auth middleware. Perms enforced inside `agentnetwork.Manager.requirePermission` (`manager.go:680-682`) on `modules.AgentNetwork`. Permission column shows the `op` passed to `requirePermission` — read = `Read`, etc.
|
||||
|
||||
| Method | Path | Perm | Handler |
|
||||
| ------ | ---- | ---- | ------- |
|
||||
| GET | `/agent-network/catalog/providers` | authn only | `providers_handler.go:43` |
|
||||
| GET | `/agent-network/providers` | read | `providers_handler.go:57` |
|
||||
| POST | `/agent-network/providers` | create | `providers_handler.go:97` |
|
||||
| GET | `/agent-network/providers/{providerId}` | read | `providers_handler.go:77` |
|
||||
| PUT | `/agent-network/providers/{providerId}` | update | `providers_handler.go:132` |
|
||||
| DELETE | `/agent-network/providers/{providerId}` | delete | `providers_handler.go:172` |
|
||||
| GET | `/agent-network/policies` | read | `policies_handler.go:32` |
|
||||
| POST | `/agent-network/policies` | create | `policies_handler.go:72` |
|
||||
| GET | `/agent-network/policies/{policyId}` | read | `policies_handler.go:52` |
|
||||
| PUT | `/agent-network/policies/{policyId}` | update | `policies_handler.go:102` |
|
||||
| DELETE | `/agent-network/policies/{policyId}` | delete | `policies_handler.go:142` |
|
||||
| GET | `/agent-network/guardrails` | read | `guardrails_handler.go:25` |
|
||||
| POST | `/agent-network/guardrails` | create | `guardrails_handler.go:65` |
|
||||
| GET | `/agent-network/guardrails/{guardrailId}` | read | `guardrails_handler.go:45` |
|
||||
| PUT | `/agent-network/guardrails/{guardrailId}` | update | `guardrails_handler.go:95` |
|
||||
| DELETE | `/agent-network/guardrails/{guardrailId}` | delete | `guardrails_handler.go:135` |
|
||||
| GET | `/agent-network/budget-rules` | read | `budget_handler.go:24` |
|
||||
| POST | `/agent-network/budget-rules` | create | `budget_handler.go:64` |
|
||||
| GET | `/agent-network/budget-rules/{ruleId}` | read | `budget_handler.go:44` |
|
||||
| PUT | `/agent-network/budget-rules/{ruleId}` | update | `budget_handler.go:95` |
|
||||
| DELETE | `/agent-network/budget-rules/{ruleId}` | delete | `budget_handler.go:135` |
|
||||
| GET | `/agent-network/settings` | read | `settings_handler.go:53` (200+`null` if no row) |
|
||||
| PUT | `/agent-network/settings` | update | `settings_handler.go:27` |
|
||||
| GET | `/agent-network/consumption` | read | `consumption_handler.go:21` |
|
||||
|
||||
## gRPC RPCs added (or modified)
|
||||
|
||||
| RPC | Direction | Trigger |
|
||||
| --- | --------- | ------- |
|
||||
| `CheckLLMPolicyLimits` | proxy→mgmt unary | Pre-flight gate; returns allow/deny, selected policy, attribution group, window, deny code+reason (`proxy.go:259-301`). `Unimplemented` when limits service is nil. |
|
||||
| `RecordLLMUsage` | proxy→mgmt unary | Post-flight write of tokens+cost against policy-window dimensions + every applicable account budget rule (`proxy.go:303-349`). `window_seconds==0` ⇒ no policy cap, only account fan-out runs. |
|
||||
| `GetMappingUpdate`/`SendServiceUpdate` (stream) | mgmt→proxy | Snapshot (`proxy.go:752-780`) now appends `SynthesizeServicesForCluster`. Live updates use `SendServiceUpdateToCluster` + `shallowCloneMapping`. |
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### HTTP request lifecycle
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant DB as Dashboard
|
||||
participant R as gorilla.Router (/api)
|
||||
participant H as handler (agentnetwork)
|
||||
participant M as agentnetwork.Manager
|
||||
participant S as store.Store
|
||||
participant AM as accountManager (StoreEvent)
|
||||
|
||||
DB->>R: POST /api/agent-network/providers
|
||||
R->>H: createProvider (auth mw sets UserAuth)
|
||||
H->>H: GetUserAuthFromContext + validate(req)
|
||||
H->>M: CreateProvider(userID, provider, bootstrapCluster)
|
||||
M->>M: requirePermission(AgentNetwork, Create)
|
||||
M->>S: SaveAgentNetworkProvider
|
||||
M->>AM: StoreEvent(AgentNetworkProviderCreated)
|
||||
M-->>H: created provider
|
||||
H-->>DB: 200 + api.AgentNetworkProvider JSON
|
||||
```
|
||||
|
||||
### Synth-service delivery via gRPC
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant P as Proxy
|
||||
participant G as ProxyServiceServer
|
||||
participant SM as service.Manager (persisted)
|
||||
participant SA as synthesizerAdapter
|
||||
participant AN as SynthesizeServicesForCluster
|
||||
participant ST as store.Store
|
||||
|
||||
Note over P,G: Initial snapshot
|
||||
P->>G: GetMappingUpdate (stream open)
|
||||
G->>SM: GetServicesForCluster(conn.address)
|
||||
SM-->>G: persisted []*Service
|
||||
G->>SA: SynthesizeServicesForCluster(conn.address)
|
||||
SA->>AN: SynthesizeServicesForCluster(store, clusterAddr)
|
||||
AN->>ST: walk every account; read providers/policies/settings
|
||||
AN-->>SA: in-memory []*Service
|
||||
SA-->>G: []*Service
|
||||
G->>P: response (persisted + synth)
|
||||
|
||||
Note over G,P: Per-request live update
|
||||
G->>G: SendServiceUpdateToCluster(update, clusterAddr)
|
||||
G->>G: shallowCloneMapping(update) %% Private MUST survive
|
||||
G->>P: response with single mapping
|
||||
```
|
||||
|
||||
End-to-end: HTTP write persists rows and emits an activity event; the manager then triggers `proxyController.SendServiceUpdate` so proxies re-render. **The snapshot path is the only one that calls into the synthesiser** — on stream open it pulls persisted services then appends synth services for the cluster. Synth services are never persisted. For OIDC/session/tunnel-peer flows, `getServiceByDomain` falls back to `SynthesizeServicesForCluster(clusterFromDomain(domain))` when persisted lookup misses (`proxy.go:1763-1793`). The network_map contribution is orthogonal: per-peer compute prepends the same synth services to `account.Services` before `InjectProxyPolicies`.
|
||||
|
||||
## Permissions model added
|
||||
|
||||
- `permissions/modules/module.go:22` adds `AgentNetwork Module = "agent_network"`, registered in `All` (`module.go:42`). Standard `operations.{Read,Create,Update,Delete}` matrix.
|
||||
- Handlers don't call `permissionsManager` directly — they extract `UserAuth` and delegate to `agentnetwork.Manager`, which gates every mutation through `requirePermission` (`manager.go:168, 308, 549`, etc.). Confirm your role-set provider has `agent_network` rows for owner/admin/user/billing-admin before merging.
|
||||
- `getCatalogProviders` (`providers_handler.go:43`) intentionally skips RBAC — catalog is global static data.
|
||||
|
||||
## Activity codes added
|
||||
|
||||
`activity/codes.go:244-274` adds Activities 125-137 + string/code mappings (`codes.go:428-444`), following `<domain>.<resource>.<action>` (e.g., `agent_network.provider.create`). Audit-log exporters / SIEM forwarders need to know the new codes.
|
||||
|
||||
## Invariants
|
||||
|
||||
- **Synth services are never persisted.** Snapshot appends after `serviceManager.GetServicesForCluster` (`proxy.go:761-770`); network_map prepends before `InjectProxyPolicies` (`controller.go:117-126`).
|
||||
- **`shallowCloneMapping` must round-trip every `ProxyMapping` field except `AuthToken`** — `proxy_clone_test.go:50-58` enforces via `gproto.Equal`. The bug it guards: a missing `Private` made every MODIFIED arrive `private=false`, the proxy skipped `ValidateTunnelPeer`, `UserGroups` stayed empty, `llm_router` denied `no_authorised_provider`; a restart "fixed" it because the snapshot uses the original mapping.
|
||||
- **Limit-window floor is 60s** (`policies_handler.go:189-220`); enabled cap with both per-group and per-user at zero is rejected. Budget rules reuse the same validator (`budget_handler.go:170`).
|
||||
- **Manager is optional at boot.** `NewAPIHandler` registers routes only when non-nil (`handler.go:129`); `ProxyServiceServer` returns `Unimplemented` from both RPCs when limits service is unwired (`proxy.go:262-265, 306-309`).
|
||||
- **Settings GET on an unbootstrapped account returns 200 + `null`** (`settings_handler.go:65-72`) — not 404.
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
- **`injectAllProxyPolicies` runs on every per-peer compute**: `controller.go:163, 309, 415, 681`. `sendUpdateAccountPeers` is the target of the buffered fan-out — synth runs once per debounced account-update tick **and** once per direct `UpdateAccountPeer`. Cost is O(providers + policies × users-per-group) per account under `LockingStrengthNone`. No per-account synth cache — verify it fits the buffer interval for your largest tenant.
|
||||
- **`clusterFromDomain` strips at the first `.`** (`proxy.go:1784-1792`). A zero-dot domain returns `""` and the synth call walks every account. Confirm no path reaches this with a malformed/internal domain.
|
||||
- **Account-budget `RecordConsumption` fans out even when `window_seconds == 0`** (`proxy.go:341-348`) — intentional. Verify the proxy never sends `RecordLLMUsage` for a request that wasn't actually allowed.
|
||||
|
||||
### Security
|
||||
- Every handler extracts `UserAuth` via `nbcontext.GetUserAuthFromContext` before any work. Routes live behind the standard `/api` mux; bypass list is not extended.
|
||||
- `CheckLLMPolicyLimits` / `RecordLLMUsage` ride the existing **proxy → mgmt** gRPC connection auth. No additional token check inside the RPCs — they trust the connection. Confirm the proxy-side token-verification interceptor in this package gates both.
|
||||
- `RecordLLMUsage` only validates `account_id != ""` (`proxy.go:317-319`). A compromised proxy can attribute cost to any account in its cluster — was already true for prior RPCs but is louder now that data drives denials.
|
||||
|
||||
### Concurrency
|
||||
- `SetAgentNetworkSynthesizer` / `SetAgentNetworkLimitsService` write under `s.mu.Lock`; read paths copy the interface under read lock (`proxy.go:236-247, 260-263, 304-307`). Same pattern as existing `serviceManager`/`proxyController` setters.
|
||||
- Manager writes use `LockingStrengthUpdate`; synth reads use `LockingStrengthNone` — read-after-write via the proxy snapshot can observe a stale view by up to one fan-out tick.
|
||||
- Network_map controller is single-threaded per account; cross-account is parallel.
|
||||
|
||||
### Backward compatibility
|
||||
- `proxy_clone_test.go` is the regression net; any new `ProxyMapping` field must be cloned or explicitly nulled in the test.
|
||||
- `AccessLogEntry` adds indexed `AgentNetwork bool` — implicit AutoMigrate; deploy story must handle table-rewrite cost on high-volume access-log tables.
|
||||
- `TargetOptions` gains seven `omitempty` JSON fields (`service.go:69-94`); on-wire shape stays compatible. `targetOptionsToProto` tests all fields when deciding nil (`service.go:551-556`).
|
||||
- `NewAPIHandler` signature changes — every caller must pass `agentNetworkManager`; `nil` is supported.
|
||||
|
||||
### Observability
|
||||
- 13 new activity codes via `accountManager.StoreEvent` in the manager — confirm dashboard's audit-log UI maps them.
|
||||
- `AccessLogEntry.AgentNetwork` is indexed for the dashboard's agent-network log filter.
|
||||
- New RPCs log at error level on store/selector failures (`proxy.go:284, 327, 332, 348`). Snapshot synth failures degrade to warnings — stream is not aborted (`proxy.go:765`).
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test | Locks down |
|
||||
| ---- | ---------- |
|
||||
| `handlers_test.go::TestPolicyHandler_WindowSecondsRoundTrip` | GET carries `window_seconds`; legacy `window_hours`/`window_days` absent. |
|
||||
| `handlers_test.go::TestPolicyHandler_RejectsSubMinuteWindow` | POST `<60s` returns 4xx. |
|
||||
| `handlers_test.go::TestConsumptionHandler_EmptyAccountReturnsArray` | `/consumption` returns `[]` — never null. |
|
||||
| `handlers_test.go::TestConsumptionHandler_PopulatedAccountListsRows` | RecordConsumption×2 surfaces both with correct tokens/cost/window. |
|
||||
| `budget_handler_test.go::TestBudgetRuleHandler_RoundTrip` | Targets + PolicyLimits shape round-trip. |
|
||||
| `budget_handler_test.go::TestBudgetRuleHandler_ListReturnsArray` | Empty-list shape. |
|
||||
| `budget_handler_test.go::TestBudgetRuleHandler_{RejectsMissingName,RejectsSubMinuteWindow}` | Validation rejections are 4xx. |
|
||||
| `budget_handler_test.go::TestSettingsHandler_GetExposesCollectionToggles` | All four toggles + computed `Endpoint`. |
|
||||
| `proxy_clone_test.go::TestShallowCloneMapping_PreservesAllFieldsExceptAuthToken` | Future-proofs clone; every field round-trips, `AuthToken` dropped. |
|
||||
|
||||
Handler tests use a real sqlite store + real manager + always-allow permissions mock (`handlers_test.go:53-75`). Create/update/delete success paths flow through `accountManager.StoreEvent` which the fixture doesn't wire — covered by manager-level no-mock tests outside this module.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- No pagination on any list endpoint; no bulk endpoints.
|
||||
- Synth result is not cached — every snapshot and every per-peer compute repeats the store walk.
|
||||
- `getSettings` returning `200 + null` is a deliberate dashboard concession.
|
||||
- No rate-limiting beyond the global `/api` rate limiter.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Upstream: [shared/api](10-shared-api.md), [management/agentnetwork](21-management-agentnetwork.md), [management/store](20-management-store.md)
|
||||
- Downstream: [proxy/runtime](33-proxy-runtime.md)
|
||||
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level: [../00-overview.md](../00-overview.md)
|
||||
215
docs/agent-networks/modules/30-proxy-middleware-framework.md
Normal file
215
docs/agent-networks/modules/30-proxy-middleware-framework.md
Normal file
@@ -0,0 +1,215 @@
|
||||
# proxy/middleware-framework — generic plugin system
|
||||
|
||||
> **Risk level:** **High** — every proxied request transits this chain. Budget exhaustion, panic recovery, or chain-close bugs hit the hot path for all targets, not just agent-network ones.
|
||||
> **Backward-compat impact:** Additive at the proxy. The `middleware` and `bodytap` packages are new (`proxy/internal/middleware/middleware.go:1`, `proxy/internal/middleware/bodytap/request.go:13`); existing proxy targets keep working until a chain is bound to them via `Manager.Rebuild`.
|
||||
|
||||
This module is the **framework only** — no LLM/agent-network domain knowledge is required, since every example built into it is generic.
|
||||
|
||||
## Module boundary
|
||||
|
||||
This module is the **framework only**: slots, chains, registry, dispatcher, accumulator, body-tap, output filters. No middleware *implementation* lives here — those land in `proxy/internal/middleware/builtin/*` (covered in module 31). The package contract is:
|
||||
|
||||
1. The proxy hands a `Manager` to its config-apply path. The synth pushes per-path `PathTargetBinding` lists (`proxy/internal/middleware/manager.go:26`) into `Manager.Rebuild`, which resolves each spec via the `Registry`/`Resolver` (`proxy/internal/middleware/registry.go:81-121`) and produces an immutable `Chain` keyed by `serviceID|pathID` (`proxy/internal/middleware/manager.go:410-412`).
|
||||
2. The reverse-proxy handler captures the request body via `bodytap.CaptureRequest`, calls `Chain.RunRequest`, applies returned mutations (already filtered by `chain.applyMutations`), forwards to the upstream behind a `bodytap.CapturingResponseWriter`, then calls `Chain.RunResponse` and `Chain.RunTerminal`.
|
||||
3. Middlewares are inert plugins that receive a deep-cloned `Input` and return an `Output` whose decision/mutations are clamped by the dispatcher's `filterOutput` (`proxy/internal/middleware/dispatcher.go:149-172`).
|
||||
|
||||
Everything that crosses the framework boundary in either direction is value-typed and deep-copied — middlewares cannot mutate the live request directly, and the framework cannot inadvertently leak middleware-owned slices into the request hot path.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `proxy/internal/middleware/middleware.go` | `Middleware` + `Factory` interfaces. |
|
||||
| `proxy/internal/middleware/types.go` | `Slot`, `FailMode`, `Decision`, all limit constants, `Input`/`Output`/`Mutations`/`UpstreamRewrite`/`AuthHeader` value types. |
|
||||
| `proxy/internal/middleware/spec.go` | Apply-time `Spec` (validated wire shape + runtime-injected fields) and `Clone`. |
|
||||
| `proxy/internal/middleware/registry.go` | `Registry` (factory map, RWMutex) and `Resolver` (Spec → bound `Middleware`). |
|
||||
| `proxy/internal/middleware/manager.go` | `Manager`, `chainTable` reverse index, `Rebuild`/`Invalidate*`, async chain close. |
|
||||
| `proxy/internal/middleware/chain.go` | `Chain.RunRequest`/`RunResponse`/`RunTerminal`, mutation gating, `cloneInputFor`. |
|
||||
| `proxy/internal/middleware/chain_test.go` | Metadata threading, LIFO response order, rewrite gating, UserGroups propagation, terminal accumulation. |
|
||||
| `proxy/internal/middleware/dispatcher.go` | Timeout/panic recovery, fail-mode, error classification, `filterOutput`. |
|
||||
| `proxy/internal/middleware/decision.go` | `RenderDenyResponse`, deny-code regex, status clamp. |
|
||||
| `proxy/internal/middleware/headerpolicy.go` | Compile-in header denylist + `FilterHeaderMutations`. |
|
||||
| `proxy/internal/middleware/bodypolicy.go` | `ValidateBodyReplace` / `ApplyBodyReplace` smuggling guards. |
|
||||
| `proxy/internal/middleware/keys.go` | Metadata key namespace constants. |
|
||||
| `proxy/internal/middleware/metadata.go` | `Accumulator` — allowlist, per-mw/per-request byte caps, redaction. |
|
||||
| `proxy/internal/middleware/metrics.go` | OTel instrument bundle (`proxy.middleware.*`). |
|
||||
| `proxy/internal/middleware/redaction.go` | `Scan` — PEM/JWT/AWS/bearer/Luhn-validated CC patterns. |
|
||||
| `proxy/internal/middleware/bodytap/request.go` | Capture + replay reader, `Budget` semaphore, bypass reason codes. |
|
||||
| `proxy/internal/middleware/bodytap/response.go` | `CapturingResponseWriter` (tee with `PassthroughWriter` for Flusher/Hijacker preservation). |
|
||||
|
||||
## Slot model
|
||||
|
||||
Three slots, declared per-middleware exactly once (`proxy/internal/middleware/types.go:27-41`):
|
||||
|
||||
- **`SlotOnRequest`** (`Slot=1`) — runs **before** the upstream call, in registration order. May `DecisionDeny`, may emit `Mutations` (header add/remove, body replace, `UpstreamRewrite`) when both `Spec.CanMutate` and `Middleware.MutationsSupported()` are true. May emit metadata. Each middleware in the slot sees metadata that earlier ones in the same slot just emitted (`proxy/internal/middleware/chain.go:144-178`) — this is how the framework gives middlewares an intra-slot side channel without a global bag.
|
||||
- **`SlotOnResponse`** (`Slot=2`) — runs **after** the upstream returns, in **reverse** registration order. Cannot deny (clamped in `dispatcher.filterOutput`, `proxy/internal/middleware/dispatcher.go:153-157`). May still mutate response headers in principle, but the current chain only forwards `RewriteUpstream` from on_request, so on_response mutations are observe-only in practice. Threads the same per-slot metadata view as on_request.
|
||||
- **`SlotTerminal`** (`Slot=3`) — runs **after** every on_response middleware has emitted, in registration order. Sees the full accumulated bag plus prior terminal emissions (`chain.go:221-245`). Cannot deny, cannot mutate (`dispatcher.go:168-170`). Designed for sinks (access log, metrics push, audit emitter).
|
||||
|
||||
Splitting a feature across slots (e.g. "parse on the way out, ship on terminal") is the explicit architectural choice — `types.go:7-15` and `types.go:22-25` make it clear no middleware participates in more than one slot.
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### Chain dispatch
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
participant H as proxy HTTP handler
|
||||
participant BT as bodytap.CaptureRequest
|
||||
participant CH as Chain
|
||||
participant DI as Dispatcher
|
||||
participant MW as Middleware (per slot)
|
||||
participant US as Upstream
|
||||
participant CW as CapturingResponseWriter
|
||||
|
||||
H->>BT: CaptureRequest(r, cfg, budget)
|
||||
BT-->>H: body[], truncated, release()
|
||||
H->>CH: RunRequest(ctx, r, Input, Accumulator)
|
||||
loop on_request, registration order
|
||||
CH->>CH: cloneInputFor(in, OnRequest)
|
||||
CH->>DI: Invoke(ctx, spec, mw, call)
|
||||
DI->>MW: mw.Invoke(callCtx, in)
|
||||
MW-->>DI: Output{decision, metadata, mutations?}
|
||||
DI->>DI: filterOutput (clamp deny, gate mutations)
|
||||
DI-->>CH: filtered Output
|
||||
CH->>CH: Accumulator.Emit (allowlist + caps + redact)
|
||||
alt DecisionDeny
|
||||
CH-->>H: denied, merged, rewrite
|
||||
else allow
|
||||
CH->>CH: applyMutations(r, m) and capture rewrite
|
||||
end
|
||||
end
|
||||
CH-->>H: nil, merged, rewrite
|
||||
H->>US: ProxyRequest (with rewrite/mutations applied)
|
||||
US-->>CW: bytes (streamed, tee'd into cap-bounded buf)
|
||||
CW-->>H: passthrough complete
|
||||
H->>CH: RunResponse(ctx, Input{RespBody:CW.Body(),...}, acc)
|
||||
loop on_response, REVERSE order (LIFO)
|
||||
CH->>DI: Invoke (same wrappers)
|
||||
end
|
||||
H->>CH: RunTerminal(ctx, Input{Metadata:full bag}, acc)
|
||||
H->>BT: release() + CW.Release()
|
||||
```
|
||||
|
||||
### Body-tap mechanics (request + response)
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph req[Request capture — bodytap.CaptureRequest]
|
||||
R0[r.Body] --> R1{cfg.MaxRequestBytes > 0?\nUpgrade absent?\nContent-Type allowed?\nCL <= cap?}
|
||||
R1 -- no --> R2[bypass = reason\nbody = nil\nr.Body untouched]
|
||||
R1 -- yes --> R3[Budget.Acquire(cap)]
|
||||
R3 -- denied --> R4[bypass=BypassBudget]
|
||||
R3 -- ok --> R5[io.LimitReader(r.Body, cap+1)\nio.ReadAll]
|
||||
R5 --> R6{len > cap?}
|
||||
R6 -- truncated --> R7[viewable = buf[:cap]\nr.Body = replayReadCloser{buf, tail}]
|
||||
R6 -- whole --> R8[r.Body = NopCloser(bytes.Reader(buf))\nclose original]
|
||||
R7 --> R9[(release captured\nbudget on req end)]
|
||||
R8 --> R9
|
||||
end
|
||||
|
||||
subgraph resp[Response capture — CapturingResponseWriter]
|
||||
W0[client] -.-> CW[Write(p)]
|
||||
CW --> P1[PassthroughWriter.Write(p)\n— bytes leave to client first]
|
||||
P1 --> P2{!stopped?}
|
||||
P2 -- yes --> P3{remaining = cap - buf.Len()}
|
||||
P3 --> P4[buf.Write(p[:take])\nset truncated if take<n]
|
||||
P2 -- no --> P5[silent drop into the tee\n(client write already done)]
|
||||
end
|
||||
```
|
||||
|
||||
The body-tap is the highest-leak-risk surface in this module; three details matter:
|
||||
|
||||
1. **Request capture is "read-and-replay", not "read-and-forward".** `CaptureRequest` always swaps `r.Body` for either a `bytes.Reader` (whole body fit) or a `replayReadCloser` that replays the captured prefix then drains the remaining stream from the original body (`bodytap/request.go:178-201`). This means the **upstream still sees the full body even when the tap truncates**. The original `r.Body` is **not** closed in the truncated branch — `replayReadCloser.Close()` only closes the tail (`bodytap/request.go:199-201`), which is the same reader, so close once on request end is correct, but reviewers should confirm the upstream proxy always reads to EOF (otherwise the tail is leaked).
|
||||
2. **Response capture is a write-through tee.** `CapturingResponseWriter.Write` forwards to the underlying writer **first** (`bodytap/response.go:116-117`), then tees into `buf` under its own mutex. Client never blocks on the tee. `Flusher`/`Hijacker` are preserved via the embedded `responsewriter.PassthroughWriter`. SSE/chunked streams flow through untouched; middlewares only see the bounded prefix.
|
||||
3. **Budget is a single shared semaphore.** `Manager` constructs one `bodytap.Budget` at startup (`manager.go:138-144`, default `256 MiB` from `bodytap/request.go:39`). Every capture pre-acquires its full `MaxRequestBytes` / `MaxResponseBytes` from the budget regardless of actual body size; that prevents a flood of small captures from collectively exceeding the cap, but it also means a misconfigured `MaxRequestBytes = 1 MiB` with 256 concurrent requests already exhausts the default budget. Reviewers should sanity-check the operator-facing defaults that ship with synth-service.
|
||||
|
||||
The framework explicitly aborts capture (and increments `proxy.middleware.capture_bypass_total`) before reading the first byte when `Upgrade`/`Connection: upgrade` is set (`bodytap/request.go:120-125`), when the content-type isn't in the allowlist (`bodytap/request.go:126-128`), or when the advertised `Content-Length` already exceeds the cap (`bodytap/request.go:131-133`). This is the right place to make sure WebSocket upgrades and large file uploads never reach the buffer.
|
||||
|
||||
## Public contracts
|
||||
|
||||
- **`Middleware` interface** (`middleware.go:14-36`): `ID()`, `Version()`, `Slot()`, `AcceptedContentTypes()`, `MetadataKeys()`, `MutationsSupported()`, `Invoke(ctx, *Input) (*Output, error)`, `Close()`. `MetadataKeys()` is the **closed set** the middleware is allowed to emit — the accumulator drops anything outside it (`metadata.go:71-75`). `Close` must be idempotent (called even when `Invoke` was never reached).
|
||||
- **`Factory` interface** (`middleware.go:44-47`): `ID()`, `New(rawConfig []byte) (Middleware, error)`. `RawConfig` is opaque JSON bytes on the wire (`spec.go:6-12`); each factory owns its own typed config.
|
||||
- **`Decision` type** (`types.go:59-69`): `Allow=0`, `Deny=1`, `Passthrough=2`. Default-zero is permissive — important because every middleware that omits `Decision` gets `Allow`. Dispatcher clamps `Deny` to `Passthrough` outside `SlotOnRequest` (`dispatcher.go:153-157`).
|
||||
- **`Mutations`** (`types.go:196-201`): `HeadersAdd`/`HeadersRemove` (filtered through `headerpolicy.go`), `BodyReplace` (gated through `bodypolicy.go`), and `RewriteUpstream`. `RewriteUpstream` is **last-write-wins** within the on_request slot (`chain.go:170-172`, locked down by `TestChain_RunRequest_LatestRewriteWins`).
|
||||
- **Metadata propagation keys** (`keys.go`): all keys live in a single file and follow `^[a-z][a-z0-9_-]*(\.[a-z0-9_-]*)+$` (`metadata.go:8`). Framework-injected error tagging uses `mw.<id>.error_kind` (`keys.go:81`) so operators can distinguish framework-emitted entries from middleware-emitted ones.
|
||||
|
||||
## Invariants
|
||||
|
||||
- **Per-request context isolation.** `cloneInputFor` deep-copies every mutable field (`Headers`, `RespHeaders`, `Metadata`, `Body`, `RespBody`, `UserGroups`, `UserGroupNames`) before each invocation (`chain.go:286-308`). A misbehaving middleware that mutates `in.Headers` only corrupts its own copy.
|
||||
- **Body-tap bounded by capture limit.** Request side uses `io.LimitReader(r.Body, limit+1)` (`bodytap/request.go:152`) — the `+1` is how the code detects truncation (`bodytap/request.go:160`); the surfaced buffer is sliced back down to `limit`. Response side stops teeing once `buf.Len() >= cap` (`bodytap/response.go:121-133`). Neither side can grow the buffer past the configured cap.
|
||||
- **Headers/body redaction order.** Accumulator runs `Scan(value)` **before** counting cost (`metadata.go:81-82`), so the byte budgets are computed against post-redaction sizes. `Scan` order is PEM → JWT → AWS key → bearer → Luhn-validated CC (`redaction.go:25-51`) — the comment block in `redaction.go:8-13` is explicit that this is best-effort, not DLP.
|
||||
- **No middleware can starve the chain.** Every invocation runs inside `context.WithTimeout(ctx, clampTimeout(spec.Timeout))` in a separate goroutine (`dispatcher.go:51-94`), with the deadline race-`select`ed against the result channel. A blocked middleware fires the timeout path, gets fail-mode'd, and `IncError(kind=timeout)`. Timeouts are clamped to `[10ms, 5s]` (`types.go:80-86`, `dispatcher.go:174-185`).
|
||||
- **Panic recovery.** `recover()` captures the panic, logs only the type + a 4 KiB stack prefix (no panic value — avoids leaking secrets the middleware was processing), and produces a `panicError` that flows through fail-mode (`dispatcher.go:64-76`).
|
||||
- **Chain immutability + atomic swap.** `chainTable` is cloned on every `Rebuild`/`Invalidate*` and swapped via `atomic.Pointer` (`manager.go:44-69`, `manager.go:221-300`). Readers (`ChainFor`) are lock-free; writers serialise on `writeMu`. The retired chain is `Close`-d in a background goroutine bounded by `chainCloseTimeout = 2 * MaxTimeout` (`manager.go:21-22`, `manager.go:326-346`), so in-flight invocations finish on the old chain after the swap.
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
|
||||
- **Chain ordering deterministic from synth output?** `Manager.buildChain` iterates `b.Specs` in slice order and appends to `bound` (`manager.go:366-391`); `NewChain` then partitions by slot but **preserves slice order within each slot** (`chain.go:50-60`). So order on the wire = order observed at runtime. Synth must therefore emit specs in the intended execution order — there is no per-spec `Priority` field. Worth flagging.
|
||||
- **Decision short-circuit semantics.** `RunRequest` returns immediately on `DecisionDeny` (`chain.go:164-167`) **with the metadata accumulated so far** plus the `denied.Metadata`. Callers that ignore `merged` on deny will lose framework-injected `mw.<id>.error_kind` entries. The proxy runtime is the only caller; confirm it always feeds `merged` into the access log on the deny path as well.
|
||||
- **`UpstreamRewrite` `AuthHeader` bypass** (`types.go:218-235`). The `AuthHeader`/`StripHeaders` fields *intentionally* bypass the header denylist on the basis that the proxy itself rewrites auth. The denylist still blocks middleware-emitted `HeadersAdd: Authorization=...`. This is a delicate carve-out — review the runtime consumer to confirm only the trusted upstream-build path unpacks `AuthHeader`, never the generic `applyMutations` loop.
|
||||
- **`replayReadCloser.Close` only closes the tail** (`bodytap/request.go:199-201`). The replay buffer doesn't own a resource, so this is correct, but it conflates "replay finished" with "underlying body closed". If a caller `Close()`s without reading to EOF, the original body is closed but the captured prefix is lost; harmless for the proxy path (upstream always reads to EOF) but worth a doc-comment.
|
||||
|
||||
### Security
|
||||
|
||||
- **Body-tap memory bounds.** Discussed above — bounded by `MaxBodyCapBytes = 1 MiB` per direction (`types.go:77`) and the shared `Budget` (default 256 MiB). The concerning case is the **deep-copy in `cloneInputFor`** (`chain.go:300-306`): every middleware invocation gets its **own copy** of `Body` and `RespBody`. A chain of N middlewares with a 1 MiB body allocates N MiB of transient bytes per request. With `MaxMiddlewaresPerChain = 16` (`types.go:103`) that's up to 16 MiB extra per in-flight request. Worth pricing into the budget model.
|
||||
- **Header redaction completeness.** `denyHeaders` (`headerpolicy.go:5-17`) covers the auth/forwarding family and framing (`Content-Length`, `Transfer-Encoding`, `Trailer`). `denyHeaderPrefixes` covers `X-Authenticated-*`, `X-Forwarded-*`, `X-Remote-*`, `X-NetBird-*`. Notably absent: `Range`, `If-Match`/`If-None-Match` (mutation could cause cache poisoning), `Origin`/`Referer`. Not necessarily wrong, but worth a deliberate decision.
|
||||
- **Metadata key collisions across middlewares.** The accumulator has no cross-middleware uniqueness check; two middlewares with the same key in their allowlist can both emit it, and both copies land in `merged` (`metadata.go:51-99`). Downstream consumers must tolerate duplicates. Worth documenting.
|
||||
- **Deny rendering.** `RenderDenyResponse` only allows codes matching `^[a-z][a-z0-9._-]{0,63}$` (`decision.go:9`), redacts/truncates message + detail values, caps `Details` at 8 entries (`decision.go:42-50`), clamps status to `[400,499]\{401}` (`decision.go:65-73`). The deny body type is fixed; middlewares cannot inject arbitrary JSON.
|
||||
|
||||
### Concurrency
|
||||
|
||||
- **Per-request state vs shared state in factories.** Each `Factory.New` is called once per chain build; the returned `Middleware` instance is **shared across all requests** for that chain. `Invoke` must be reentrant. The framework does not enforce this — a buggy middleware that holds per-call state on the struct will silently race. Suggest a `// Invoke must be safe for concurrent use` doc on the interface.
|
||||
- **`chainTable` clone-on-write** is correct, but `addChain`/`removeChain` mutate the *cloned* table before the swap (`manager.go:71-108`), and they're called under `writeMu`. Readers only ever see the post-swap pointer. Good.
|
||||
- **`Chain.inflight` WaitGroup**. `Run*` does `Add(1)`/`Done()` (`chain.go:142-143`, `chain.go:194-195`, `chain.go:225-226`); `Close` waits on it bounded by ctx (`chain.go:75-85`). One concern: a *new* `RunRequest` can `Add(1)` *after* `Close` started waiting if the caller still holds a stale chain pointer. `WaitGroup` does not panic on this if the count was already > 0 at `Wait` time, but it does panic if `Add` happens after `Wait` returns and another `Wait` runs. `Close` is documented one-shot, so single-`Wait` is fine, but callers must drop the chain reference before calling `Close`. Worth a code comment near `Close`.
|
||||
- **Goroutine leaks.** `Dispatcher.Invoke` spawns one goroutine per call and *always* writes to a buffered (cap=1) channel (`dispatcher.go:62-76`), so even if the timeout fires the goroutine completes its send and exits. No leak.
|
||||
- **`closeChainsAsync`** detaches retired chains into a goroutine (`manager.go:326-346`). If `Manager` is never GC'd this is fine, but there's no shutdown hook to wait on outstanding closes. Reviewers should confirm the proxy shutdown path explicitly drains in-flight requests before tearing down `Manager`, or accept that the last chain-close round may be cut short on exit.
|
||||
|
||||
### Performance
|
||||
|
||||
- **Allocations per request.** `cloneInputFor` allocates new slices for `Headers`, `RespHeaders`, `Metadata`, `Body`, `RespBody`, `UserGroups`, `UserGroupNames` — once per middleware per request. For a typical 5-middleware chain on a 1 KiB body that's ~10 small slice allocs plus one `Body` copy each. Not a hot-path crisis, but `sync.Pool` for the per-call `Input` would be a natural follow-up.
|
||||
- **Accumulator allocates a fresh `allowSet` per `Emit` call** (`metadata.go:55-58`). One per middleware per slot pass = up to 48 per request. Cheap, but worth noting.
|
||||
- **Regex cost.** `Scan` runs five regex passes on every accepted metadata value (`redaction.go:25-51`). Bounded by `MaxMetadataValueBytes = 4 KiB` so worst case is small.
|
||||
|
||||
### Observability
|
||||
|
||||
- **Per-middleware metrics.** `proxy.middleware.requests_total{middleware,target_id,outcome}` (`metrics.go:34-41`), `duration_ms`, `invocations_total`, `errors_total{kind}`, `metadata_rejected_total{reason}`, `header_mutation_blocked_total{header}`, `capture_bypass_total{reason}`. Comprehensive surface; operators can alert on `errors_total{kind=panic}` and `errors_total{kind=timeout}` separately. **Latency histogram is in milliseconds with default OTel buckets** — for a 10ms–5s timeout range default buckets cover OK, but a custom bucket set centred on 1–500ms would resolve the agent-network response-parser tail better.
|
||||
- **Decision logs.** Panic logs (`dispatcher.go:69`) include `request_id`, type, and stack but not the panic value (safe). `Chain.Close` logs middleware-close errors at debug (`chain.go:91`). `applyMutations` logs body-replace rejections at warn (`chain.go:278`). No log on the deny path itself — by design, since the access-log terminal middleware is expected to record outcomes.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| `proxy/internal/middleware/chain_test.go:77` | `RunRequest` threads metadata across on_request middlewares (regression for the "later mw can't see earlier mw's emissions" bug). |
|
||||
| `chain_test.go:110` | `RunResponse` reverse-order threading. |
|
||||
| `chain_test.go:142` | `cost_meter`-shaped scenario: response_parser registered after cost_meter still emits *before* cost_meter sees the bag (guards the `cost.skipped=missing_tokens` regression). |
|
||||
| `chain_test.go:178` | `UpstreamRewrite` last-write-wins. |
|
||||
| `chain_test.go:206` | No middleware emits → nil rewrite. |
|
||||
| `chain_test.go:224` | Rewrite filtered when `CanMutate=false`. |
|
||||
| `chain_test.go:245` | `Input.UserGroups` propagates verbatim through `cloneInputFor`. |
|
||||
| `chain_test.go:304` | Terminal middlewares see the full accumulated bag + prior terminal emissions. |
|
||||
|
||||
**Gaps** worth raising with the author:
|
||||
- No direct test for `Dispatcher.Invoke` timeout / panic / fail-mode behaviour at the framework level (covered indirectly by built-in tests, but a unit test pinning `errors_total{kind=...}` labels would be cheap insurance).
|
||||
- No test for `bodytap.CaptureRequest` truncated replay (the upstream-sees-full-body invariant is exactly the kind of thing a regression would silently break).
|
||||
- No test for `Budget` exhaustion behaviour under concurrency.
|
||||
- No test for `Manager.InvalidateMiddleware` + `LiveServiceCheck` race (the auth-revocation race the comment at `manager.go:33-38` calls out is the load-bearing reason for `LiveServiceCheck`).
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- **No middleware-to-middleware RPC.** Side-channel is metadata only.
|
||||
- **No streaming body inspection.** Middlewares see a bounded prefix; SSE / chunked parsing happens against that prefix in the response middleware.
|
||||
- **No per-spec priority.** Order is registration order in the spec slice.
|
||||
- **No retry / circuit-breaker** on middleware errors. Fail-mode is binary (open/closed) and per-spec.
|
||||
- **Mutations cannot rewrite the request URL path or query** — only `RewriteUpstream` can change scheme/host (+ optional path replacement, see `types.go:218-235`).
|
||||
- **Redaction is best-effort.** Explicitly documented in `redaction.go:8-13`. Not a DLP solution.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Upstream wire shape: [../modules/10-shared-api.md](10-shared-api.md) (Spec/RawConfig encoding from management).
|
||||
- Built-in middlewares using this framework: [../modules/31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md).
|
||||
- Runtime wiring (where `Manager`, `Chain`, and `bodytap` are consumed by the HTTP handler): [../modules/33-proxy-runtime.md](33-proxy-runtime.md).
|
||||
- End-to-end request flow including capture + chain dispatch: [../01-end-to-end-flows.md](../01-end-to-end-flows.md).
|
||||
- Top-level architecture: [../00-overview.md](../00-overview.md).
|
||||
365
docs/agent-networks/modules/31-proxy-middleware-builtin.md
Normal file
365
docs/agent-networks/modules/31-proxy-middleware-builtin.md
Normal file
@@ -0,0 +1,365 @@
|
||||
# proxy/middleware-builtin — the LLM chain
|
||||
|
||||
The registry-mounted middleware set the proxy executes on every agent-network
|
||||
LLM request. The two highest-blast-radius areas are the **capture-pointer
|
||||
semantics** and the **limit_check ⇒ limit_record** record-once invariant.
|
||||
|
||||
Sibling module: [32-proxy-llm-parsers.md](./32-proxy-llm-parsers.md) — the SDK
|
||||
adapters + pricing catalog this chain delegates to.
|
||||
|
||||
---
|
||||
|
||||
## Module boundary
|
||||
|
||||
This module is the registry-mounted middleware set the proxy executes on
|
||||
every agent-network LLM request. Each sub-package registers itself via
|
||||
`init()`
|
||||
([builtin.go:32–34](../../../proxy/internal/middleware/builtin/builtin.go));
|
||||
the proxy server anonymous-imports the set
|
||||
([all_test.go:11–19](../../../proxy/internal/middleware/builtin/all_test.go))
|
||||
so the registry is populated at boot. The chain is wired by the management
|
||||
synthesiser and executed by the framework
|
||||
(`proxy/internal/middleware/{chain,dispatcher,accumulator}.go` — both out
|
||||
of scope). Everything here reads from / writes to one envelope: the
|
||||
`middleware.KV` metadata bag plus `middleware.Mutations` for header/body
|
||||
rewrites.
|
||||
|
||||
## The 8 middlewares
|
||||
|
||||
| Name | Slot | Inputs (metadata read) | Outputs (metadata written) | Side effects |
|
||||
|---|---|---|---|---|
|
||||
| `llm_request_parser` | OnRequest | `Input.{URL,Body,BodyTruncated}` | `llm.{provider,model,stream,request_prompt_raw,capture_truncated}` | none |
|
||||
| `llm_router` | OnRequest | `llm.model`, `Input.{URL,UserGroups}` | `llm.{resolved_provider_id,authorising_groups}`, `llm_policy.{decision,reason}` | upstream rewrite + auth strip/inject |
|
||||
| `llm_limit_check` | OnRequest | `llm.{resolved_provider_id,model}`, `Input.{AccountID,UserID,UserGroups}` | `llm.{selected_policy_id,attribution_group_id,attribution_window_seconds}`, `llm_policy.{decision,reason}` | gRPC `CheckLLMPolicyLimits` |
|
||||
| `llm_identity_inject` | OnRequest | `llm.{resolved_provider_id,authorising_groups}`, `Input.{UserEmail,UserID,UserGroups,UserGroupNames}` | none | header strip/inject + optional body rewrite |
|
||||
| `llm_guardrail` | OnRequest | `llm.{model,request_prompt_raw}` | `llm_policy.{decision,reason}`, `llm.request_prompt` | none (model allowlist deny) |
|
||||
| `llm_response_parser` | OnResponse | `llm.provider`, `Input.{RespHeaders,RespBody,Status}` | `llm.{input,output,total,cached_input,cache_creation}_tokens`, `llm.response_completion` | none |
|
||||
| `cost_meter` | OnResponse | `llm.{provider,model}`, token buckets | `cost.usd_total` or `cost.skipped` | pricing lookup |
|
||||
| `llm_limit_record` | OnResponse | `llm.{attribution_group_id,attribution_window_seconds,input_tokens,output_tokens}`, `cost.usd_total` | none | gRPC `RecordLLMUsage` |
|
||||
|
||||
[all_test.go:26–40](../../../proxy/internal/middleware/builtin/all_test.go)
|
||||
locks the ID set; adding or removing one is a conscious extension.
|
||||
|
||||
## Files
|
||||
|
||||
| File | LOC | Notes |
|
||||
|---|---:|---|
|
||||
| `builtin.go` | 86 | Registry + `FactoryContext` (ctx, data dir, meter, logger, mgmt client) |
|
||||
| `all_test.go` | 41 | Locks the 8-ID registry surface |
|
||||
| `agentnetwork_chain_integration_test.go` | 319 | Live sqlite + real gRPC bufconn; gate→recorder wire path |
|
||||
| `llm_request_parser/*` | 162 / 66 / 356 | Provider detection, body parse, prompt extraction with capture-pointer gating |
|
||||
| `llm_router/*` | 385 / 84 / 586 | Three-pass route selection (model → groups → path-prefix) |
|
||||
| `llm_limit_check/*` | 196 / 38 / 182 | Pre-flight `CheckLLMPolicyLimits` (2s, fail-open) |
|
||||
| `llm_identity_inject/*` | 440 / 108 / 666 | HeaderPair (LiteLLM) + JSONMetadata (Portkey) + ExtraHeaders |
|
||||
| `llm_guardrail/*` | 176 / 82 / 75 / 219 / 217 | Model allowlist + optional prompt capture with PII redaction |
|
||||
| `llm_response_parser/*` | 258 / 222 / 43 / 433 / 169 / 111 | Buffered + SSE accumulation; AWS event-stream accumulator (`streaming_bedrock.go`) for Bedrock; capture-pointer gates completion emit |
|
||||
| `cost_meter/*` | 181 / 84 / 439 | Token → USD via `proxy/internal/llm/pricing` |
|
||||
| `llm_limit_record/*` | 144 / 35 / 191 | Post-flight `RecordLLMUsage` (5s, debug-on-error) |
|
||||
|
||||
## Per-middleware
|
||||
|
||||
### llm_request_parser
|
||||
|
||||
Detects the LLM provider via `llm.DetectParser` (URL sniff) or by name via
|
||||
`llm.ParserByName` when synthesiser stamps `provider_id`
|
||||
([middleware.go:96–99](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)).
|
||||
**Path-routed providers short-circuit first:** `parseVertexPath` and
|
||||
`parseBedrockPath` ([middleware.go:85–94](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go))
|
||||
pull the model + vendor out of the URL before parser selection runs — Vertex
|
||||
from `/v1/projects/.../publishers/{pub}/models/{model}:{action}` (publisher →
|
||||
vendor via `vertexPublisherVendor`), Bedrock from `/model/{id}/{action}` with
|
||||
`normalizeBedrockModel` stripping the region prefix + version suffix. See
|
||||
[50-path-routed-providers.md](./50-path-routed-providers.md) for the full path
|
||||
grammar. For body-routed providers it decodes the body into `RequestFacts`
|
||||
(model + stream) and extracts the prompt. On
|
||||
`capture_prompt=true` (or absent — see capture-pointer semantics below) the
|
||||
prompt is run through `llm_guardrail.RedactPII` when `redact_pii=true` and
|
||||
truncated rune-safely to 3500 bytes
|
||||
([middleware.go:109–122](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)).
|
||||
**Key invariant:** redaction is parser-side, not guardrail-side — access-log
|
||||
reads `llm.request_prompt_raw` directly.
|
||||
|
||||
### llm_router
|
||||
|
||||
Three-pass route selection in `matchRoute`
|
||||
([middleware.go:241–300](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)):
|
||||
filter by `Models` claim → vendor-pin (a vendor-tagged request never crosses to
|
||||
another vendor's route) → filter by `AllowedGroupIDs` intersection → model
|
||||
precedence over path → tie-break by longest `UpstreamPath` prefix match.
|
||||
Model-miss returns `llm_policy.model_not_routable`; known-but-unauthorised
|
||||
returns `llm_policy.no_authorised_provider`. **Key invariant:** auth-header
|
||||
strip+inject rides on `UpstreamRewrite.{StripHeaders,AuthHeader}`
|
||||
([middleware.go:606–646](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
— NOT `HeadersAdd/HeadersRemove` — because the framework's mutation gate
|
||||
blocks `Authorization` on the generic header path.
|
||||
|
||||
**Path-routed providers route before the model table.** `Invoke` checks
|
||||
`isVertexPath` / `isBedrockPath`
|
||||
([middleware.go:138–216](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
ahead of the model lookup, so a path-carried model can't be claimed by a
|
||||
same-vendor body-routed provider. `matchPathRoute` enforces the route's `Models`
|
||||
allowlist (empty = catch-all) even though the model came from the URL.
|
||||
Two path-only behaviours:
|
||||
- **Vertex unmeterable publisher** — when `llm_request_parser` emits no
|
||||
`llm.provider` (e.g. Gemini/`google`), the router denies with
|
||||
`llm_policy.unmeterable_publisher` (403) rather than forward it uncounted.
|
||||
- **GCP token minting** — when the route carries `GCPServiceAccountKeyB64`
|
||||
(set from a `keyfile::` api_key), `gcpBearer` mints + caches a short-lived
|
||||
OAuth2 token per request instead of injecting a static value; a bad key or
|
||||
unreachable token endpoint denies with `llm_policy.upstream_auth_failed`
|
||||
(502). Bedrock uses its static bearer token directly (no minting).
|
||||
- **`/bedrock` prefix** — an optional `/bedrock` gateway-namespace prefix is
|
||||
accepted and stripped via `RewriteUpstream.StripPathPrefix` so the native
|
||||
`/model/...` path reaches the upstream.
|
||||
|
||||
Full treatment in [50-path-routed-providers.md](./50-path-routed-providers.md).
|
||||
|
||||
### llm_limit_check
|
||||
|
||||
Pre-flight gate. Reads `llm.resolved_provider_id`, calls
|
||||
`CheckLLMPolicyLimits` with a 2s context timeout
|
||||
([middleware.go:24, 97–106](../../../proxy/internal/middleware/builtin/llm_limit_check/middleware.go)),
|
||||
on allow stamps `llm.selected_policy_id`, `llm.attribution_group_id`,
|
||||
`llm.attribution_window_seconds`. **Key invariant:** fail-open. Nil
|
||||
`MgmtClient`, empty provider id, or RPC error returns `allowNoAttribution()`
|
||||
— management outage doesn't take down every LLM request. Operators audit via
|
||||
the access-log; a future flag may switch this to fail-closed.
|
||||
|
||||
### llm_identity_inject
|
||||
|
||||
Dispatches per-rule between LiteLLM-shaped `HeaderPair`
|
||||
([middleware.go:169](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go))
|
||||
and Portkey-shaped `JSONMetadata`
|
||||
([middleware.go:292](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go)).
|
||||
Identity is the peer's email (or `UserID` fallback); tags are the
|
||||
**authorising-groups intersection** emitted by `llm_router`, not the full
|
||||
`UserGroups` — a peer in 5 groups authorised under 1 only tags as that 1.
|
||||
**Anti-spoof:** every `HeadersAdd` is preceded by a `HeadersRemove` of the
|
||||
same name; the framework runs `Remove` before `Add` so client-supplied
|
||||
identity never reaches the upstream. Body-level inject (`tags_in_body`,
|
||||
`end_user_id_in_body`) is skipped on empty / truncated / non-JSON bodies so
|
||||
header attribution stays intact.
|
||||
|
||||
### llm_guardrail
|
||||
|
||||
Model allowlist deny + optional prompt-capture-with-redaction. Allowlist
|
||||
match is case-insensitive via `normaliseModel`; empty allowlist disables the
|
||||
check. Prompt capture reads `llm.request_prompt_raw` and emits
|
||||
`llm.request_prompt` only when `prompt_capture.enabled`
|
||||
([middleware.go:149–165](../../../proxy/internal/middleware/builtin/llm_guardrail/middleware.go)).
|
||||
**Key invariant:** `RedactPII` is the exported function the parsers call —
|
||||
single PII contract across all three keys.
|
||||
|
||||
### llm_response_parser
|
||||
|
||||
Buffered and SSE paths share one `Invoke`
|
||||
([middleware.go:102–127](../../../proxy/internal/middleware/builtin/llm_response_parser/middleware.go)):
|
||||
content-type sniffing dispatches to `invokeBuffered` (JSON, status<400) or
|
||||
`invokeStreaming` (text/event-stream, partial bodies tolerated). Streaming
|
||||
delegates to `accumulateStream`
|
||||
([streaming.go:21–30](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go))
|
||||
using `llm.NewScanner`. A third path, `accumulateBedrockStream`
|
||||
([streaming_bedrock.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock.go)),
|
||||
decodes the AWS binary event-stream (`application/vnd.amazon.eventstream`)
|
||||
returned by Bedrock's `-stream` actions — InvokeModel `chunk` frames wrap a
|
||||
base64 Anthropic event, Converse frames carry text + a trailing usage block.
|
||||
Cached / cache-creation buckets emit only when non-zero, preserving the existing
|
||||
token schema.
|
||||
|
||||
### cost_meter
|
||||
|
||||
Reads `llm.provider` + `llm.model` + token buckets, looks up per-1k rate via
|
||||
`pricing.Loader`, emits `cost.usd_total` or a closed-set `cost.skipped`
|
||||
reason (`missing_provider/model/tokens`, `unparseable_tokens`, `zero_tokens`,
|
||||
`unknown_model`). Loader's hot-reload goroutine is bound to proxy-lifetime
|
||||
context via `startReloader`. **Key invariant:** provider-shape switch lives
|
||||
in `pricing.Table.Cost` (sibling doc) — `cost_meter` stays provider-agnostic.
|
||||
|
||||
### llm_limit_record
|
||||
|
||||
Post-flight write. Always returns `DecisionAllow`; response has already been
|
||||
served so RPC errors mustn't surface (logged at `Debugf`). Skip-on-no-signal
|
||||
at line 81 (zero tokens + zero cost). **Key invariant:** the
|
||||
skip-on-missing-attribution guard at line 98 is a safety net independent of
|
||||
the framework's deny short-circuit — if the gate denied and the framework
|
||||
still runs the recorder, the recorder skips on absent
|
||||
`UserID`+`groupID`+`UserGroups` and no phantom counter materialises.
|
||||
|
||||
## Full-chain diagram (canonical order)
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[HTTP request] --> B[llm_request_parser<br/>OnRequest]
|
||||
B -->|llm.provider, llm.model,<br/>llm.stream, llm.request_prompt_raw| C[llm_router<br/>OnRequest]
|
||||
C -->|llm.resolved_provider_id,<br/>llm.authorising_groups,<br/>upstream rewrite + auth| D[llm_limit_check<br/>OnRequest]
|
||||
D -->|deny path| Z1[403 llm_policy.*]
|
||||
D -->|allow + llm.selected_policy_id,<br/>llm.attribution_group_id,<br/>llm.attribution_window_seconds| E[llm_identity_inject<br/>OnRequest]
|
||||
E -->|header strip+inject<br/>+ optional body rewrite| F[llm_guardrail<br/>OnRequest]
|
||||
F -->|deny: model_blocked| Z2[403 llm_policy.model_blocked]
|
||||
F -->|allow + llm.request_prompt| G[upstream LLM call]
|
||||
G --> H[llm_response_parser<br/>OnResponse]
|
||||
H -->|llm.{input,output,total,cached_input,cache_creation}_tokens,<br/>llm.response_completion| I[cost_meter<br/>OnResponse]
|
||||
I -->|cost.usd_total or cost.skipped| J[llm_limit_record<br/>OnResponse]
|
||||
J --> K[response to client]
|
||||
```
|
||||
|
||||
## limit_check ⇒ limit_record record-once invariant
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant LC as llm_limit_check
|
||||
participant M as management gRPC
|
||||
participant U as upstream LLM
|
||||
participant LR as llm_limit_record
|
||||
participant DB as sqlite consumption table
|
||||
|
||||
LC->>M: CheckLLMPolicyLimits (2s)
|
||||
alt allow
|
||||
M-->>LC: selected_policy_id, attribution_group_id, window_s
|
||||
LC->>U: stamps attribution metadata
|
||||
U-->>LR: response + tokens (via llm_response_parser + cost_meter)
|
||||
LR->>M: RecordLLMUsage (5s, debug-on-error)
|
||||
M->>DB: increment (user, group, window) row
|
||||
else deny
|
||||
M-->>LC: llm_policy.token_cap_exceeded
|
||||
Note over LR: framework short-circuits; even if invoked,<br/>recorder skips on absent UserID+groupID+UserGroups
|
||||
else mgmt nil / rpc error
|
||||
LC-->>LC: allowNoAttribution() — fail open
|
||||
Note over LR: no window_s ⇒ recorder books only account-level<br/>budget rules (which run independently)
|
||||
end
|
||||
```
|
||||
|
||||
The integration test
|
||||
[agentnetwork_chain_integration_test.go](../../../proxy/internal/middleware/builtin/agentnetwork_chain_integration_test.go)
|
||||
exercises all three branches against a real sqlite store + bufconn gRPC —
|
||||
no mocks. Tests: `TestChain_AllowPath_StampsAttributionAndRecordsCounter`
|
||||
(line 130), `TestChain_DenyPath_GateRejectsAndNoConsumptionWritten` (line
|
||||
207), `TestChain_CapExhaustTransition` (line 265).
|
||||
|
||||
## Public contracts (per-middleware JSON config)
|
||||
|
||||
| Middleware | Config shape |
|
||||
|---|---|
|
||||
| `llm_request_parser` | `{provider_id?, redact_pii?, capture_prompt?: *bool}` ([factory.go:19–37](../../../proxy/internal/middleware/builtin/llm_request_parser/factory.go)) |
|
||||
| `llm_router` | `{providers: [{id, models, upstream_scheme, upstream_host, upstream_path?, auth_header_name, auth_header_value, allowed_group_ids}]}` |
|
||||
| `llm_limit_check` | `{}` — pulls `MgmtClient` from `FactoryContext` |
|
||||
| `llm_identity_inject` | `{providers: [{provider_id, header_pair?|json_metadata?, extra_headers?}]}` |
|
||||
| `llm_guardrail` | `{model_allowlist: []string, prompt_capture: {enabled, redact_pii}}` |
|
||||
| `llm_response_parser` | `{redact_pii?, capture_completion?: *bool}` |
|
||||
| `cost_meter` | `{pricing_path?}` (basename inside data-dir; defaults `pricing.yaml`) |
|
||||
| `llm_limit_record` | `{}` — same pattern as `llm_limit_check` |
|
||||
|
||||
All factories accept empty / null / `{}` / whitespace as zero-value config;
|
||||
only structurally invalid JSON is rejected so misconfig surfaces at chain
|
||||
build time.
|
||||
|
||||
## Invariants
|
||||
|
||||
1. **limit_check ↔ limit_record paired.** They MUST appear together. Gate
|
||||
stamps attribution metadata on the request leg; recorder reads it on the
|
||||
response leg. If a chain contains only the recorder, the
|
||||
skip-on-missing-attribution guard at
|
||||
[llm_limit_record/middleware.go:81–87, 98–103](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go)
|
||||
keeps counters consistent but no enforcement runs. Only-gate means
|
||||
counters never tick and headroom appears infinite.
|
||||
|
||||
2. **`capture_prompt` / `capture_completion` pointer semantics.** Both are
|
||||
`*bool`. `nil` = "preserve legacy emit" (back-compat default for
|
||||
non-agent-network callers and pre-toggle tests). `false` = suppress the
|
||||
key entirely (access-log row carries zero prompt / completion content).
|
||||
`true` = emit. The synthesiser sets the pointer explicitly to the
|
||||
account's `EnablePromptCollection` toggle. The handling lives
|
||||
in [llm_request_parser/factory.go:55–61](../../../proxy/internal/middleware/builtin/llm_request_parser/factory.go)
|
||||
and the symmetric [llm_response_parser/middleware.go:62–68](../../../proxy/internal/middleware/builtin/llm_response_parser/middleware.go);
|
||||
a missing pointer must not be treated as `false` (that would suppress
|
||||
capture for legacy non-agent-network callers).
|
||||
`redact_pii` is an orthogonal `bool` controlling **form** of emitted
|
||||
content, not whether it's emitted.
|
||||
|
||||
3. **`redact_pii` is parser-side.** Both parsers import
|
||||
`llm_guardrail.RedactPII` and run it BEFORE stamping the metadata bag.
|
||||
Load-bearing because the access-log sink reads `llm.request_prompt_raw`
|
||||
and `llm.response_completion` directly — by the time `llm_guardrail`
|
||||
runs its own pass on `llm.request_prompt`, the raw key has already been
|
||||
stamped. Tests: `TestInvoke_RedactPii_RedactsBeforeEmittingRawPrompt`,
|
||||
`TestInvoke_RedactPii_RedactsCompletionBeforeEmit`.
|
||||
|
||||
4. **Metadata allowlist enforcement.** Every middleware declares
|
||||
`MetadataKeys()`. The framework accumulator drops any KV outside that
|
||||
allowlist. When adding a new key, also extend the docstring in
|
||||
`middleware/keys.go`.
|
||||
|
||||
5. **Closed deny-code set.** All deny paths emit one of:
|
||||
`llm_policy.model_not_routable`, `llm_policy.no_authorised_provider`,
|
||||
`llm_policy.model_blocked`, `llm_policy.token_cap_exceeded`,
|
||||
`llm_policy.unmeterable_publisher` (path-routed Vertex publisher with no
|
||||
parser → 403), `llm_policy.upstream_auth_failed` (GCP token mint failure →
|
||||
502), or the management-supplied code on `llm_limit_check`. These surface
|
||||
verbatim; arbitrary middleware text never reaches the wire.
|
||||
|
||||
## Things to scrutinise
|
||||
|
||||
**Correctness.** `llm_router` model match treats an empty `Models` slice as
|
||||
"claim every model"
|
||||
([middleware.go:238–248](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
for gateway-style providers — confirm no real provider record ships with an
|
||||
empty `Models` by accident. Path-prefix tie-break falls back to declaration
|
||||
order when no candidate prefix-matches, so the synthesiser must emit a
|
||||
deterministic order. `llm_limit_record` discards `strconv.ParseInt` errors
|
||||
([middleware.go:78–80](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go))
|
||||
— relies on `llm_response_parser` always emitting parseable values; spot-check
|
||||
the streaming partial path on truncated bodies.
|
||||
|
||||
**Security.** Auth headers must NEVER appear on `Mutations.HeadersAdd/Remove`
|
||||
for the router — a direct headers path would bypass the framework gate. The
|
||||
capture-pointer handling is the kind of place a bug ships PII to logs
|
||||
silently; every synthesiser config path must set the pointer explicitly.
|
||||
`llm_identity_inject` body inject silently skips on a
|
||||
non-object `metadata` field
|
||||
([middleware.go:262–270](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go))
|
||||
— header path still attributes, but body-level tag-budget enforcement
|
||||
doesn't run for that request.
|
||||
|
||||
**Concurrency.** `cost_meter` shares a `pricing.Loader` via
|
||||
`atomic.Pointer[Table]`; readers always see a consistent table. Every
|
||||
middleware is a stateless value receiver. Integration test uses real bufconn
|
||||
gRPC — race detector is the meaningful bar.
|
||||
|
||||
**Perf.** Hot path is `lookupKV` linear scan over <10 KVs; `cost_meter.Cost`
|
||||
is O(1); SSE accumulation is single-pass. No map allocation per call.
|
||||
|
||||
**Observability.** Every deny stamps `llm_policy.decision=deny` and a
|
||||
matching `llm_policy.reason` — access-log can pivot on either.
|
||||
`llm_limit_record` only logs at `Debugf` on RPC failure
|
||||
([middleware.go:125–130](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go));
|
||||
operators need an alternate signal (metric on `RecordLLMUsage` failures) for
|
||||
counter accuracy.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| File | Tests | Notes |
|
||||
|---|---:|---|
|
||||
| `all_test.go` | 1 | Registry surface lock |
|
||||
| `agentnetwork_chain_integration_test.go` | 3 | Allow/deny/cap-exhaust vs live sqlite + bufconn gRPC |
|
||||
| `llm_request_parser/middleware_test.go` | 18 | `provider_id` bypass, redaction, capture-pointer, rune-safe truncation |
|
||||
| `llm_router/middleware_test.go` | 19 | Three-pass match, deny codes, path-prefix tie-break, header strip+inject |
|
||||
| `llm_limit_check/middleware_test.go` | 6 | Allow/deny, fail-open on nil mgmt / RPC error, attribution stamping |
|
||||
| `llm_identity_inject/middleware_test.go` | 28 | HeaderPair, JSONMetadata, ExtraHeaders, body inject, anti-spoof |
|
||||
| `llm_guardrail/middleware_test.go` | 15 | Allowlist case-insensitivity, prompt capture toggle, deny shape |
|
||||
| `llm_guardrail/redact_test.go` | 15 | Email, SSN, phone (E.164 + NA), bearer, IPv4; fixture-driven |
|
||||
| `llm_response_parser/middleware_test.go` | 18 | Buffered OAI+Anthro, capture-pointer, redact, truncation |
|
||||
| `llm_response_parser/streaming_test.go` | 7 | OAI usage frame, Anthro message_delta, truncated body best-effort |
|
||||
| `cost_meter/middleware_test.go` | 17 | Each skip reason, provider-shape, pricing loader integration |
|
||||
| `llm_limit_record/middleware_test.go` | 7 | Skip-on-no-signal, skip-on-missing-attribution, RPC failure swallowed |
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Sibling: [32-proxy-llm-parsers.md](./32-proxy-llm-parsers.md) — SDK adapters
|
||||
+ SSE framer + pricing loader.
|
||||
- Path-routed providers (Vertex AI + Bedrock), `keyfile::` credential, GCP
|
||||
token minting, `/bedrock` prefix:
|
||||
[50-path-routed-providers.md](./50-path-routed-providers.md).
|
||||
- Upstream config: `management/server/agentnetwork/synthesizer` (out of scope).
|
||||
- Framework: `proxy/internal/middleware/{chain,dispatcher,accumulator,registry}.go`.
|
||||
- Metadata key registry: `proxy/internal/middleware/keys.go`.
|
||||
- gRPC surface: `proto.ProxyServiceClient.{CheckLLMPolicyLimits,RecordLLMUsage}`.
|
||||
392
docs/agent-networks/modules/32-proxy-llm-parsers.md
Normal file
392
docs/agent-networks/modules/32-proxy-llm-parsers.md
Normal file
@@ -0,0 +1,392 @@
|
||||
# proxy/llm-parsers — SDK adapters + pricing + SSE
|
||||
|
||||
The runtime-agnostic LLM library: the OpenAI Responses API (`/v1/responses`)
|
||||
and the older Chat Completions API (`/v1/chat/completions`), the Anthropic
|
||||
Messages API (`/v1/messages`), the SSE wire format (`event:` / `data:` lines,
|
||||
`\n\n` framing, CRLF tolerance), and per-provider token accounting (OpenAI's
|
||||
cached-prompt **subset** vs Anthropic's cache_read **additive** model). The
|
||||
pricing table's per-provider cost formula is the highest-leverage place a
|
||||
small bug would silently mis-bill operators.
|
||||
|
||||
Sibling module: [31-proxy-middleware-builtin.md](./31-proxy-middleware-builtin.md)
|
||||
— the 8 middlewares that consume this package's parsers + pricing loader.
|
||||
|
||||
---
|
||||
|
||||
## Module boundary
|
||||
|
||||
`proxy/internal/llm` is the runtime-agnostic LLM library shared by every
|
||||
middleware that needs to understand provider-specific shapes. Zero
|
||||
proxy-framework dependencies:
|
||||
|
||||
- `parser.go` — `Parser` interface, `Provider` enum, public factories
|
||||
(`Parsers`, `DetectParser`, `ParserByName`).
|
||||
- `openai.go` / `anthropic.go` / `bedrock.go` — per-provider `Parser` impls.
|
||||
- `sse.go` — SSE scanner (`Scanner`, `Event`, `NewScanner`).
|
||||
- `errors.go` — sentinels callers branch on with `errors.Is`.
|
||||
- `pricing/` — embedded-default + hot-reload override table with
|
||||
symlink-safe Unix loader (build-tagged stub elsewhere).
|
||||
- `fixtures/` — captured request/response/stream bodies the tests replay.
|
||||
|
||||
The package carries zero proxy-framework dependencies so the same parsers can
|
||||
be reused later by a WASM adapter
|
||||
([parser.go:1–6](../../../proxy/internal/llm/parser.go)).
|
||||
|
||||
## Files
|
||||
|
||||
| File | LOC | Notes |
|
||||
|---|---:|---|
|
||||
| `parser.go` | 104 | Interface + factories + `Provider{Unknown,OpenAI,Anthropic}` enum |
|
||||
| `openai.go` | 347 | Chat Completions + Completions + Responses API; cached_tokens subset |
|
||||
| `openai_test.go` | 222 | 11 tests; fixture replay + cached/Responses-API matrix |
|
||||
| `anthropic.go` | 172 | Messages + legacy `/v1/complete`; cache_read + cache_creation additive |
|
||||
| `anthropic_test.go` | 154 | 7 tests including streaming-extraction-skipped contract |
|
||||
| `bedrock.go` | 190 | AWS Bedrock InvokeModel (snake_case) + Converse (camelCase) response shapes; model lives in URL path |
|
||||
| `bedrock_test.go` | — | InvokeModel + Converse usage shapes; AWS event-stream content-type → `ErrStreamingUnsupported` on buffered `ParseResponse` |
|
||||
| `sse.go` | 117 | `bufio`-backed scanner; CRLF normalised; trailing-event handling |
|
||||
| `sse_test.go` | 175 | 12 tests; fixture replay + multiline + size limits |
|
||||
| `parser_test.go` | 53 | `Parsers()`, `DetectParser`, provider enum values |
|
||||
| `errors.go` | 31 | 6 sentinels: `Err{Unknown,Unsupported}Provider/Model`, `Err{NotLLM,Malformed}Response`, `ErrStreamingUnsupported`, `ErrMalformedRequest` |
|
||||
| `pricing/pricing.go` | 421 | `Loader`, `Table`, `Entry`; embedded defaults + atomic swap + mtime reload |
|
||||
| `pricing/pricing_unix.go` | 69 | `O_NOFOLLOW` + fstat-from-FD + 1 MiB cap |
|
||||
| `pricing/pricing_other.go` | 21 | Stub returning "not supported on this platform" |
|
||||
| `pricing/pricing_test.go` | 432 | 21 tests — symlink rejection, reload race, path traversal, oversize |
|
||||
| `pricing/defaults_pricing.yaml` | 85 | go:embed source of truth |
|
||||
| `fixtures/*` | 21–59 | OAI chat/responses/stream + Anthro messages/stream + pricing starter |
|
||||
|
||||
## Request body → parser dispatch
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[HTTP request<br/>URL + JSON body] --> B{ParserByName?<br/>provider_id config set}
|
||||
B -- yes --> P[matched Parser]
|
||||
B -- no --> C[DetectParser]
|
||||
C --> D{loop Parsers<br/>OpenAIParser, AnthropicParser}
|
||||
D -- DetectFromURL match --> P
|
||||
D -- no match --> X[ok=false<br/>middleware skips]
|
||||
P --> E[ParseRequest body]
|
||||
E -->|err: ErrMalformedRequest| Y[middleware emits provider only]
|
||||
E --> F[RequestFacts<br/>model + stream]
|
||||
P --> G[ExtractPrompt body]
|
||||
G --> H[joinMessages<br/>extractContentParts<br/>decodeStringOrJoin]
|
||||
H --> I[prompt text<br/>or empty]
|
||||
F --> J[stamps llm.model + llm.stream]
|
||||
I --> K[stamps llm.request_prompt_raw<br/>subject to capture_prompt gate]
|
||||
```
|
||||
|
||||
OpenAI's URL hints
|
||||
([openai.go:27–33](../../../proxy/internal/llm/openai.go)) include
|
||||
both `/v1/chat/completions` and the bare `/chat/completions` — the latter
|
||||
covers Cloudflare AI Gateway, which rewrites the canonical version segment.
|
||||
Anthropic's hints are `/v1/messages` and `/v1/complete`
|
||||
([anthropic.go:14–17](../../../proxy/internal/llm/anthropic.go)).
|
||||
Both implementations use case-insensitive substring matching so a proxy prefix
|
||||
strip / rewrite doesn't defeat detection.
|
||||
|
||||
`ParserByName` ([parser.go:93–103](../../../proxy/internal/llm/parser.go))
|
||||
is the **agent-network bypass**: the synthesiser knows which parser to use
|
||||
because it built the synth service from the catalog, so it stamps
|
||||
`provider_id` on the parser config and the middleware skips URL sniffing
|
||||
entirely. This is what makes the same parser set work whether the request
|
||||
flows to OpenAI direct, to LiteLLM, to Portkey, or to any gateway with a
|
||||
non-canonical URL shape.
|
||||
|
||||
**Path-routed providers (Vertex AI, Bedrock) bypass both `ParserByName` and
|
||||
`DetectParser`.** The model and the parser surface live in the URL path, so the
|
||||
request middleware extracts them directly (`parseVertexPath` /
|
||||
`parseBedrockPath`) before the parser-selection step. For Vertex the publisher
|
||||
segment picks the parser (`anthropic` → Anthropic parser; `google`/Gemini →
|
||||
none, request denied as unmeterable). For Bedrock the dedicated `BedrockParser`
|
||||
handles the response. Full treatment in
|
||||
[50-path-routed-providers.md](./50-path-routed-providers.md).
|
||||
|
||||
## Streaming response → SSE chunker → response parser → completion + token count
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant U as upstream LLM
|
||||
participant LR as llm_response_parser<br/>(OnResponse)
|
||||
participant S as llm.NewScanner<br/>(SSE framer)
|
||||
participant P as Parser-specific accumulator<br/>(accumulateOpenAIStream<br/>or accumulateAnthropicStream)
|
||||
|
||||
U-->>LR: text/event-stream<br/>(buffered prefix in RespBody)
|
||||
LR->>S: NewScanner(bytes.NewReader(body))
|
||||
loop until EOF or [DONE]
|
||||
S-->>LR: Event{Type, Data}
|
||||
LR->>P: dispatch per event.Type<br/>(OpenAI: data-only<br/>Anthropic: named events)
|
||||
P-->>P: accumulate completion text<br/>track usage from final frame
|
||||
end
|
||||
P-->>LR: llm.Usage + completion string
|
||||
LR->>LR: appendUsage stamps<br/>llm.{input,output,total,cached_input,cache_creation}_tokens
|
||||
LR->>LR: truncateCompletion(3500 bytes, rune-safe)
|
||||
LR->>LR: redactPII if redact_pii && captureCompletion
|
||||
```
|
||||
|
||||
`Scanner.Next`
|
||||
([sse.go:44–87](../../../proxy/internal/llm/sse.go)) returns one
|
||||
event per `\n\n` boundary; multiple `data:` lines join with `\n`; comment lines
|
||||
(starting with `:`) are skipped per the SSE spec; a trailing event without a
|
||||
closing blank line is still returned before `io.EOF` so a server that closes
|
||||
the connection cleanly doesn't lose the last frame
|
||||
([sse.go:55–58](../../../proxy/internal/llm/sse.go)). CRLF is
|
||||
normalised in `trimEOL` so fixtures captured from live servers replay
|
||||
unchanged.
|
||||
|
||||
## Per-provider
|
||||
|
||||
### OpenAI
|
||||
|
||||
[openai.go:54–67](../../../proxy/internal/llm/openai.go) defines
|
||||
`openAIRequest` with three prompt fields: `messages` (Chat Completions),
|
||||
`prompt` (legacy), `input` (Responses API). The decoder uses
|
||||
`json.RawMessage` so each shape is parsed lazily.
|
||||
|
||||
`ParseResponse`
|
||||
([openai.go:117–146](../../../proxy/internal/llm/openai.go))
|
||||
accepts both naming conventions: Chat Completions returns
|
||||
`prompt_tokens`/`completion_tokens`, Responses API returns
|
||||
`input_tokens`/`output_tokens`. `pickInt64` prefers Responses-API names and
|
||||
falls back — same parser handles both endpoints without per-route config.
|
||||
`openAICachedTokens` mirrors the fallback for
|
||||
`input_tokens_details.cached_tokens` vs `prompt_tokens_details.cached_tokens`.
|
||||
|
||||
**Key invariant:** `CachedInputTokens` for OpenAI is a SUBSET of
|
||||
`InputTokens`. The cost meter clamps to guard against malformed upstream
|
||||
responses where `cached > total`.
|
||||
|
||||
### Anthropic
|
||||
|
||||
[anthropic.go:37–49](../../../proxy/internal/llm/anthropic.go)
|
||||
defines `anthropicRequest` covering Messages API (`system` + `messages[]`)
|
||||
and legacy `/v1/complete` (`prompt` string). `ExtractPrompt` emits
|
||||
`system: <text>` first when present, then per-message `role: content`.
|
||||
|
||||
`ParseResponse`
|
||||
([anthropic.go:82–104](../../../proxy/internal/llm/anthropic.go))
|
||||
fills three independent token buckets: `InputTokens`, `CacheReadInputTokens`,
|
||||
`CacheCreationInputTokens`. Latter two are **additive** (not subset).
|
||||
`TotalTokens` sums all four so downstream dashboards render one "tokens"
|
||||
number without double-counting.
|
||||
|
||||
`ExtractCompletion` walks `content[]` `{type, text}` parts and concatenates
|
||||
non-empty text with newlines, falling back to legacy `completion`.
|
||||
|
||||
### Bedrock
|
||||
|
||||
[bedrock.go](../../../proxy/internal/llm/bedrock.go) implements the
|
||||
`Parser` interface for the AWS Bedrock runtime. Bedrock is **path-routed**: the
|
||||
model lives in the URL (`/model/{id}/{action}`), so the request middleware
|
||||
extracts it (see [50-path-routed-providers.md](./50-path-routed-providers.md))
|
||||
and `ParseRequest` is a deliberate no-op. The parser's real work is on the
|
||||
response leg, covering both Bedrock body shapes:
|
||||
|
||||
- **InvokeModel** — vendor-native. Anthropic-on-Bedrock returns snake_case usage
|
||||
(`input_tokens`, `output_tokens`, `cache_read_input_tokens`,
|
||||
`cache_creation_input_tokens`) with the same additive cache buckets as
|
||||
first-party Anthropic.
|
||||
- **Converse** — unified camelCase (`inputTokens`, `outputTokens`,
|
||||
`totalTokens`). `firstNonZero` folds the two naming conventions into one
|
||||
`Usage`; when Converse omits `totalTokens` the parser sums the buckets.
|
||||
|
||||
`ProviderName()` returns `"bedrock"` — its own `defaults_pricing.yaml` block,
|
||||
keyed by the **normalised** model id (region prefix + version suffix stripped by
|
||||
the request parser). `ParseResponse` returns `ErrStreamingUnsupported` for an
|
||||
AWS binary event-stream content-type (`application/vnd.amazon.eventstream`,
|
||||
`isAWSEventStream`) so the caller routes to the streaming accumulator instead.
|
||||
|
||||
### SSE framing
|
||||
|
||||
`Scanner` is `bufio`-backed, 64 KiB read buffer, 1 MiB max line so a
|
||||
malicious upstream can't blow process memory
|
||||
([sse.go:33–38, 97–100](../../../proxy/internal/llm/sse.go)).
|
||||
`splitField` strips one space after the `:` per the SSE spec. Documented
|
||||
`not safe for concurrent use`; every consumer creates a fresh scanner per
|
||||
response body. Streaming accumulators live in the middleware package
|
||||
([llm_response_parser/streaming.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go))
|
||||
but use `llm.NewScanner` so the framing contract stays here.
|
||||
|
||||
### Pricing catalog
|
||||
|
||||
`Table.Cost`
|
||||
([pricing.go:129–174](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
is the cost formula — most security-relevant math in this module:
|
||||
|
||||
| Provider | Formula |
|
||||
|---|---|
|
||||
| `openai` | `(inTokens − clamped) × InputPer1K + clamped × CachedInputPer1K + outTokens × OutputPer1K` where `clamped = min(cachedInput, inTokens)` |
|
||||
| `anthropic`, `bedrock` | `inTokens × InputPer1K + cachedInput × CacheReadPer1K + cacheCreation × CacheCreationPer1K + outTokens × OutputPer1K` |
|
||||
| default | `inTokens × InputPer1K + outTokens × OutputPer1K` |
|
||||
|
||||
`bedrock` shares the Anthropic additive-cache formula
|
||||
([pricing.go:172-174](../../../proxy/internal/llm/pricing/pricing.go)):
|
||||
Anthropic-on-Bedrock reports the same additive cache buckets, while non-Anthropic
|
||||
Bedrock models (Nova, Llama) simply report zero in those buckets so cost reduces
|
||||
to `input + output`.
|
||||
|
||||
Each per-bucket rate falls back to `InputPer1K` when zero — operators opt in
|
||||
to discounts by setting the field.
|
||||
|
||||
`Loader`
|
||||
([pricing.go:212–268](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
overlays an optional `pricing.yaml` from data-dir on top of the go:embed
|
||||
defaults. Atomic pointer swap means readers never observe a partial update.
|
||||
The mtime-poll reloader (30s default cadence) keeps the previous table on
|
||||
parse failure so cost annotation never goes blank during a botched edit.
|
||||
|
||||
`defaults_pricing.yaml` is the source of truth for built-in pricing.
|
||||
Operator overrides only carry the entries they want to change.
|
||||
|
||||
## Public contracts
|
||||
|
||||
**`Parser` interface**
|
||||
([parser.go:50–66](../../../proxy/internal/llm/parser.go)):
|
||||
|
||||
```go
|
||||
type Parser interface {
|
||||
Provider() Provider
|
||||
ProviderName() string
|
||||
DetectFromURL(path string) bool
|
||||
ParseRequest(body []byte) (RequestFacts, error)
|
||||
ParseResponse(status int, contentType string, body []byte) (Usage, error)
|
||||
ExtractPrompt(body []byte) string
|
||||
ExtractCompletion(status int, contentType string, body []byte) string
|
||||
}
|
||||
```
|
||||
|
||||
Adding a provider means implementing this interface and appending to the
|
||||
slice returned by `Parsers()` ([parser.go:78–84](../../../proxy/internal/llm/parser.go)).
|
||||
Order matters: `DetectFromURL` ties resolve by registration order.
|
||||
`Parsers()` today returns `{OpenAIParser, AnthropicParser, BedrockParser}`.
|
||||
|
||||
**`Provider` enum**
|
||||
([parser.go:8–18](../../../proxy/internal/llm/parser.go)):
|
||||
`ProviderUnknown = 0`, `ProviderOpenAI = 1`, `ProviderAnthropic = 2`,
|
||||
`ProviderBedrock = 3`. Numeric values are persisted in nothing today but treat
|
||||
them as wire-stable — new providers must take fresh numbers.
|
||||
|
||||
**`Pricing` lookup**
|
||||
([pricing.go:129](../../../proxy/internal/llm/pricing/pricing.go)):
|
||||
|
||||
```go
|
||||
func (t *Table) Cost(provider, model string, inTokens, outTokens, cachedInput, cacheCreation int64) (float64, bool)
|
||||
```
|
||||
|
||||
Nil-safe: `t.Cost` on a nil receiver returns `(0, false)`
|
||||
([pricing.go:130–132](../../../proxy/internal/llm/pricing/pricing.go)).
|
||||
`ok=false` means provider or model is absent from the loaded table; the caller
|
||||
emits `cost.skipped=unknown_model`.
|
||||
|
||||
## Invariants
|
||||
|
||||
1. **Cross-platform pricing build.** `pricing_unix.go` carries the only
|
||||
functional `loadPricing` (uses `syscall.O_NOFOLLOW` and `f.Stat()` on an
|
||||
open descriptor — both Unix-only). `pricing_other.go` is a build-tag
|
||||
fallback that returns `"not supported on this platform"`
|
||||
([pricing_other.go:14–16](../../../proxy/internal/llm/pricing/pricing_other.go)).
|
||||
The proxy is Linux-only in production today; a Windows port needs an
|
||||
equivalent path-as-handle implementation. Reviewers building on Windows
|
||||
should expect this surface to return an error at startup if an override
|
||||
file is configured.
|
||||
|
||||
2. **SSE scanner handles partial chunks.** A buffered prefix that doesn't end
|
||||
in `\n\n` still yields its accumulated event before `io.EOF`
|
||||
([sse.go:55–58](../../../proxy/internal/llm/sse.go)). Tests:
|
||||
`TestSSEScanner_OpenAIFixture`, `TestSSEScanner_AnthropicFixture`,
|
||||
`TestSSEScanner_MultilineData`, `TestSSEScanner_CRLF`. The streaming
|
||||
accumulators ride on this: `accumulateAnthropicStream` and
|
||||
`accumulateOpenAIStream` `break` on any scanner error to return partial
|
||||
usage rather than aborting
|
||||
([streaming.go:68–73, 144–150](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go)).
|
||||
|
||||
3. **`defaults_pricing.yaml` is the source of truth.** Compiled into the
|
||||
binary via `//go:embed`
|
||||
([pricing.go:29–30](../../../proxy/internal/llm/pricing/pricing.go)).
|
||||
`DefaultTable()` parses once and panics on parse failure
|
||||
([pricing.go:42–49](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
— by design: a broken embedded YAML must not ship to production.
|
||||
|
||||
4. **Loader path validation.** `resolveMiddlewareDataPath`
|
||||
([pricing.go:370–394](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
rejects absolute paths, traversal segments, and basenames that fail
|
||||
`basenameRegex = ^[a-zA-Z0-9._-]+$`. The resolved path must remain
|
||||
inside `baseDir` even after `filepath.Clean`. Tests:
|
||||
`TestNewLoader_PathValidation`, `TestNewLoader_PathValidation_Extended`,
|
||||
`TestNewLoader_SymlinkOutsideBaseDirRejected`, `TestNewLoader_SymlinkRejected`.
|
||||
|
||||
5. **Unix loader symlink safety.** `O_NOFOLLOW` on open, `f.Stat()` on the
|
||||
open descriptor (never re-stat by path), `info.Mode().IsRegular()` check,
|
||||
`io.LimitReader(f, maxPricingBytes+1)` with a final size assertion
|
||||
([pricing_unix.go:25–57](../../../proxy/internal/llm/pricing/pricing_unix.go)).
|
||||
A mid-read symlink swap is detected because the fstat is on the original
|
||||
fd. Test: `TestNewLoader_RejectsOversizedFile_FixesM4`.
|
||||
|
||||
6. **`yaml.NewDecoder(...).KnownFields(true)`**
|
||||
([pricing.go:397–398](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
rejects YAML files that carry fields not in the schema. A typo in an
|
||||
operator override file fails loud instead of silently zeroing rates.
|
||||
|
||||
## Things to scrutinise
|
||||
|
||||
**Correctness.** Verify OpenAI cached-prompt clamp at
|
||||
[pricing.go:147–149](../../../proxy/internal/llm/pricing/pricing.go)
|
||||
short-circuits before subtraction. `Anthropic.TotalTokens` sums all four
|
||||
buckets (in + out + cache_read + cache_creation) — downstream dashboards
|
||||
need to know this differs from `input + output`.
|
||||
`OpenAIParser.ExtractPrompt` falls through `messages → input → prompt`; a
|
||||
request sending all three reports only `messages` (uncommon but worth
|
||||
noting).
|
||||
|
||||
**Security.** `Scanner.maxLine = 1 MiB`; a 2 MiB single-line `data:` event
|
||||
errors from `Scanner.Next` and both accumulators stop with partial usage.
|
||||
Pricing file 1 MiB cap is orders of magnitude larger than realistic. Confirm
|
||||
new schema additions are mirrored in both `pricingFile` and `Entry`;
|
||||
`KnownFields(true)` will reject silently-typo'd operator overrides
|
||||
otherwise.
|
||||
|
||||
**Concurrency.** `Loader.table` is `atomic.Pointer[Table]`; readers never
|
||||
block or see a torn table. `Loader.Reload` is one goroutine, cancelled via
|
||||
context (`TestLoader_ReloadBackgroundLoopCancellation`). `DefaultTable()`
|
||||
uses `sync.Once`. Per-call `Scanner` instances mean no shared state across
|
||||
concurrent response-parser calls.
|
||||
|
||||
**Perf.** `Table.Cost` is two map lookups + multiplications, O(1).
|
||||
`Scanner.Next` is one `ReadString('\n')` per line. Pricing reload poll 30s.
|
||||
|
||||
**Observability.** Reload failures count via `metric.Int64Counter` keyed
|
||||
`plugin`; warning log rate-limited at 5 min so a broken file doesn't flood.
|
||||
Parser errors return sentinels — middleware uses `errors.Is` to map to the
|
||||
right `cost.skipped` reason.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| File | Tests | Coverage highlights |
|
||||
|---|---:|---|
|
||||
| `parser_test.go` | 3 | `Parsers()` shape lock, `DetectParser` URL matrix, provider enum stability |
|
||||
| `openai_test.go` | 11 | Chat Completions + Responses API + legacy `prompt`; cached-tokens subset for both naming conventions; fixture replays |
|
||||
| `anthropic_test.go` | 7 | Messages + legacy `/v1/complete`; streaming REJECTED on `ParseResponse` (must use scanner); fixture replays |
|
||||
| `sse_test.go` | 12 | Fixture replay both providers; multiline `data:`; CRLF; comment skip; trailing-event-without-blank-line; oversize rejection |
|
||||
| `pricing/pricing_test.go` | 21 | Provider-shape switch; cached-rate fallback; cached-clamp; symlink rejection (target outside basedir + symlink to file); path validation matrix; oversize rejection; reload-keeps-previous-on-parse-error; mtime change detection; goroutine cancellation |
|
||||
|
||||
**Fixtures** ([proxy/internal/llm/fixtures/](../../../proxy/internal/llm/fixtures/)):
|
||||
`openai_chat_completion.json` (chat.completions with usage),
|
||||
`openai_responses.json` (Responses API shape),
|
||||
`openai_stream.txt` (3 deltas + usage + `[DONE]`),
|
||||
`anthropic_messages.json` (Messages API non-streaming),
|
||||
`anthropic_stream.txt` (full 7-event sequence: message_start →
|
||||
content_block_{start,delta×2,stop} → message_delta (usage) → message_stop),
|
||||
`pricing.yaml` (realistic-pricing starter for operator overrides).
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Sibling: [31-proxy-middleware-builtin.md](./31-proxy-middleware-builtin.md)
|
||||
— the chain that calls `llm.Parsers()`, `llm.ParserByName`,
|
||||
`llm.NewScanner`, `pricing.NewLoader`.
|
||||
- Path-routed providers (Vertex AI + Bedrock), credential syntax, and the
|
||||
Bedrock AWS event-stream accumulator:
|
||||
[50-path-routed-providers.md](./50-path-routed-providers.md).
|
||||
- Direct callers: `llm_request_parser/middleware.go:82–94`,
|
||||
`llm_response_parser/middleware.go:113–123`,
|
||||
`llm_response_parser/streaming.go:65, 142`, `cost_meter/factory.go:49–57`.
|
||||
- Related elsewhere: the agent-network synthesiser stamping `provider_id`
|
||||
is covered in the management-side module guide; proxy server boot +
|
||||
`FactoryContext` construction is covered in the proxy-framework guide.
|
||||
194
docs/agent-networks/modules/33-proxy-runtime.md
Normal file
194
docs/agent-networks/modules/33-proxy-runtime.md
Normal file
@@ -0,0 +1,194 @@
|
||||
# proxy/runtime — translate + serve + log
|
||||
|
||||
> **Risk level:** High — every config push from management is translated here, and the chain runs on every HTTP request to a synth target.
|
||||
> **Backward-compat impact:** Additive at the wire (`PathTargetOptions.middlewares`, `agent_network`, `disable_access_log`, capture caps) and on the proxy `Server` struct (`MiddlewareDataDir`, `MiddlewareCaptureBudgetBytes`). Non-agent-network targets stay on the no-middleware fast path.
|
||||
|
||||
## Module boundary
|
||||
|
||||
Turns the synth-service wire format from `ProxyService.SyncMappings`/`GetMappingUpdate` into in-process middleware chains and runs them on top of the existing `httputil.ReverseProxy`. Four concerns: (a) **translate** — `proto.MiddlewareConfig` → validated `middleware.Spec` (proxy/middleware_translate.go) + self-register the eight built-ins (proxy/middleware_register.go); (b) **boot + rebuild** — construct the `middleware.Manager`, share the OTel meter, install the live-service check, rebuild per-path chains on every `addMapping`/`modifyMapping` (proxy/server.go); (c) **serve** — resolve chain at request time, capture bodies under a global budget, invoke `RunRequest`/`RunResponse`/`RunTerminal`, render deny responses, apply `UpstreamRewrite` (proxy/internal/proxy/reverseproxy.go); (d) **log + tag** — emit access-log entries with the new `agent_network` flag, gate emission on `EnableLogCollection` via `DisableAccessLog` (proxy/internal/accesslog).
|
||||
|
||||
**Inert for non-agent-network targets**: nil or empty chain → existing fast path (reverseproxy.go:127-139); `SuppressAccessLog` defaults false so the access-log middleware emits unchanged.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| proxy/middleware_translate.go | proto→Spec translation; slot/failmode/timeout mapping; caps |
|
||||
| proxy/middleware_translate_test.go | translator unit tests |
|
||||
| proxy/middleware_register.go | blank-imports the eight builtins for `init()` registration |
|
||||
| proxy/server.go | `initMiddlewareManager`, `rebuildMiddlewareChains`, `isLiveService`, `buildMiddlewareBindings`, new Server fields, `protoToMapping` stamps AgentNetwork/DisableAccessLog/CaptureConfig/Middlewares |
|
||||
| proxy/internal/proxy/reverseproxy.go | `WithMiddlewareManager`, chain dispatch, body capture, `applyUpstreamRewrite`/`Headers`, `buildRequestInput`, response-leg respInput identity fields |
|
||||
| proxy/internal/proxy/reverseproxy_test.go | `TestBuildRequestInput_PropagatesIdentityAndGroups` |
|
||||
| proxy/internal/proxy/context.go | `agentNetwork`, `suppressAccessLog`, `userGroupNames` on `CapturedData` |
|
||||
| proxy/internal/proxy/servicemapping.go | new `PathTarget` fields |
|
||||
| proxy/internal/proxy/agent_network_chain_realstack_test.go | end-to-end self-contained chain test |
|
||||
| proxy/internal/accesslog/logger.go | `logEntry.AgentNetwork` → `proto.AccessLog` |
|
||||
| proxy/internal/accesslog/middleware.go | reads `GetAgentNetwork()`; gates `l.log` on `!GetSuppressAccessLog()` |
|
||||
| proxy/internal/accesslog/middleware_test.go | suppress/default/preserves-usage assertions |
|
||||
| proxy/internal/auth/middleware_test.go | tunnel-peer group propagation contract |
|
||||
| proxy/internal/metrics/metrics.go | `Meter()` getter for the middleware manager |
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### Synth-service ingestion → translate → register → serve
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Management SyncMappings/GetMappingUpdate] --> B["processMappings\nserver.go:1492"]
|
||||
B --> C{Mapping type}
|
||||
C -->|CREATED| D["addMapping → setupHTTPMapping → updateMapping"]
|
||||
C -->|MODIFIED| E["modifyMapping → cleanupMappingRoutes → setupHTTPMapping → updateMapping"]
|
||||
C -->|REMOVED| F["removeMapping → cleanupMappingRoutes → invalidateMiddlewareChains"]
|
||||
D --> G["protoToMapping\nserver.go:2181"]
|
||||
E --> G
|
||||
G --> H["translateMiddlewareConfigs\nmiddleware_translate.go:55"]
|
||||
G --> I["translateMiddlewareCaptureConfig\nmiddleware_translate.go:18"]
|
||||
H --> J["[]middleware.Spec on PathTarget"]
|
||||
I --> K["*bodytap.Config on PathTarget"]
|
||||
J --> L["proxy.AddMapping\nservicemapping.go:118"]
|
||||
K --> L
|
||||
L --> M["rebuildMiddlewareChains\nserver.go:2017 → Manager.Rebuild"]
|
||||
F --> N["Manager.Invalidate(serviceID)"]
|
||||
```
|
||||
|
||||
### Per-request lifecycle through the chain + accesslog
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
participant C as Client
|
||||
participant M as accesslog.Middleware
|
||||
participant A as auth.Middleware (Protect)
|
||||
participant RP as ReverseProxy.ServeHTTP
|
||||
participant CH as middleware.Chain
|
||||
participant U as Upstream
|
||||
C->>M: HTTP request
|
||||
M->>M: NewCapturedData(requestID), WithCapturedData(ctx)
|
||||
M->>A: next.ServeHTTP
|
||||
A->>A: Private → ValidateTunnelPeer → stamp UserID/Email/Groups/GroupNames/AuthMethod
|
||||
A->>RP: next.ServeHTTP
|
||||
RP->>RP: findTargetForRequest → targetResult
|
||||
RP->>RP: stamp ServiceID/AccountID/AgentNetwork/SuppressAccessLog on CapturedData
|
||||
RP->>RP: resolveChain via Manager.ChainFor
|
||||
alt chain == nil or Empty
|
||||
RP->>U: httputil.ReverseProxy.ServeHTTP (fast path)
|
||||
else chain non-empty
|
||||
RP->>RP: bodytap.CaptureRequest (global budget)
|
||||
RP->>CH: RunRequest
|
||||
CH-->>RP: denyOutput? requestMeta + upstreamRewrite
|
||||
alt deny
|
||||
RP->>C: RenderDenyResponse
|
||||
else allow
|
||||
RP->>RP: capturingWriter + applyUpstreamRewrite/Headers
|
||||
RP->>U: httputil.ReverseProxy.ServeHTTP(respWriter)
|
||||
U-->>RP: response
|
||||
RP->>CH: RunResponse (respInput carries UserGroups)
|
||||
RP->>CH: RunTerminal (merged request+response metadata)
|
||||
end
|
||||
end
|
||||
RP-->>M: handler returns
|
||||
M->>M: build logEntry incl. AgentNetwork
|
||||
alt SuppressAccessLog == true
|
||||
M->>M: skip l.log; still trackUsage
|
||||
else default
|
||||
M->>M: l.log → goroutine SendAccessLog
|
||||
end
|
||||
```
|
||||
|
||||
### EnableLogCollection suppression path
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
S["agentnetwork.Settings.EnableLogCollection"] --> B["synthesizer: target.DisableAccessLog = !EnableLogCollection"]
|
||||
B --> P["proto PathTargetOptions.disable_access_log (field 13)"]
|
||||
P --> T["protoToMapping reads GetDisableAccessLog()\nserver.go:2211"]
|
||||
T --> M["PathTarget.DisableAccessLog\nservicemapping.go:47"]
|
||||
M --> R["ServeHTTP: cd.SetSuppressAccessLog\nreverseproxy.go:106"]
|
||||
R --> G["accesslog middleware: if !GetSuppressAccessLog l.log\nmiddleware.go:95"]
|
||||
R --> U["trackUsage unconditional — bandwidth telemetry preserved"]
|
||||
```
|
||||
|
||||
**Ingestion** lands as a `ProxyMapping` batch on `handleSyncMappingsStream`/`handleMappingStream`. `processMappings` dispatches to `addMapping`/`modifyMapping`/`removeMapping`; HTTP goes `setupHTTPMapping → updateMapping → protoToMapping`. `protoToMapping` (server.go:2181) is the single translation surface that materialises `[]middleware.Spec`, `*bodytap.Config`, `AgentNetwork`, `DisableAccessLog` onto each `PathTarget`; `updateMapping` finishes with `s.proxy.AddMapping(m)` (atomic swap under `mappingsMux`) and `s.rebuildMiddlewareChains(svcID, m)`.
|
||||
|
||||
At **request time** the access-log middleware stamps `CapturedData`; the auth chain runs (Private services lift `peer_group_ids` from `ValidateTunnelPeer` — auth/middleware_test.go:322). `ReverseProxy.ServeHTTP` resolves the chain; nil or empty → original `httputil.ReverseProxy`, no body capture. When a chain matches, body is captured under the global budget, `RunRequest` produces an `UpstreamRewrite` (`llm_router` selects a provider, rewrites scheme/host/path, injects `Authorization`), and `RunResponse`+`RunTerminal` run after the upstream returns. The terminal slot sees the merged metadata bag — that's how `llm_limit_record` ships the consumption sample. The **access-log** addition: `logEntry.AgentNetwork` from `GetAgentNetwork()` onto `proto.AccessLog.AgentNetwork`; the gate at middleware.go:95 honors `EnableLogCollection`, skipping `l.log` but keeping `trackUsage` so bandwidth telemetry survives.
|
||||
|
||||
## Public contracts touched
|
||||
|
||||
- `proxy.Server.MiddlewareDataDir` (string) — base dir for file-backed middleware config (server.go:238-241).
|
||||
- `proxy.Server.MiddlewareCaptureBudgetBytes` (int64) — process-wide capture cap; defaults to 256 MiB (server.go:248-250).
|
||||
- `proxy/internal/proxy.WithMiddlewareManager(*middleware.Manager) Option` — new option on `NewReverseProxy`; nil keeps the fast path (reverseproxy.go:48-56).
|
||||
- `proxy/internal/proxy.PathTarget` adds `Middlewares`, `CaptureConfig`, `AgentNetwork`, `DisableAccessLog` (servicemapping.go:27-51), all zero-default.
|
||||
- `proxy/internal/proxy.CapturedData` adds `agentNetwork`, `suppressAccessLog`, `userGroupNames` behind `sync.RWMutex`; slices deep-copied (context.go:47-66, 183-258).
|
||||
- `accesslog.logEntry.AgentNetwork` + `proto.AccessLog.AgentNetwork` (logger.go:131, 268).
|
||||
- `metrics.Metrics.Meter()` exposes the OTel meter for the middleware manager (metrics.go:53-58).
|
||||
|
||||
## Invariants
|
||||
|
||||
- **Synth-service updates are live (no proxy restart).** Every `MODIFIED` flows through `modifyMapping → cleanupMappingRoutes` (invalidates chains) `→ setupHTTPMapping → updateMapping → rebuildMiddlewareChains`. **ProxyMapping.Private preservation:** the relevant logic lives in `management/internals/shared/grpc/proxy.go:shallowCloneMapping`, not this module, but it surfaces here — if a `MODIFIED` synth service arrives `private=false`, auth skips `ValidateTunnelPeer`, `CapturedData.UserGroups` stays empty, and `llm_router` denies with `llm_policy.no_authorised_provider` until a management restart re-pushes the snapshot. This module assumes `mapping.GetPrivate()` is correct on every batch.
|
||||
- **`EnableLogCollection=false` suppresses access-log writes but middleware still runs.** Gate is one `if !cd.GetSuppressAccessLog()` immediately around `l.log(entry)` (middleware.go:95); `trackUsage` runs below the gate. Locked by `TestMiddleware_SuppressAccessLog_PreservesUsageTracking` (middleware_test.go:139).
|
||||
- **`agent_network` flag on access-log entries is set when the chain processed the request.** Source `target.AgentNetwork`, stamped at reverseproxy.go:105, read at accesslog/middleware.go:86.
|
||||
- **auth → builtin group propagation.** `Protect` writes `UserGroups`/`UserGroupNames`; `buildRequestInput` (reverseproxy.go:333) copies them into `middleware.Input`. The response-leg `respInput` (reverseproxy.go:196-223) also carries `UserEmail`/`UserGroups`/`UserGroupNames` — `llm_limit_record` needs `UserGroups` to ship `group_ids` so management's group-targeted budget rules match (comment at reverseproxy.go:211-215).
|
||||
- **Empty chains stay on the fast path.** `ServeHTTP` skips body capture and the run sequence when `chain == nil || chain.Empty()` (reverseproxy.go:127).
|
||||
- **Self-registration is the only way a builtin reaches the registry.** `middleware_register.go` blank-imports each builtin; `init()` adds the factory to `mwbuiltin.DefaultRegistry()`. Missing it → translator drops the entry with a warn (translate.go:97).
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
- **Translate edge cases** — drops on nil cfg, empty ID, unknown ID, UNSPECIFIED slot; each logs one warn; volume bounded by `MaxMiddlewaresPerChain`.
|
||||
- **Re-translate without dropping in-flight requests** — `Manager.Rebuild` is the only call from `rebuildMiddlewareChains`. Reverse proxy reads `ChainFor` once per request (reverseproxy.go:327) and runs the captured `*Chain` for the whole request. Verify in module 30 that `Rebuild` swaps atomically.
|
||||
- **ProxyMapping.Private preservation** — enforced management-side in `shallowCloneMapping`. Proxy-side regression catches: `TestProtect_PrivateService_TunnelPeerGroupsPropagate` + the integration test.
|
||||
- **Body-capture cleanup** — `defer releaseBudget()` (reverseproxy.go:145) and `defer capturingWriter.Release()` (reverseproxy.go:180) must run on every return; confirm no future `return` lands between acquisition and defer.
|
||||
- **`applyUpstreamRewrite` clones the URL** — `cloned := *orig` value-copies `*url.URL`; safe because overwritten fields are strings, not slices/maps (reverseproxy.go:285-292).
|
||||
|
||||
### Security
|
||||
- **Translate validates every config** — registry membership rejects unknown IDs; UNSPECIFIED slot drops; ID-less drops; raw config copied (not aliased) at translate.go:109.
|
||||
- **`AuthHeader`/`StripHeaders` only reachable via `UpstreamRewrite`** — regular mutation surface goes through the framework denylist (`Authorization`/`Cookie` blocked); only the router middleware can replace `Authorization` (reverseproxy.go:296-304). Confirm in module 30 nothing outside the proxy-trusted path populates `UpstreamRewrite.AuthHeader`.
|
||||
- **`stampNetBirdIdentity` strips client-sent values first** (reverseproxy.go:742-743) — anti-spoof for `X-NetBird-User`/`X-NetBird-Groups`; control chars filtered; comma-bearing labels dropped (reverseproxy_test.go:1217/:1243/:1193).
|
||||
- **Auth → group propagation** — `auth/middleware_test.go:322` and `:366` cover the contract. If auth ever stops calling `ValidateTunnelPeer` for Private services, every agent-network request silently denies.
|
||||
|
||||
### Concurrency
|
||||
- **Chain replacement under in-flight requests** — `findTargetForRequest` takes `mappingsMux.RLock`; `AddMapping` writes. `resolveChain` calls `ChainFor` once; even if `Rebuild` swaps mid-request, in-flight requests keep running on the captured pointer.
|
||||
- **`CapturedData` mutation across slots** — accessors take `sync.RWMutex`; slices deep-copied on both Set and Get. Verify no caller mutates the returned slice expecting it to land back.
|
||||
- **`Manager.Invalidate` race** — `removeMapping` invalidates after `cleanupMappingRoutes`; mapping read happens before chain resolution, so requests before invalidate run captured chains; later ones fail `findTargetForRequest`.
|
||||
- **`Logger.log` goroutine** — `logSem` caps at `maxLogWorkers = 4096`; overflow → `dropped.Add(1)` + debug log. Middleware test uses a buffered channel and 150ms negative-assertion window — review whether 150ms holds on slow CI.
|
||||
|
||||
### Backward compatibility
|
||||
- **Non-agent-network services unaffected** — `protoToMapping` reads new fields only when `opts != nil`; defaults leave `Middlewares`/`CaptureConfig` nil → chain resolves nil → fast path. Existing `reverseproxy_test.go` (non-chain) still passes.
|
||||
- **`disable_access_log` is proto field 13, default false** — every existing target unset; gate is no-op. Locked by `TestMiddleware_SuppressAccessLog_DefaultEmitsLog` (middleware_test.go:104).
|
||||
- **`Server` additions optional** — 256 MiB default when `MiddlewareCaptureBudgetBytes ≤ 0` (server.go:1997-2000).
|
||||
|
||||
### Performance
|
||||
- **Translate cost per push** — O(n) with per-entry registry lookup and `config_json` copy; negligible vs. the upstream gRPC unmarshal.
|
||||
- **Empty-chain hot path** — one `ChainFor` map lookup + one `chain.Empty()` check; no allocation delta vs. pre-PR.
|
||||
- **Body capture buffer churn** — `bodytap.CaptureRequest` allocates `MaxRequestBytes` per chain-hitting request; `releaseBudget` ties allocation to the 256 MiB proxy-wide budget. Confirm in module 30 the budget is a hard cap.
|
||||
|
||||
### Observability
|
||||
- **Metrics** — `Metrics.Meter()` shared with `middleware.NewMetrics` (server.go:1990-1993) so middleware instruments land in the same prometheus exporter. No new metrics defined here.
|
||||
- **Access-log accuracy** — every entry carries `AgentNetwork`; terminal-slot metadata merged into `CapturedData.Metadata` (reverseproxy.go:238-241).
|
||||
- **Deny logs at `Infof`** (reverseproxy.go:170) — review whether `Info` is too noisy at high deny rates; consider Debug or rate-limit.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| proxy/middleware_translate_test.go | Empty/nil → nil; field preservation; unknown ID skip; nil registry permissive; timeout clamping; fail-mode + slot incl. UNSPECIFIED-drop; empty-ID drop; truncation above + at `MaxMiddlewaresPerChain` |
|
||||
| proxy/internal/proxy/reverseproxy_test.go | Rewrite host/headers/cookies/query; trusted proxy; path forwarding; classifyProxyError; X-NetBird-User/Groups anti-spoof + CSV-join + control-char/comma rejection + fallback-to-ID; `TestBuildRequestInput_PropagatesIdentityAndGroups` (UserGroups/Email/GroupNames/AgentNetwork reach `middleware.Input`) |
|
||||
| proxy/internal/proxy/agent_network_chain_realstack_test.go | **The end-to-end integration test.** Drives a real agent-network request through `ReverseProxy.ServeHTTP` with the chain the synthesizer produces, against an in-process management gRPC (bufconn) backed by a real sqlite store + real `agentnetwork.Manager`, plus an `httptest` upstream — no external infrastructure or real LLM. Guarantees: (1) response-leg `respInput` carries `UserGroups` so `llm_limit_record` ships non-empty `group_ids` and the admin-group consumption row increments; (2) `RedactPii=true` redacts both prompt and completion on captured metadata; (3) the full chain runs against a real management stack. **Line 189-211 inlines the proto→Spec mapping** instead of calling the proxy's private `translateMiddlewareConfig` — keep that inline mirror in sync with `proxy/middleware_translate.go` or the test silently diverges from production. |
|
||||
| proxy/internal/accesslog/middleware_test.go | `SuppressAccessLog=true` skips `SendAccessLog` (150ms negative wait); default emits one send (2s positive); usage tracking runs under suppression |
|
||||
| proxy/internal/auth/middleware_test.go | `TestProtect_PrivateService_TunnelPeerGroupsPropagate` proves `peer_group_ids` reach `CapturedData.UserGroups`; `TestProtect_PrivateService_TunnelPeerDenied` proves rejected peers 403 without reaching the handler |
|
||||
|
||||
The integration test runs in a few seconds with no external infrastructure — exercising the real synthesizer, `Manager.Rebuild`, `ServeHTTP` dispatch, and `llm_limit_record` writing a real consumption row through the real `agentnetwork.Manager` over real gRPC.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- **Translator does not validate `RawConfig` JSON** — factory's job at `New([]byte)`. Confirm in module 30 that a per-binding factory failure doesn't poison the rest of the chain.
|
||||
- **No throttle on management push rate** — every `MODIFIED` triggers `Manager.Rebuild`. Mitigation upstream.
|
||||
- **Streaming responses (SSE)** — body capture is streaming-aware, but response-leg middleware runs only after the response completes; long SSE streams delay `llm_limit_record` until close.
|
||||
- **OIDC-only path doesn't carry tunnel-peer groups** — agent-network synth services rely on the Private tunnel-peer path; JWT groups claim is the only carrier for non-Private OIDC.
|
||||
- **`agent_network` flag on L4 entries** not added; HTTP-only.
|
||||
- **`mw.capture.bypass_reason` metadata key** documented at reverseproxy.go:151,184; namespace this in module 30/31 to avoid collisions.
|
||||
|
||||
## Cross-references
|
||||
- Upstream: [shared/api](10-shared-api.md), [proxy/middleware-framework](30-proxy-middleware-framework.md), [proxy/middleware-builtin](31-proxy-middleware-builtin.md), [proxy/llm-parsers](32-proxy-llm-parsers.md)
|
||||
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level: [../00-overview.md](../00-overview.md)
|
||||
228
docs/agent-networks/modules/40-dashboard.md
Normal file
228
docs/agent-networks/modules/40-dashboard.md
Normal file
@@ -0,0 +1,228 @@
|
||||
# dashboard — UI for agent-networks
|
||||
|
||||
This module documents code that lives in the **dashboard repo** (under
|
||||
`src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`), not
|
||||
in this repo. It is co-located here so backend readers see the full picture.
|
||||
|
||||
> **Risk level:** Medium. The new surface is isolated under `src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`, but it also reshapes the sidebar, splits `/peers`, renames `reverse-proxy/clusters` → `self-hosted-proxies`, and overlays the Control Center graph. Regressions here would be cross-cutting.
|
||||
> **Backward-compat impact:** Additive on the API side. Breaking on URL/navigation: `/peers` redirects to `/peers/devices` (src/app/(dashboard)/peers/page.tsx:7-15), `/reverse-proxy/clusters` was renamed to `/reverse-proxy/self-hosted-proxies`, the sidebar lost Access Control / Networks / Reverse Proxy / DNS / standalone Guardrails / Consumption / Activity (Navigation.tsx:165-171 — routes still resolve via URL), and the standalone `/agent-network/{access-log,consumption,global-controls}` routes are gone in favor of `/agent-network/observability`.
|
||||
|
||||
## Module boundary
|
||||
|
||||
The dashboard is the only place an operator interacts with agent-networks: provider catalog, configured providers, policies, guardrails, account-level budget rules, account settings (collection / redaction toggles), per-request access log, and consumption rollups all render, paginate, and edit here. Data flows in via SWR (`useFetchApi`) keyed by REST URL. One big context provider (`src/modules/agent-network/AIProvidersProvider.tsx`) aggregates five resources (providers, policies, guardrails, budget rules, settings) plus the proxy access-log stream filtered to `agent_network=true`, and exposes `add* / update* / toggle* / delete*` mutators that call through `useApiCall` and re-`mutate()` SWR. Pages mount the provider once at the top and compose presentational tables and modals beneath. The control-center page additionally fetches `/agent-network/{providers,policies}` directly (control-center/page.tsx:123-130) to overlay graph nodes.
|
||||
|
||||
## What the UI delivers
|
||||
|
||||
- **AI Observability** page with four tabs: Access Logs, Budget Dashboard,
|
||||
Budget Settings, Log Settings (replaces the standalone access-log,
|
||||
consumption, and global-controls routes).
|
||||
- **Providers** page: provider catalog + connect/edit wizard with per-vendor
|
||||
copy (LiteLLM, Portkey, Bifrost, Cloudflare, Vercel, OpenRouter, custom).
|
||||
- **Policies** page: group → provider authorization with per-policy Limits
|
||||
(minute-granular windows) + guardrail attach.
|
||||
- **Guardrails** page: reusable model-allowlist + prompt-capture sets.
|
||||
- **Account controls**: Log Collection / Prompt Collection / Redact PII toggles.
|
||||
- **Budget rules**: account-level rules reusing the policy Limits UI.
|
||||
- **Control Center overlay**: provider + agent-policy nodes on the graph.
|
||||
- **Navigation + peers reshaping**: peers split into Devices / Agents,
|
||||
`reverse-proxy/clusters` renamed to `self-hosted-proxies`, sidebar
|
||||
repackaged for agent-network focus.
|
||||
|
||||
## Surface added
|
||||
|
||||
### New pages
|
||||
|
||||
| Route | Purpose | Backing module(s) |
|
||||
| ----- | ------- | ----------------- |
|
||||
| `/agent-network` | Redirect to `/agent-network/providers` | page.tsx:7-15 |
|
||||
| `/agent-network/providers` | List + connect providers; header surfaces per-account base URL | providers/page.tsx + AgentProvidersTable + AIProviderModal |
|
||||
| `/agent-network/policies` | Group → Provider authorization with per-policy Limits + Guardrail attach | policies/page.tsx + AgentPoliciesTable + AgentPolicyModal |
|
||||
| `/agent-network/guardrails` | Reusable guardrail sets (model allowlist + prompt capture) | guardrails/page.tsx + AgentGuardrailsTable + AgentGuardrailModal |
|
||||
| `/agent-network/observability` | Tabs: Access Logs / Budget Dashboard / Budget Settings / Log Settings | observability/page.tsx |
|
||||
| `/peers/devices`, `/peers/agents` | Split of `/peers`, shared via `PeersListView` keyed by `kind` | peers/{devices,agents}/page.tsx |
|
||||
| `/reverse-proxy/self-hosted-proxies` | Renamed from `clusters` | self-hosted-proxies/page.tsx |
|
||||
|
||||
Removed in favor of `/agent-network/observability`: `/agent-network/access-log`, `/agent-network/consumption`, `/agent-network/global-controls`.
|
||||
|
||||
### New modules under src/modules/agent-network
|
||||
|
||||
| File | Role |
|
||||
| ---- | ---- |
|
||||
| AIProvidersProvider.tsx (~1158 LOC) | Aggregates every agent-network resource via SWR; normalises snake↔camel; exposes mutators; holds wizard-open state |
|
||||
| AIProviderModal.tsx (~1268 LOC) | Connect / edit provider wizard with per-vendor copy (Bifrost, Portkey, LiteLLM, Cloudflare, Vercel, OpenRouter, custom) |
|
||||
| AIProviderLogo + useProviderCatalog | Catalog-driven brand swatch + SWR hook over `/agent-network/catalog/providers` |
|
||||
| AgentPoliciesTable + AgentPolicyModal + AgentPolicyGuardrailsTab + AgentPolicyLimitsTab | Policies; modal has 3 tabs (Rule, Limits, Guardrails) |
|
||||
| AgentGuardrailsTable + AgentGuardrailModal + AgentGuardrailBrowseModal + AgentGuardrailChecksCell | Guardrails CRUD + attach-from-policy |
|
||||
| AgentBudgetRulesTable + AgentBudgetRuleModal | Account-level budget rules; modal reuses AgentPolicyLimitsTab verbatim |
|
||||
| AgentAccountControlsCard | Three account-wide toggles (Log Collection / Prompt Collection / Redact PII) |
|
||||
| AgentAccessLogTable + AgentAccessLogExpandedRow | Access log on `/events/proxy?agent_network=true` |
|
||||
| AgentConsumptionPanel + AgentConsumptionTable | Token + cost panel: charts + counter table |
|
||||
| table/AgentProvidersTable + AgentProviderActionCell | Providers table + per-row actions |
|
||||
| data/mockData.ts | Domain types and a few residual `MOCK_*` constants (see scrutinize) |
|
||||
|
||||
### Touched non-agent-network areas
|
||||
|
||||
- **control-center**: agent-network overlay (provider + agent-policy nodes); removed the All Networks dropdown; hid the Networks tab in FlowSelector (FlowSelector.tsx:9-14 — enum value kept so `?tab=networks` still type-checks); wrapped `ControlCenterView` in `AIProvidersProvider` (page.tsx:73-83); `agentPolicyNode` clicks routed to a separate state slot (page.tsx:1871-1874). New node renderers: nodes/ProviderNode.tsx, nodes/AgentPolicyNode.tsx (registered at utils/nodes.ts:21-22).
|
||||
- **peers**: Split into Devices and Agents sub-routes; shared via `PeersListView` keyed by `kind` (PeersListView.tsx:24-95). New compact-toolbar `UserFilterSelector` (users/UserFilterSelector.tsx).
|
||||
- **reverse-proxy**: Folder rename `clusters/` → `self-hosted-proxies/`; deleted `ClustersFeaturesCell.tsx`, `ClusterTypeIndicator.tsx`; new ReverseProxyClusterTargetSelector for cluster target type; Private toggle on target modal; body-capture knobs removed; new ReverseProxyEventExpandedRow.
|
||||
- **events**: `ReverseProxyEventsUserCell` rewritten with user + peer fallback (ReverseProxyEventsUserCell.tsx:14-21), shared with the access-log table.
|
||||
- **navigation**: Full repackaging in Navigation.tsx — Agent Network items flattened (no collapsible parent), distinct icons per item; Access Control, Networks, Reverse Proxy, DNS, standalone Guardrails, Consumption, Activity removed (still URL-reachable, per lines 165-171).
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### Page → Provider → Table/Modal hierarchy
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
Nav[Navigation.tsx]
|
||||
Nav --> ProvidersPage[/agent-network/providers/]
|
||||
Nav --> PoliciesPage[/agent-network/policies/]
|
||||
Nav --> GuardrailsPage[/agent-network/guardrails/]
|
||||
Nav --> ObsPage[/agent-network/observability/]
|
||||
|
||||
ProvidersPage --> AIPP1[AIProvidersProvider]
|
||||
PoliciesPage --> AIPP2[AIProvidersProvider]
|
||||
GuardrailsPage --> AIPP3[AIProvidersProvider]
|
||||
ObsPage --> AIPP4[AIProvidersProvider]
|
||||
ObsPage -.wraps.-> GroupsProvider
|
||||
ObsPage -.wraps.-> PeersProvider
|
||||
|
||||
AIPP1 --> ProvTable[AgentProvidersTable]
|
||||
ProvTable --> ProvModal[AIProviderModal]
|
||||
AIPP2 --> PolTable[AgentPoliciesTable]
|
||||
PolTable --> PolModal[AgentPolicyModal]
|
||||
PolModal --> PolGuardTab[AgentPolicyGuardrailsTab]
|
||||
PolModal --> PolLimitsTab[AgentPolicyLimitsTab]
|
||||
PolGuardTab --> GuardBrowse[AgentGuardrailBrowseModal]
|
||||
PolGuardTab --> GuardModal[AgentGuardrailModal]
|
||||
AIPP3 --> GuardTable[AgentGuardrailsTable]
|
||||
GuardTable --> GuardModal
|
||||
AIPP4 --> Tabs[Tabs]
|
||||
Tabs --> AccessLog[AgentAccessLogTable]
|
||||
Tabs --> Consumption[AgentConsumptionPanel]
|
||||
Tabs --> BudgetRules[AgentBudgetRulesTable]
|
||||
Tabs --> AccountCtl[AgentAccountControlsCard]
|
||||
BudgetRules --> BudgetModal[AgentBudgetRuleModal]
|
||||
BudgetModal -.reuses.-> PolLimitsTab
|
||||
```
|
||||
|
||||
### AI Observability tab page
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
Page[AIObservabilityPage] --> RA[RestrictedAccess<br/>permission.services.read]
|
||||
RA --> GP[GroupsProvider]
|
||||
GP --> PP[PeersProvider]
|
||||
PP --> AIP[AIProvidersProvider]
|
||||
AIP --> Tabs[Tabs / TabsList]
|
||||
Tabs --> T1[Access Logs<br/>AgentAccessLogTable]
|
||||
Tabs --> T2[Budget Dashboard<br/>AgentConsumptionPanel]
|
||||
Tabs --> T3[Budget Settings<br/>AgentBudgetRulesTable]
|
||||
Tabs --> T4[Log Settings<br/>AgentAccountControlsCard]
|
||||
T1 -.GET.-> EP[/events/proxy?agent_network=true/]
|
||||
T2 -.GET poll 5s.-> CONS[/agent-network/consumption/]
|
||||
T3 -.GET/PUT.-> BR[/agent-network/budget-rules/]
|
||||
T4 -.GET/PUT.-> ST[/agent-network/settings/]
|
||||
```
|
||||
|
||||
### Data fetch path
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
Page[Page component] --> Prov[AIProvidersProvider]
|
||||
Prov -->|useFetchApi| SWR[(SWR cache<br/>key = URL)]
|
||||
SWR -.GET.-> P[/agent-network/providers/]
|
||||
SWR -.GET.-> POL[/agent-network/policies/]
|
||||
SWR -.GET.-> G[/agent-network/guardrails/]
|
||||
SWR -.GET.-> BR[/agent-network/budget-rules/]
|
||||
SWR -.GET ignoreError.-> ST[/agent-network/settings/]
|
||||
SWR -.GET.-> CAT[/agent-network/catalog/providers/]
|
||||
SWR -.GET pageSize=100.-> EVT[/events/proxy agent_network=true/]
|
||||
Prov --> Mut[useApiCall.post/put/del]
|
||||
Mut -.on success.-> MutateSWR[SWR mutate keys]
|
||||
Prov --> Children[Tables / Modals via useAIProviders]
|
||||
```
|
||||
|
||||
Every list view reaches management through SWR over `/api/agent-network/*`. The provider context maps snake-case payloads to camelCase domain types (`fromAPI`, `policyFromAPI`, `guardrailFromAPI`, `budgetRuleFromAPI`, `settingsFromAPI`, `accessLogFromAPI` — AIProvidersProvider.tsx:138-562) and back via matching `*ToRequest` adaptors. The access log piggy-backs on `/events/proxy` with `agent_network=true&page_size=100` (line 707-709) and decodes LLM-specific fields from per-event `metadata`. Group IDs on events are resolved to current names through the surrounding GroupsProvider catalog (lines 515-521, 717-731) — no extra round trip. Mutators run `*ToRequest`, await `useApiCall.post/put/del`, call SWR `mutate()`, then `notify`. Errors caught and surfaced via `notify` — no exceptions escape into render. The Connect Provider modal's open state lives in the provider itself (`isWizardOpen` at lines 732-735) so the providers-page empty-state CTA and the table's + button share one modal. Control-center re-fetches `/agent-network/{providers,policies}` directly on top of `AIProvidersProvider` — SWR de-dupes but the code path is harder to reason about.
|
||||
|
||||
## Public contracts consumed
|
||||
|
||||
- `GET/POST /api/agent-network/providers`, `PUT/DELETE /:id`
|
||||
- `GET/POST /api/agent-network/policies`, `PUT/DELETE /:id`
|
||||
- `GET/POST /api/agent-network/guardrails`, `PUT/DELETE /:id`
|
||||
- `GET/POST /api/agent-network/budget-rules`, `PUT/DELETE /:id`
|
||||
- `GET/PUT /api/agent-network/settings` (ignoreError-tolerant; 404 = not yet bootstrapped — auto-bootstrap on first provider create via `bootstrap_cluster` field — AIProvidersProvider.tsx:737-760)
|
||||
- `GET /api/agent-network/catalog/providers` (read-only declarative; backend owns vendor list, IDs, brand colors, models, extra_headers, identity_injection — useProviderCatalog.ts:6-95)
|
||||
- `GET /api/agent-network/consumption` (polled every 5s on Budget Dashboard — ConsumptionPanel.tsx:53,65-71)
|
||||
- `GET /api/events/proxy?agent_network=true&page_size=100` (shared with Proxy Events)
|
||||
- `permission?.services?.read` gates every agent-network route via RestrictedAccess.
|
||||
|
||||
`AIProviderId` is a closed union in dashboard types (data/mockData.ts:8-21) but the converter tolerates anything the backend ships — unknown ids fall through to `"custom"` (AIProvidersProvider.tsx:497-506). Catalog values are pure read-through: anything declared in `extra_headers` renders in the modal automatically, copy keyed by header name (`EXTRA_HEADER_UI` in AIProviderModal.tsx:61-89), labeled-fallback for unknown ones.
|
||||
|
||||
## Invariants
|
||||
|
||||
- Provider context wrap order on user-attribution pages: `GroupsProvider > PeersProvider > AIProvidersProvider` (observability/page.tsx:87-89). Reverse it and access-log group resolution silently drops names.
|
||||
- Every agent-network route checks `permission?.services?.read` via `RestrictedAccess` (observability/page.tsx:85, providers/page.tsx:184, policies/page.tsx:53, guardrails/page.tsx:55).
|
||||
- Modal `key={open ? 1 : 0}` pattern is used to force unmount/remount on close so internal `useState` resets between edits (AgentBudgetRuleModal.tsx:60, AgentPolicyModal.tsx:66). Removing this would leak prior-row state into a new-row session.
|
||||
- `mockData.ts` is the canonical home for ALL agent-network domain types; `MOCK_*` constants must never reach a production code path. One leak remains (below).
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
|
||||
- **Tab-state URL hand-off is one-way.** observability/page.tsx:53-58 reads `?tab=` on mount (despite the file comment at line 28 saying URL hand-off is future) but `setTab` does NOT push back, so reload preserves the chosen tab only if it came in via the link. Inconsistent with control-center (page.tsx:1817-1831).
|
||||
- **Provider overlay runs only in `applySingleGroupView` / `applyPeerView`** (control-center/page.tsx:557, 1159-1166). User view does NOT show providers — if agent-network is a primary lens, that's a gap.
|
||||
- **Two useEffects race to invalidate the control-center layout.** page.tsx:1655-1657 drops `layoutInitialized` when `agentPolicies` / `agentProviders` arrive; the main effect (1786-1799) also lists them as deps. Functional but fragile — watch for flash-of-empty-graph.
|
||||
- **`updateProvider` / `updatePolicy` / `updateBudgetRule` use `??` on `enabled`** (AIProvidersProvider.tsx:784, 859, 1018). Toggle paths are safe; any caller sending `enabled: false` thinking "leave it off" gets `existing.enabled` instead. Audit modal callers.
|
||||
- **Form validation in modals is minimal.** Window-seconds picker — mockData.ts:209-215 documents "minimum 60 — one minute" but there is no matching UI guard in PolicyLimitsTab; the backend validator is the enforcement point.
|
||||
|
||||
### Security
|
||||
|
||||
- **No client-side enforcement claims** — every cap, allowlist, and toggle is display + edit; proxy is the source of truth for deny decisions (AccessLogTable.tsx:177-191 renders backend-emitted `denyReason` as-is).
|
||||
- **Prompt display is gated by what the backend stamps.** When `enable_prompt_collection` is OFF the proxy must not put prompt/completion into event metadata; the dashboard renders whatever it gets verbatim (AccessLogTable lines 532-534, AccessLogExpandedRow.tsx:42-57). No UI filter on top of backend collection switches.
|
||||
- Account Controls disables `Redact PII` when `Prompt Collection` is off (AgentAccountControlsCard.tsx:122) and clears it on off-transition (line 100), but relies on backend to enforce the same gate at write — confirm PUT handler rejects `redact_pii=true && enable_prompt_collection=false`.
|
||||
- **Bifrost identity-header overrides**: empty-string vs nil semantics documented in AIProvidersProvider.tsx:772-781 ("omitted = preserve, empty = explicit clear"). Mishandling could leak group attribution to a header the operator thought disabled. Focused read of Bifrost code path in AIProviderModal.tsx recommended.
|
||||
|
||||
### Accessibility
|
||||
|
||||
- Observability TabsList (observability/page.tsx:96-113) uses the shared Tabs component — should inherit Radix roving-tabindex. All four TabsTriggers carry only icon + text, no `aria-label`; fine because text is visible.
|
||||
- Modal focus traps are inherited from the shared Modal; agent-network modals don't override them. Quick keyboard pass recommended.
|
||||
- `EndpointBadge` Copy button (providers/page.tsx:66-76) has an `aria-label`, good.
|
||||
|
||||
### Performance
|
||||
|
||||
- `AgentConsumptionPanel` polls `/agent-network/consumption` every 5s (ConsumptionPanel.tsx:53,70). Tab switches unmount the panel, so the poll stops — verify in network panel.
|
||||
- `AgentAccessLogTable` is hard-capped at 100 rows via `page_size=100` (AIProvidersProvider.tsx:707-709). Server-side pagination is future work; high-traffic tenants miss everything past row 100 — known limitation.
|
||||
- Observability page mounts providers ONCE at page level (observability/page.tsx:87-89); tab switches keep SWR cache hot. Moving the provider mount inside `TabsContent` would re-fetch the access log on every switch.
|
||||
|
||||
### Visual consistency
|
||||
|
||||
- The observability tab style mirrors peers/page.tsx. Outer Tabs `pt-4 pb-0 mb-0`, TabsList `px-8` (observability/page.tsx:94-96) — confirm chrome height matches so the page doesn't visually jump.
|
||||
- Sidebar: `Boxes` for Providers, `AccessControlIcon` for Policies, `TelescopeIcon` for AI Observability (Navigation.tsx:113,120,133). Reusing `AccessControlIcon` makes Policies look identical to the (now hidden) Access Control item — if Access Control ever comes back, they collide.
|
||||
- `AgentNetworkIcon` is used in breadcrumbs on every agent-network page but NOT in the sidebar (per-page icons instead). Deliberate departure — record so it doesn't get reverted.
|
||||
|
||||
## Test coverage
|
||||
|
||||
- **Cypress**: One file (`cypress/e2e/test.cy.ts`) covering only the install-page copy-to-clipboard flow. NOTHING covers agent-network UI.
|
||||
- **Component / unit tests**: `src/utils/version.test.ts` is the only `.test.*` file in the repo. The agent-network modules ship without component tests.
|
||||
- Data-cy hooks exist on key controls: `save-account-controls` (AgentAccountControlsCard.tsx:71), `enable-log-collection`, `enable-prompt-collection`, `redact-pii`, plus existing `data-cy={policy.name}` / `data-cy={provider.name}` on ActiveInactiveRow. Sufficient hooks for Cypress flows; none written yet.
|
||||
- **Tooling gap (pre-existing):** `npm run lint` (`next lint`) is broken in Next 16 — the `lint` subcommand was removed from the Next CLI in 16.x, so the dashboard effectively has no working lint gate. The fix is to add either a flat-config `eslint .` script or wire ESLint via an explicit `eslint-config-next` invocation.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- **`data/mockData.ts` still contains `MOCK_GROUPS`, `MOCK_PROVIDERS`, `MOCK_PEERS`.** Only `MOCK_GROUPS` is referenced from production — AgentPoliciesTable.tsx:45,76 uses it as a name-lookup fallback when a policy references a group ID the real GroupsProvider doesn't know about. `MOCK_PROVIDERS` / `MOCK_PEERS` are unreferenced; safe to delete. The file is `/* eslint-disable */` so dead-code warnings don't flag them.
|
||||
- **Tab-state URL hand-off on observability page is one-way** (read-only).
|
||||
- **Access log hard-capped at 100 rows**; no server-side pagination.
|
||||
- **No optimistic updates.** All mutations are round-trip; failures rollback via SWR revalidation.
|
||||
- **`FlowView.NETWORKS` retained but hidden** from FlowSelector (FlowSelector.tsx:9-14). Old `?tab=networks` links still route to the hidden view because `applyNetworksView` still runs.
|
||||
- **Redirects are not query-preserving** — `router.replace("/peers/devices")` (peers/page.tsx:13) strips any incoming filter params.
|
||||
- **Control-center cross-fetches** `/agent-network/{providers,policies}` directly on top of `AIProvidersProvider`. Could be collapsed.
|
||||
- **Sidebar permanently hides Access Control, Networks, Reverse Proxy, standalone Guardrails, DNS, Activity, Consumption.** Routes still resolve via URL (Navigation.tsx:165-171); intentional.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Upstream API contracts: [shared/api](10-shared-api.md)
|
||||
- Backend persistence: [management/store](20-management-store.md)
|
||||
- Backend handler wiring: [management/handlers + wiring](22-management-handlers-wiring.md)
|
||||
- End-to-end flow narrative: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level overview: [../00-overview.md](../00-overview.md)
|
||||
251
docs/agent-networks/modules/50-path-routed-providers.md
Normal file
251
docs/agent-networks/modules/50-path-routed-providers.md
Normal file
@@ -0,0 +1,251 @@
|
||||
# path-routed providers — Vertex AI + Bedrock
|
||||
|
||||
This guide pulls the **path-routed** provider story together in one place
|
||||
because it crosses the catalog, the synthesiser, the request parser, and the
|
||||
router. The relevant building blocks are the `llm_router` /
|
||||
`llm_request_parser` middlewares
|
||||
([31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md)), the
|
||||
per-provider parser surface ([32-proxy-llm-parsers.md](32-proxy-llm-parsers.md)),
|
||||
and the synthesiser's catalog → `ProviderRoute` mapping
|
||||
([21-management-agentnetwork.md](21-management-agentnetwork.md)).
|
||||
|
||||
Sibling modules: [31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md)
|
||||
(router + request parser) and [32-proxy-llm-parsers.md](32-proxy-llm-parsers.md)
|
||||
(Bedrock parser + pricing).
|
||||
|
||||
---
|
||||
|
||||
## What "path-routed" means
|
||||
|
||||
Most catalog providers carry the model in the request **body** (`{"model": …}`),
|
||||
so `llm_router` selects an upstream by matching the model name against each
|
||||
provider's `Models` claim. Two providers instead carry the model in the **URL
|
||||
path**, so they are routed by path before the model/vendor table is consulted:
|
||||
|
||||
| Catalog id | Style flag | Request path shape |
|
||||
|---|---|---|
|
||||
| `vertex_ai_api` | `IsVertexPathStyle` → `ProviderRoute.Vertex` | `/v1/projects/{project}/locations/{region}/publishers/{publisher}/models/{model}:{action}` |
|
||||
| `bedrock_api` | `IsBedrockPathStyle` → `ProviderRoute.Bedrock` | `/model/{modelId}/{action}` (optionally behind `/bedrock`) |
|
||||
|
||||
The catalog declares the style with
|
||||
[`catalog.IsVertexPathStyle` / `catalog.IsBedrockPathStyle`](../../../management/server/agentnetwork/catalog/catalog.go)
|
||||
and the synthesiser copies the result onto the router route as the `Vertex` /
|
||||
`Bedrock` booleans
|
||||
([synthesizer.go:450-451](../../../management/server/agentnetwork/synthesizer.go)).
|
||||
On the request leg `llm_router.Invoke` dispatches `isVertexPath` / `isBedrockPath`
|
||||
**before** the model lookup
|
||||
([llm_router/middleware.go:138-216](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
so a model the parser extracted from the path can't be claimed by a same-vendor
|
||||
*body-routed* provider (e.g. `claude-*` on `api.anthropic.com`).
|
||||
|
||||
## Google Vertex AI (`vertex_ai_api`)
|
||||
|
||||
### Catalog entry
|
||||
|
||||
`KindProvider`, parser surface left unset on the catalog entry — the request
|
||||
parser picks the parser from the URL **publisher** segment, not from
|
||||
`ParserID`. Upstream host is `<region>-aiplatform.googleapis.com`
|
||||
(`https://aiplatform.googleapis.com` for the `global` location). The catalog
|
||||
lists the Claude-on-Vertex lineup (`claude-opus-4-*`, `claude-sonnet-4-*`,
|
||||
`claude-haiku-4-5`, `claude-fable-5`) at the same per-token rates as the
|
||||
first-party Anthropic entry
|
||||
([catalog.go:333-363](../../../management/server/agentnetwork/catalog/catalog.go)).
|
||||
|
||||
### Credential — service-account OAuth (`keyfile::`)
|
||||
|
||||
Vertex does **not** accept a static API key. The operator sets the provider
|
||||
`api_key` to:
|
||||
|
||||
```
|
||||
keyfile::<base64 of the GCP service-account JSON key>
|
||||
```
|
||||
|
||||
The synthesiser recognises the `keyfile::` prefix in `providerAuthHeader`
|
||||
([synthesizer.go:897-903](../../../management/server/agentnetwork/synthesizer.go)),
|
||||
emits **no** static auth value, and carries the base64 key material on the
|
||||
route as `GCPServiceAccountKeyB64`
|
||||
([factory.go:56-61](../../../proxy/internal/middleware/builtin/llm_router/factory.go)).
|
||||
At request time the router mints a short-lived OAuth2 access token from the key
|
||||
(cloud-platform scope) and injects `Authorization: Bearer <access-token>` —
|
||||
never the key itself
|
||||
([llm_router/middleware.go:621-692](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)):
|
||||
|
||||
- One auto-refreshing `oauth2.TokenSource` is cached per key (keyed by a
|
||||
SHA-256 of the base64 material), so token minting happens once and refreshes
|
||||
amortise across requests.
|
||||
- Mint / refresh is bounded by a 10s timeout HTTP client (`gcpTokenTimeout`) so
|
||||
a slow Google token endpoint can't hang the request.
|
||||
- A malformed key or an unreachable token endpoint fails the request with
|
||||
`llm_policy.upstream_auth_failed` at HTTP **502** (an upstream problem, not a
|
||||
policy denial) — see `denyUpstreamAuth`.
|
||||
|
||||
### Metering — Anthropic-on-Vertex only
|
||||
|
||||
The request parser extracts `{publisher, model, action}` from the path
|
||||
(`parseVertexPath`, [llm_request_parser/middleware.go:237-263](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)),
|
||||
strips the `@version` suffix from the model, and maps the publisher to a parser
|
||||
surface via `vertexPublisherVendor`:
|
||||
|
||||
- `anthropic` → `llm.provider="anthropic"` → metered through the Anthropic
|
||||
parser, priced under the **`anthropic`** block in `defaults_pricing.yaml`
|
||||
(the parser emits the standard Anthropic provider label, so Vertex Claude
|
||||
reuses first-party Anthropic prices).
|
||||
- `openai` → `llm.provider="openai"` (reserved; not in the catalog lineup
|
||||
today).
|
||||
- anything else (notably `google` / Gemini) → empty vendor → **no parser**.
|
||||
|
||||
**Gemini is intentionally denied as unmeterable.** When the parser emits no
|
||||
`llm.provider` for a Vertex publisher, `llm_router` returns
|
||||
`llm_policy.unmeterable_publisher` (403) rather than forwarding the request
|
||||
uncounted — serving it would bypass token / budget metering
|
||||
([llm_router/middleware.go:144-162, 712-728](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)).
|
||||
A Gemini parser would lift this restriction; until then the `google` publisher
|
||||
is omitted from the catalog.
|
||||
|
||||
> Caveat: cross-region inference profiles in `eu` / `apac` carry a ~10% price
|
||||
> premium that the base per-token rates do **not** model — cost annotations for
|
||||
> those regions read low. Operators who need exact regional billing override
|
||||
> the affected entries in `pricing.yaml`.
|
||||
|
||||
## AWS Bedrock (`bedrock_api`)
|
||||
|
||||
### Catalog entry
|
||||
|
||||
`KindProvider`, upstream host `bedrock-runtime.<region>.amazonaws.com`. Metered
|
||||
models are the Anthropic-on-Bedrock lineup (`anthropic.claude-*`) plus Amazon
|
||||
Nova and Llama 3.3 entries
|
||||
([catalog.go:300-332](../../../management/server/agentnetwork/catalog/catalog.go)).
|
||||
Anthropic-on-Bedrock reuses the first-party Claude prices (with additive cache
|
||||
buckets); Nova / Llama report no cache, so cost is `input + output`.
|
||||
|
||||
### Credential — static bearer token
|
||||
|
||||
Bedrock uses the **AWS Bedrock API key** as a static bearer. The operator sets
|
||||
the provider `api_key` directly (no `keyfile::` prefix); the catalog template
|
||||
is `Authorization: Bearer ${API_KEY}`
|
||||
([catalog.go:306-307](../../../management/server/agentnetwork/catalog/catalog.go)).
|
||||
No token minting — the synthesiser substitutes the key into the template and
|
||||
the router injects the resulting `Authorization` header after stripping inbound
|
||||
vendor auth (including client-supplied AWS SigV4 material: `X-Amz-Date`,
|
||||
`X-Amz-Security-Token`, `X-Amz-Content-Sha256`, see `strippedAuthHeaders`).
|
||||
|
||||
### Model id form — cross-region inference profiles
|
||||
|
||||
Bedrock model ids in the request path must be the cross-region
|
||||
**inference-profile** form, e.g.
|
||||
`eu.anthropic.claude-sonnet-4-5-20250929-v1:0`. The bare
|
||||
`anthropic.claude-…` id is rejected by AWS. `normalizeBedrockModel`
|
||||
([llm_request_parser/middleware.go:398-414](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go))
|
||||
strips the region prefix (`us.` / `eu.` / `apac.` / `global.`), an optional ARN
|
||||
wrapper, and the `-YYYYMMDD-vN[:N]` version/throughput suffix so the normalised
|
||||
id (`anthropic.claude-sonnet-4-5`) matches the catalog/pricing key.
|
||||
|
||||
### Supported endpoints + actions
|
||||
|
||||
`/model/{modelId}/{action}` where action ∈ `invoke`,
|
||||
`invoke-with-response-stream`, `converse`, `converse-stream`
|
||||
([llm_request_parser/middleware.go:363-390](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)).
|
||||
`invoke` / `converse` are non-streaming; the `-stream` actions set the streaming
|
||||
flag.
|
||||
|
||||
- **InvokeModel** body uses the vendor-native shape — for Anthropic that means
|
||||
`"anthropic_version":"bedrock-2023-05-31"` and snake_case usage with additive
|
||||
cache buckets.
|
||||
- **Converse** uses the unified camelCase shape with a precomputed `totalTokens`.
|
||||
- The `BedrockParser` reads both shapes on the response leg
|
||||
([bedrock.go](../../../proxy/internal/llm/bedrock.go)); the request parser
|
||||
doesn't need to distinguish them (`ParseRequest` is a no-op — model + stream
|
||||
come from the path).
|
||||
|
||||
### Streaming — AWS binary event-stream
|
||||
|
||||
The `-stream` actions return `application/vnd.amazon.eventstream` (the AWS
|
||||
binary event-stream framing), and streaming **is metered**.
|
||||
`accumulateBedrockStream`
|
||||
([llm_response_parser/streaming_bedrock.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock.go))
|
||||
decodes the frames with `aws-sdk-go-v2/aws/protocol/eventstream`:
|
||||
|
||||
- InvokeModel `chunk` frames wrap a base64 `{"bytes":…}` payload carrying a
|
||||
vendor-native (Anthropic) stream event — folded through the shared Anthropic
|
||||
stream accumulator.
|
||||
- Converse `contentBlockDelta` frames carry text; the trailing `metadata` frame
|
||||
carries the final usage block.
|
||||
- A truncated stream (cut at the body-tap capture cap) decodes best-effort:
|
||||
frames up to the cut are applied and partial usage is returned.
|
||||
|
||||
### Optional `/bedrock` gateway-namespace prefix
|
||||
|
||||
Clients may place an optional `/bedrock` prefix before the native path
|
||||
(`/bedrock/model/{modelId}/{action}`) to disambiguate Bedrock from other
|
||||
providers that also use `/model/...`. Both the request parser
|
||||
(`trimBedrockNamespace`) and the router (`splitBedrockNamespace`) accept it.
|
||||
When the prefix is present, the router sets
|
||||
`RewriteUpstream.StripPathPrefix = "/bedrock"` so the **native** path
|
||||
(`/model/...`) is what reaches `bedrock-runtime.<region>.amazonaws.com`
|
||||
([llm_router/middleware.go:168-184, 320-348](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)).
|
||||
|
||||
## Model allowlist on path-routed providers
|
||||
|
||||
Because the model lives in the URL rather than the body, a path-routed provider
|
||||
credential could otherwise be used for any model the upstream supports. The
|
||||
router still enforces the route's `Models` allowlist via `matchPathRoute`
|
||||
([llm_router/middleware.go:370-416](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)):
|
||||
|
||||
1. Filter to routes of the matching style (`Vertex` / `Bedrock`).
|
||||
2. Filter to routes whose `AllowedGroupIDs` authorise the caller's groups
|
||||
(else `no_authorised_provider`).
|
||||
3. Filter to routes that **claim the requested model**. As with body-routed
|
||||
providers, an **empty `Models` list = catch-all** (serve any model);
|
||||
a non-empty list serves only the listed models (else `model_not_routable`).
|
||||
4. Multiple survivors disambiguate by longest `UpstreamPath` prefix match.
|
||||
|
||||
So an operator who lists explicit models on a Vertex/Bedrock provider gets a
|
||||
hard allowlist; an operator who leaves `Models` empty accepts every model the
|
||||
upstream serves (still subject to the unmeterable-publisher gate on Vertex).
|
||||
|
||||
Model-less OpenAI endpoints (`GET /v1/models`) are **never** routed to a
|
||||
Vertex/Bedrock provider — `matchModelless` skips path-routed routes
|
||||
([llm_router/middleware.go:427-462](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
so a model-listing call can't be rewritten onto an upstream that would 404 it.
|
||||
|
||||
## Catalog ↔ pricing cross-check
|
||||
|
||||
Catalog prices and context windows are cross-checked against LiteLLM's
|
||||
`model_prices_and_context_window.json`. The proxy's embedded
|
||||
`defaults_pricing.yaml` covers **every metered first-party model** the catalog
|
||||
enumerates — guarded by
|
||||
`TestDefaultTable_FirstPartyModelCoverage`
|
||||
([pricing/defaults_coverage_test.go](../../../proxy/internal/llm/pricing/defaults_coverage_test.go)),
|
||||
which fails if a catalog model has no embedded price. Bedrock entries are keyed
|
||||
by the **normalised** id the request parser emits (region prefix + version
|
||||
suffix stripped). Vertex Claude carries no Bedrock-style prefix, so it prices
|
||||
straight off the `anthropic` block.
|
||||
|
||||
## Things to scrutinise
|
||||
|
||||
**Security.** The Vertex service-account key is never forwarded — only a minted
|
||||
short-lived bearer. Confirm the key material stays out of access logs (it lives
|
||||
on `ProviderRoute.GCPServiceAccountKeyB64`, not in any emitted metadata key).
|
||||
The unmeterable-publisher deny is the only thing standing between an
|
||||
operator-misconfigured Vertex provider and unmetered Gemini traffic; verify
|
||||
`vertexPublisherVendor` stays conservative (deny by default for unknown
|
||||
publishers).
|
||||
|
||||
**Correctness.** `normalizeBedrockModel` is the join between the wire id and the
|
||||
pricing key — a model that normalises to something not in `defaults_pricing.yaml`
|
||||
meters at `cost.skipped=unknown_model` rather than failing the request. The
|
||||
`/bedrock` prefix strip must run on both the parser side (so the model is
|
||||
extracted) and the router side (so the upstream path is native); a regression in
|
||||
either silently breaks the other.
|
||||
|
||||
**Metering caveats.** eu/apac cross-region Bedrock + Vertex profiles carry a
|
||||
~10% premium not modelled by base pricing — flagged in both the catalog comment
|
||||
and `defaults_pricing.yaml`. Operators needing exact regional billing override
|
||||
the relevant entries.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Router + request-parser detail: [31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md)
|
||||
- Bedrock parser + pricing + SSE / event-stream: [32-proxy-llm-parsers.md](32-proxy-llm-parsers.md)
|
||||
- Catalog → route synthesis + `keyfile::` handling: [21-management-agentnetwork.md](21-management-agentnetwork.md)
|
||||
- Overview: [../00-overview.md](../00-overview.md)
|
||||
2
go.mod
2
go.mod
@@ -35,6 +35,7 @@ require (
|
||||
github.com/DeRuina/timberjack v1.4.2
|
||||
github.com/awnumar/memguard v0.23.0
|
||||
github.com/aws/aws-sdk-go-v2 v1.38.3
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1
|
||||
github.com/aws/aws-sdk-go-v2/config v1.31.6
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.18.10
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
|
||||
@@ -156,7 +157,6 @@ require (
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect
|
||||
github.com/awnumar/memcall v0.4.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6 // indirect
|
||||
|
||||
616
infrastructure_files/getting-started-enterprise.sh
Executable file
616
infrastructure_files/getting-started-enterprise.sh
Executable file
@@ -0,0 +1,616 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# NetBird Enterprise — Getting Started
|
||||
# Single-node bootstrap for a self-hosted NetBird Enterprise stack with the
|
||||
# embedded identity provider. Owner is created via first-login flow.
|
||||
|
||||
SED_STRIP_PADDING='s/=//g'
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "docker-compose"
|
||||
return
|
||||
fi
|
||||
if docker compose --help &> /dev/null; then
|
||||
echo "docker compose"
|
||||
return
|
||||
fi
|
||||
echo "docker-compose is not installed or not in PATH. See https://docs.docker.com/engine/install/" > /dev/stderr
|
||||
exit 1
|
||||
}
|
||||
|
||||
check_openssl() {
|
||||
if ! command -v openssl &> /dev/null; then
|
||||
echo "openssl is not installed or not in PATH." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
rand_secret() {
|
||||
openssl rand -base64 32 | sed "$SED_STRIP_PADDING"
|
||||
}
|
||||
|
||||
rand_b64_key() {
|
||||
openssl rand -base64 32
|
||||
}
|
||||
|
||||
check_nb_domain() {
|
||||
local domain="$1"
|
||||
if [[ -z "$domain" ]]; then
|
||||
echo "The domain cannot be empty." > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
if [[ "$domain" == "netbird.example.com" ]]; then
|
||||
echo "The domain cannot be netbird.example.com" > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
if [[ "$domain" =~ ^[0-9.]+$ ]]; then
|
||||
echo "An IP address is not allowed. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
if [[ ! "$domain" =~ ^[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?(\.[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)+$ ]]; then
|
||||
echo "The value '$domain' is not a valid FQDN. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
check_domain_resolves() {
|
||||
local domain="$1"
|
||||
if command -v getent &> /dev/null && getent hosts "$domain" &> /dev/null; then return 0; fi
|
||||
if command -v host &> /dev/null && host "$domain" &> /dev/null; then return 0; fi
|
||||
if command -v dig &> /dev/null && [[ -n "$(dig +short "$domain" 2>/dev/null)" ]]; then return 0; fi
|
||||
if command -v nslookup &> /dev/null && nslookup "$domain" &> /dev/null; then return 0; fi
|
||||
return 1
|
||||
}
|
||||
|
||||
read_nb_domain() {
|
||||
local value=""
|
||||
echo -n "Enter the FQDN for NetBird (must resolve via DNS, e.g. netbird.my-domain.com): " > /dev/stderr
|
||||
read -r value < /dev/tty
|
||||
if ! check_nb_domain "$value"; then
|
||||
read_nb_domain
|
||||
return
|
||||
fi
|
||||
if ! check_domain_resolves "$value"; then
|
||||
echo "" > /dev/stderr
|
||||
echo "Warning: '$value' does not resolve via DNS from this host." > /dev/stderr
|
||||
echo "Caddy will not be able to issue TLS certificates until it does." > /dev/stderr
|
||||
local confirm=""
|
||||
echo -n "Continue anyway? [y/N]: " > /dev/stderr
|
||||
read -r confirm < /dev/tty
|
||||
if [[ ! "$confirm" =~ ^[Yy]$ ]]; then
|
||||
read_nb_domain
|
||||
return
|
||||
fi
|
||||
fi
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_required() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -r value < /dev/tty
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_secret() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -rs value < /dev/tty
|
||||
echo "" > /dev/stderr
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
# read_yes_no "<prompt>" [<default y|n>]
|
||||
read_yes_no() {
|
||||
local prompt="$1"
|
||||
local default="${2:-n}"
|
||||
local hint
|
||||
if [[ "$default" == "y" ]]; then
|
||||
hint="[Y/n]"
|
||||
else
|
||||
hint="[y/N]"
|
||||
fi
|
||||
echo -n "${prompt} ${hint}: " > /dev/stderr
|
||||
local ans=""
|
||||
read -r ans < /dev/tty
|
||||
if [[ -z "$ans" ]]; then
|
||||
ans="$default"
|
||||
fi
|
||||
case "$ans" in
|
||||
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
|
||||
*) echo "no" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
wait_postgres() {
|
||||
set +e
|
||||
echo -n "Waiting for postgres to become ready"
|
||||
local counter=1
|
||||
while true; do
|
||||
if $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" &> /dev/null; then
|
||||
break
|
||||
fi
|
||||
if [[ $counter -eq 60 ]]; then
|
||||
echo ""
|
||||
echo "Postgres is taking too long. Recent logs:"
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
|
||||
exit 1
|
||||
fi
|
||||
echo -n " ."
|
||||
sleep 2
|
||||
counter=$((counter + 1))
|
||||
done
|
||||
echo " done"
|
||||
set -e
|
||||
}
|
||||
|
||||
init_environment() {
|
||||
check_openssl
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
|
||||
if [[ -f .env ]] || [[ -f docker-compose.yml ]] || [[ -f config.yaml ]] || [[ -f Caddyfile ]]; then
|
||||
echo "Generated files already exist in $(pwd)."
|
||||
echo "If you want to reinitialize the environment, please remove them first:"
|
||||
echo " $DOCKER_COMPOSE_COMMAND down --volumes # removes all containers and volumes"
|
||||
echo " rm -f .env docker-compose.yml Caddyfile config.yaml"
|
||||
echo "Be aware this will remove all data from the database."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "NetBird Enterprise bootstrap"
|
||||
echo ""
|
||||
echo "Traffic flow:"
|
||||
echo " Enables traffic events logging on the management server."
|
||||
echo " When enabled, the NetBird stack also runs NATS along with two"
|
||||
echo " additional containers: netbird-receiver (the traffic log receiver"
|
||||
echo " service) and netbird-enricher (the traffic log enricher service)."
|
||||
echo " It still has to be turned on from the dashboard settings afterwards."
|
||||
echo " See https://docs.netbird.io/manage/activity/traffic-events-logging"
|
||||
NETBIRD_TRAFFIC_FLOW=$(read_yes_no "Enable traffic flow" "n")
|
||||
|
||||
echo ""
|
||||
NETBIRD_DOMAIN=$(read_nb_domain)
|
||||
|
||||
echo ""
|
||||
|
||||
NETBIRD_LICENSE_KEY=$(read_secret "Enter license key (input hidden)")
|
||||
|
||||
GHCR_USERNAME="netbirdExtAccess1"
|
||||
GHCR_TOKEN=$(read_secret "Enter GHCR token (input hidden)")
|
||||
|
||||
POSTGRES_USER="netbird"
|
||||
POSTGRES_DB="netbird"
|
||||
POSTGRES_PASSWORD=$(rand_secret)
|
||||
NETBIRD_ENCRYPTION_KEY=$(rand_b64_key)
|
||||
NETBIRD_RELAY_AUTH_SECRET=$(rand_secret)
|
||||
|
||||
POSTGRES_DSN="host=postgres user=${POSTGRES_USER} password=${POSTGRES_PASSWORD} dbname=${POSTGRES_DB} port=5432 sslmode=disable TimeZone=UTC"
|
||||
NETBIRD_RELAY_ENDPOINT="rels://${NETBIRD_DOMAIN}:443"
|
||||
|
||||
echo ""
|
||||
echo "Selected:"
|
||||
echo " Traffic flow: ${NETBIRD_TRAFFIC_FLOW}"
|
||||
echo " Domain: ${NETBIRD_DOMAIN}"
|
||||
echo ""
|
||||
echo "Rendering files into $(pwd) ..."
|
||||
install -m 600 /dev/null .env
|
||||
render_env >> .env
|
||||
render_docker_compose > docker-compose.yml
|
||||
|
||||
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' docker-compose.yml && rm -f docker-compose.yml.bak
|
||||
fi
|
||||
render_caddyfile > Caddyfile
|
||||
install -m 600 /dev/null config.yaml
|
||||
render_config_yaml >> config.yaml
|
||||
|
||||
echo "Logging in to ghcr.io ..."
|
||||
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
|
||||
unset GHCR_TOKEN
|
||||
|
||||
echo ""
|
||||
echo "Pulling images ..."
|
||||
$DOCKER_COMPOSE_COMMAND pull
|
||||
|
||||
echo ""
|
||||
echo "Starting postgres ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d postgres
|
||||
sleep 2
|
||||
wait_postgres
|
||||
|
||||
echo ""
|
||||
echo "Starting remaining services ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d
|
||||
|
||||
echo ""
|
||||
echo "Done."
|
||||
echo ""
|
||||
echo "Dashboard: https://${NETBIRD_DOMAIN}"
|
||||
echo ""
|
||||
echo "Open the dashboard in a browser to complete the first-login owner setup."
|
||||
echo "All configuration and secrets are stored (mode 600) in $(pwd)/.env"
|
||||
echo ""
|
||||
echo "Tail logs:"
|
||||
echo " cd $(pwd) && $DOCKER_COMPOSE_COMMAND logs -f netbird-server caddy"
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Renderers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
render_env() {
|
||||
cat <<EOF
|
||||
# Generated by getting-started-enterprise.sh
|
||||
# Holds all configuration and secrets for the stack. Mode 600.
|
||||
|
||||
# Features (set by the script; don't edit without re-running)
|
||||
NETBIRD_TRAFFIC_FLOW_ENABLED=${NETBIRD_TRAFFIC_FLOW}
|
||||
|
||||
# Domain
|
||||
NETBIRD_DOMAIN=${NETBIRD_DOMAIN}
|
||||
|
||||
# Image tags. Default to "latest"
|
||||
NETBIRD_DASHBOARD_TAG=${NETBIRD_DASHBOARD_TAG:-latest}
|
||||
NETBIRD_SERVER_TAG=${NETBIRD_SERVER_TAG:-latest}
|
||||
EOF
|
||||
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
NETBIRD_ENRICHER_TAG=${NETBIRD_ENRICHER_TAG:-latest}
|
||||
NETBIRD_RECEIVER_TAG=${NETBIRD_RECEIVER_TAG:-latest}
|
||||
EOF
|
||||
fi
|
||||
|
||||
cat <<EOF
|
||||
|
||||
# License keys
|
||||
EOF
|
||||
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
cat <<EOF
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
EOF
|
||||
fi
|
||||
cat <<EOF
|
||||
NETBIRD_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
EOF
|
||||
|
||||
cat <<EOF
|
||||
|
||||
# Postgres
|
||||
POSTGRES_USER=${POSTGRES_USER}
|
||||
POSTGRES_DB=${POSTGRES_DB}
|
||||
POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN=${POSTGRES_DSN}
|
||||
|
||||
# Relay
|
||||
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT}
|
||||
NETBIRD_RELAY_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
|
||||
|
||||
# Datastore encryption
|
||||
NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}
|
||||
|
||||
# Dashboard OIDC scopes
|
||||
NETBIRD_AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES:-openid profile email groups}
|
||||
EOF
|
||||
}
|
||||
|
||||
render_docker_compose() {
|
||||
render_compose_header
|
||||
render_compose_common
|
||||
render_compose_server
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
render_compose_flow
|
||||
fi
|
||||
render_compose_postgres
|
||||
render_compose_footer
|
||||
}
|
||||
|
||||
render_compose_header() {
|
||||
cat <<'EOF'
|
||||
x-default: &default
|
||||
restart: unless-stopped
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: '500m'
|
||||
max-file: '2'
|
||||
|
||||
services:
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_common() {
|
||||
cat <<'EOF'
|
||||
caddy:
|
||||
<<: *default
|
||||
image: caddy:2
|
||||
container_name: netbird-caddy
|
||||
networks: [netbird]
|
||||
environment:
|
||||
- CADDY_SECURE_DOMAIN=${NETBIRD_DOMAIN}
|
||||
ports:
|
||||
- '443:443'
|
||||
- '443:443/udp'
|
||||
- '80:80'
|
||||
volumes:
|
||||
- netbird_caddy_data:/data
|
||||
- ./Caddyfile:/etc/caddy/Caddyfile
|
||||
|
||||
dashboard:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/dashboard-cloud:${NETBIRD_DASHBOARD_TAG}
|
||||
container_name: netbird-dashboard
|
||||
networks: [netbird]
|
||||
environment:
|
||||
- NETBIRD_MGMT_API_ENDPOINT=https://${NETBIRD_DOMAIN}
|
||||
- NETBIRD_MGMT_GRPC_API_ENDPOINT=https://${NETBIRD_DOMAIN}
|
||||
- AUTH_AUDIENCE=netbird-dashboard
|
||||
- AUTH_CLIENT_ID=netbird-dashboard
|
||||
- AUTH_CLIENT_SECRET=
|
||||
- AUTH_AUTHORITY=https://${NETBIRD_DOMAIN}/oauth2
|
||||
- USE_AUTH0=false
|
||||
- AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES}
|
||||
- AUTH_REDIRECT_URI=/nb-auth
|
||||
- AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
|
||||
- NETBIRD_TOKEN_SOURCE=accessToken
|
||||
- NGINX_SSL_PORT=443
|
||||
- LETSENCRYPT_DOMAIN=
|
||||
- LETSENCRYPT_EMAIL=
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_server() {
|
||||
cat <<'EOF'
|
||||
netbird-server:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/netbird-server-cloud:${NETBIRD_SERVER_TAG}
|
||||
container_name: netbird-server
|
||||
networks: [netbird]
|
||||
depends_on:
|
||||
dashboard:
|
||||
condition: service_started
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
ports:
|
||||
- '3478:3478/udp'
|
||||
volumes:
|
||||
- netbird_data:/var/lib/netbird
|
||||
- ./config.yaml:/etc/netbird/config.yaml
|
||||
command: ["--config", "/etc/netbird/config.yaml"]
|
||||
environment:
|
||||
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_flow() {
|
||||
cat <<'EOF'
|
||||
nats:
|
||||
<<: *default
|
||||
image: nats:2
|
||||
container_name: netbird-nats
|
||||
networks: [netbird]
|
||||
volumes:
|
||||
- netbird_nats_data:/data
|
||||
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
|
||||
|
||||
enricher:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/flow-enricher-cloud:${NETBIRD_ENRICHER_TAG}
|
||||
container_name: netbird-enricher
|
||||
networks: [netbird]
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
nats:
|
||||
condition: service_started
|
||||
volumes:
|
||||
- netbird_enricher:/var/lib/netbird
|
||||
environment:
|
||||
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
- NB_DATADIR=/var/lib/netbird
|
||||
- NB_MANAGEMENT_STORE_ENGINE=postgres
|
||||
- NB_MANAGEMENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||
- NB_TRAFFIC_EVENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||
- NB_TRAFFIC_EVENT_STORE_ENGINE=postgres
|
||||
- NB_MANAGEMENT_STORE_KEY=${NETBIRD_ENCRYPTION_KEY}
|
||||
- NB_FLOW_ADAPTER_TYPE=nats
|
||||
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
|
||||
- NB_FLOW_NATS_STREAM=traffic-events
|
||||
- NB_METRICS_PORT=9091
|
||||
- NB_PERSISTENCE_RETENTION_PERIOD=168h
|
||||
|
||||
receiver:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/flow-receiver-cloud:${NETBIRD_RECEIVER_TAG}
|
||||
container_name: netbird-receiver
|
||||
networks: [netbird]
|
||||
depends_on:
|
||||
nats:
|
||||
condition: service_started
|
||||
environment:
|
||||
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
- NB_FLOW_LISTEN_PORT=80
|
||||
- NB_FLOW_ADAPTER_TYPE=nats
|
||||
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
|
||||
- NB_FLOW_NATS_STREAM=traffic-events
|
||||
- NB_FLOW_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_postgres() {
|
||||
cat <<'EOF'
|
||||
postgres:
|
||||
<<: *default
|
||||
image: postgres:17
|
||||
container_name: netbird-postgres
|
||||
networks: [netbird]
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
||||
- POSTGRES_DB=${POSTGRES_DB}
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
volumes:
|
||||
- netbird_postgres:/var/lib/postgresql/data
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_footer() {
|
||||
cat <<'EOF'
|
||||
volumes:
|
||||
netbird_data:
|
||||
EOF
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<'EOF'
|
||||
netbird_nats_data:
|
||||
netbird_enricher:
|
||||
EOF
|
||||
fi
|
||||
cat <<'EOF'
|
||||
netbird_postgres:
|
||||
netbird_caddy_data:
|
||||
|
||||
networks:
|
||||
netbird:
|
||||
EOF
|
||||
}
|
||||
|
||||
render_caddyfile() {
|
||||
cat <<'EOF'
|
||||
{
|
||||
servers :80,:443 {
|
||||
protocols h1 h2c h2 h3
|
||||
}
|
||||
}
|
||||
|
||||
(security_headers) {
|
||||
header * {
|
||||
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
|
||||
X-Content-Type-Options "nosniff"
|
||||
X-Frame-Options "SAMEORIGIN"
|
||||
X-XSS-Protection "1; mode=block"
|
||||
-Server
|
||||
Referrer-Policy strict-origin-when-cross-origin
|
||||
}
|
||||
}
|
||||
|
||||
:80 {
|
||||
redir https://{$CADDY_SECURE_DOMAIN}{uri} permanent
|
||||
}
|
||||
|
||||
{$CADDY_SECURE_DOMAIN}:443 {
|
||||
import security_headers
|
||||
# Signal (gRPC over h2c)
|
||||
reverse_proxy /signalexchange.SignalExchange/* h2c://netbird-server:80
|
||||
# Management (gRPC over h2c + HTTP)
|
||||
reverse_proxy /management.ManagementService/* h2c://netbird-server:80
|
||||
reverse_proxy /api/* netbird-server:80
|
||||
reverse_proxy /ws-proxy/* netbird-server:80
|
||||
# Embedded IdP (OAuth2 endpoints served by netbird server)
|
||||
reverse_proxy /oauth2/* netbird-server:80
|
||||
# Relay (WebSocket multiplexed on the same port)
|
||||
reverse_proxy /relay* netbird-server:80
|
||||
EOF
|
||||
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<'EOF'
|
||||
# Flow receiver (gRPC over h2c)
|
||||
reverse_proxy /flow.FlowService/* h2c://receiver:80
|
||||
EOF
|
||||
fi
|
||||
|
||||
cat <<'EOF'
|
||||
# Dashboard
|
||||
reverse_proxy /* dashboard:80
|
||||
}
|
||||
EOF
|
||||
}
|
||||
|
||||
render_config_yaml() {
|
||||
cat <<EOF
|
||||
# NetBird Enterprise server configuration.
|
||||
# Generated by getting-started-enterprise.sh. Mode 600.
|
||||
|
||||
server:
|
||||
listenAddress: ":80"
|
||||
exposedAddress: "https://${NETBIRD_DOMAIN}:443"
|
||||
|
||||
metricsPort: 9090
|
||||
healthcheckAddress: ":9000"
|
||||
|
||||
logLevel: "info"
|
||||
logFile: "console"
|
||||
|
||||
# TLS is terminated by Caddy in front; leave this block empty.
|
||||
tls:
|
||||
certFile: ""
|
||||
keyFile: ""
|
||||
letsencrypt:
|
||||
enabled: false
|
||||
|
||||
authSecret: "${NETBIRD_RELAY_AUTH_SECRET}"
|
||||
dataDir: "/var/lib/netbird/"
|
||||
|
||||
disableAnonymousMetrics: false
|
||||
disableGeoliteUpdate: false
|
||||
|
||||
auth:
|
||||
issuer: "https://${NETBIRD_DOMAIN}/oauth2"
|
||||
localAuthDisabled: false
|
||||
signKeyRefreshEnabled: false
|
||||
dashboardRedirectURIs:
|
||||
- "https://${NETBIRD_DOMAIN}/nb-auth"
|
||||
- "https://${NETBIRD_DOMAIN}/nb-silent-auth"
|
||||
cliRedirectURIs:
|
||||
- "http://localhost:53000/"
|
||||
|
||||
store:
|
||||
engine: "postgres"
|
||||
dsn: "${POSTGRES_DSN}"
|
||||
encryptionKey: "${NETBIRD_ENCRYPTION_KEY}"
|
||||
|
||||
activityStore:
|
||||
engine: "postgres"
|
||||
dsn: "${POSTGRES_DSN}"
|
||||
EOF
|
||||
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
|
||||
trafficFlow:
|
||||
enabled: true
|
||||
address: "https://${NETBIRD_DOMAIN}:443"
|
||||
interval: "60s"
|
||||
EOF
|
||||
fi
|
||||
}
|
||||
|
||||
init_environment
|
||||
@@ -351,6 +351,11 @@ initialize_default_values() {
|
||||
NETBIRD_STUN_PORT=3478
|
||||
|
||||
# Docker images
|
||||
# Record whether the operator explicitly pinned the server/proxy images via
|
||||
# env vars, so the agent-network preset can pick its own defaults without
|
||||
# clobbering an explicit override.
|
||||
NETBIRD_SERVER_IMAGE_EXPLICIT=${NETBIRD_SERVER_IMAGE:+true}
|
||||
NETBIRD_PROXY_IMAGE_EXPLICIT=${NETBIRD_PROXY_IMAGE:+true}
|
||||
DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"}
|
||||
# Combined server replaces separate signal, relay, and management containers
|
||||
NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"}
|
||||
@@ -398,7 +403,53 @@ configure_domain() {
|
||||
return 0
|
||||
}
|
||||
|
||||
apply_agent_network_preset() {
|
||||
# Agent-network turnkey install: built-in Traefik + NetBird Proxy with
|
||||
# NB_PROXY_PRIVATE=true, dashboard locked to agent-network-only mode.
|
||||
# Bypasses every reverse-proxy / proxy / CrowdSec prompt. The only
|
||||
# inputs we still need from the operator are the domain (handled by
|
||||
# configure_domain via NETBIRD_DOMAIN env var or interactive prompt)
|
||||
# and the ACME email — both honor env vars first and fall back to a
|
||||
# prompt only when unset. CrowdSec is intentionally off.
|
||||
REVERSE_PROXY_TYPE="0"
|
||||
ENABLE_PROXY="true"
|
||||
ENABLE_CROWDSEC="false"
|
||||
|
||||
# Agent-network ships dedicated server/proxy images. Honor an explicit
|
||||
# env override; otherwise pin the agent-network builds.
|
||||
if [[ "${NETBIRD_SERVER_IMAGE_EXPLICIT}" != "true" ]]; then
|
||||
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:0.74.0-rc.1"
|
||||
fi
|
||||
if [[ "${NETBIRD_PROXY_IMAGE_EXPLICIT}" != "true" ]]; then
|
||||
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:0.74.0-rc.1"
|
||||
fi
|
||||
|
||||
if [[ -n "${NETBIRD_LETSENCRYPT_EMAIL}" ]]; then
|
||||
TRAEFIK_ACME_EMAIL="${NETBIRD_LETSENCRYPT_EMAIL}"
|
||||
else
|
||||
TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email)
|
||||
fi
|
||||
|
||||
echo "" > /dev/stderr
|
||||
echo "Agent-network preset enabled (NETBIRD_AGENT_NETWORK=true):" > /dev/stderr
|
||||
echo " - reverse proxy: built-in Traefik" > /dev/stderr
|
||||
echo " - NetBird Proxy: enabled with NB_PROXY_PRIVATE=true" > /dev/stderr
|
||||
echo " - server image: ${NETBIRD_SERVER_IMAGE}" > /dev/stderr
|
||||
echo " - proxy image: ${NETBIRD_PROXY_IMAGE}" > /dev/stderr
|
||||
echo " - dashboard: NETBIRD_AGENT_NETWORK_ONLY=true" > /dev/stderr
|
||||
echo " - CrowdSec: disabled" > /dev/stderr
|
||||
echo " - Let's Encrypt email: ${TRAEFIK_ACME_EMAIL}" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
}
|
||||
|
||||
configure_reverse_proxy() {
|
||||
# Short-circuit: agent-network preset locks every reverse-proxy /
|
||||
# proxy / CrowdSec choice and bypasses the interactive prompts.
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
apply_agent_network_preset
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Prompt for reverse proxy type
|
||||
REVERSE_PROXY_TYPE=$(read_reverse_proxy_type)
|
||||
|
||||
@@ -910,6 +961,15 @@ NGINX_SSL_PORT=443
|
||||
# Letsencrypt
|
||||
LETSENCRYPT_DOMAIN=none
|
||||
EOF
|
||||
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
cat <<EOF
|
||||
# Agent-network preset: dashboard hides the standard NetBird surfaces
|
||||
# and exposes only the AI Observability + agent-network configuration
|
||||
# pages. Paired with NB_PROXY_PRIVATE=true on the proxy side.
|
||||
NETBIRD_AGENT_NETWORK_ONLY=true
|
||||
EOF
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -946,6 +1006,17 @@ NB_PROXY_PROXY_PROTOCOL=true
|
||||
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
|
||||
EOF
|
||||
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
cat <<EOF
|
||||
# Agent-network preset: turn the proxy into the private reverse-proxy
|
||||
# ingress for agent-network synth services. Disables the public-facing
|
||||
# surface so the proxy serves only synth-generated routes (the
|
||||
# llm_router-driven LLM endpoints) and the per-account inbound
|
||||
# listeners on the embedded netstack.
|
||||
NB_PROXY_PRIVATE=true
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [[ "$ENABLE_CROWDSEC" == "true" && -n "$CROWDSEC_BOUNCER_KEY" ]]; then
|
||||
cat <<EOF
|
||||
NB_PROXY_CROWDSEC_API_URL=http://crowdsec:8080
|
||||
@@ -1326,12 +1397,20 @@ print_builtin_traefik_instructions() {
|
||||
echo " - 51820/udp (WIREGUARD - (optional) for P2P proxy connections)"
|
||||
fi
|
||||
echo ""
|
||||
echo "This setup is ideal for homelabs and smaller organization deployments."
|
||||
echo "For enterprise environments requiring high availability and advanced integrations,"
|
||||
echo "consider a commercial on-prem license or scaling your open source deployment:"
|
||||
echo ""
|
||||
echo " Commercial license: https://netbird.io/pricing#on-prem"
|
||||
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
echo "For enterprise environments requiring high availability and advanced integrations,"
|
||||
echo "consider a commercial on-prem license:"
|
||||
echo ""
|
||||
echo " Commercial license: https://netbird.ai/pricing"
|
||||
echo " Documentation: https://docs.netbird.io/agent-network"
|
||||
else
|
||||
echo "This setup is ideal for homelabs and smaller organization deployments."
|
||||
echo "For enterprise environments requiring high availability and advanced integrations,"
|
||||
echo "consider a commercial on-prem license or scaling your open source deployment:"
|
||||
echo ""
|
||||
echo " Commercial license: https://netbird.io/pricing#on-prem"
|
||||
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
|
||||
fi
|
||||
echo ""
|
||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||
echo "NetBird Proxy:"
|
||||
@@ -1354,6 +1433,11 @@ print_builtin_traefik_instructions() {
|
||||
echo ""
|
||||
fi
|
||||
fi
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
echo "Note: The public domain is only for setting up secure connections."
|
||||
echo "Your APIs and agent services remain private and are never exposed publicly."
|
||||
echo ""
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
638
infrastructure_files/migrate-to-enterprise.sh
Executable file
638
infrastructure_files/migrate-to-enterprise.sh
Executable file
@@ -0,0 +1,638 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# NetBird — community combined → Enterprise combined migration
|
||||
#
|
||||
# Non-destructive migration: produces docker-compose.override.yml (auto-loaded
|
||||
# by docker compose) and config.yaml.enterprise alongside the operator's
|
||||
# existing files. Original docker-compose.yml and config.yaml are never
|
||||
# modified.
|
||||
#
|
||||
# Steps (all optional, asked interactively):
|
||||
# 1. Image swap — replace community images with enterprise cloud images.
|
||||
# 2. Postgres migration — add Postgres, migrate SQLite data via migrate-store.
|
||||
# 3. Traffic flow — add NATS + flow-enricher + flow-receiver.
|
||||
#
|
||||
# To revert:
|
||||
# docker compose down
|
||||
# rm -f docker-compose.override.yml config.yaml.enterprise
|
||||
# # If Postgres migration was done, also restore the SQLite backup printed
|
||||
# # at the end of this script's run.
|
||||
# docker compose up -d
|
||||
|
||||
OVERRIDE_FILE="docker-compose.override.yml"
|
||||
ENTERPRISE_CONFIG_FILE="config.yaml.enterprise"
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "docker-compose"
|
||||
return
|
||||
fi
|
||||
if docker compose --help &> /dev/null; then
|
||||
echo "docker compose"
|
||||
return
|
||||
fi
|
||||
echo "docker-compose is not installed or not in PATH." > /dev/stderr
|
||||
exit 1
|
||||
}
|
||||
|
||||
check_yq() {
|
||||
if ! command -v yq &> /dev/null; then
|
||||
cat > /dev/stderr <<'EOF'
|
||||
yq is required to parse and update YAML safely.
|
||||
|
||||
macOS: brew install yq
|
||||
Linux: https://github.com/mikefarah/yq/releases (download binary into PATH)
|
||||
Debian: apt-get install yq (Note: must be the mikefarah Go yq, not the Python wrapper.)
|
||||
|
||||
EOF
|
||||
exit 1
|
||||
fi
|
||||
if ! yq --version 2>&1 | grep -q "mikefarah"; then
|
||||
echo "yq is present but appears to be the wrong implementation. The mikefarah Go-based yq is required (https://github.com/mikefarah/yq)." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
check_openssl() {
|
||||
if ! command -v openssl &> /dev/null; then
|
||||
echo "openssl is not installed or not in PATH." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
rand_password() {
|
||||
openssl rand -hex 32
|
||||
}
|
||||
|
||||
read_required() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -r value < /dev/tty
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_secret() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -rs value < /dev/tty
|
||||
echo "" > /dev/stderr
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_yes_no() {
|
||||
local prompt="$1"
|
||||
local default="${2:-n}"
|
||||
local hint
|
||||
if [[ "$default" == "y" ]]; then
|
||||
hint="[Y/n]"
|
||||
else
|
||||
hint="[y/N]"
|
||||
fi
|
||||
echo -n "${prompt} ${hint}: " > /dev/stderr
|
||||
local ans=""
|
||||
read -r ans < /dev/tty
|
||||
if [[ -z "$ans" ]]; then
|
||||
ans="$default"
|
||||
fi
|
||||
case "$ans" in
|
||||
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
|
||||
*) echo "no" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection — read the operator's existing compose to find service names and
|
||||
# paths we need to override. Bail loudly if shape isn't recognised.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
detect_combined_service() {
|
||||
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/netbird-server"))) | .[0].key // ""' "$COMPOSE_FILE"
|
||||
}
|
||||
|
||||
detect_dashboard_service() {
|
||||
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/dashboard"))) | .[0].key // ""' "$COMPOSE_FILE"
|
||||
}
|
||||
|
||||
detect_config_yaml_host_path() {
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/etc/netbird/config.yaml\")) | sub(\":/etc/netbird/config.yaml.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
|
||||
}
|
||||
|
||||
detect_data_volume() {
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/var/lib/netbird\")) | sub(\":/var/lib/netbird.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
|
||||
}
|
||||
|
||||
detect_exposed_address() {
|
||||
yq eval '.server.exposedAddress // ""' "$CONFIG_YAML_HOST"
|
||||
}
|
||||
|
||||
detect_compose_network() {
|
||||
local tag
|
||||
tag=$(yq eval ".services[\"$COMBINED_SERVICE\"].networks | tag" "$COMPOSE_FILE" 2>/dev/null)
|
||||
case "$tag" in
|
||||
"!!seq")
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].networks[0]" "$COMPOSE_FILE"
|
||||
;;
|
||||
"!!map")
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].networks | keys | .[0]" "$COMPOSE_FILE"
|
||||
;;
|
||||
*)
|
||||
echo "default"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Renderers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Build docker-compose.override.yml from the steps the operator selected.
|
||||
# Service names match what we detected on the operator's side.
|
||||
render_override() {
|
||||
cat <<EOF
|
||||
# Generated by migrate-to-enterprise.sh. Mode 644.
|
||||
# Merged with docker-compose.yml automatically by Docker Compose.
|
||||
# Remove this file (and config.yaml.enterprise if present) to revert.
|
||||
|
||||
services:
|
||||
${DASHBOARD_SERVICE}:
|
||||
image: \${NETBIRD_DASHBOARD_IMAGE:-ghcr.io/netbirdio/dashboard-cloud:latest}
|
||||
|
||||
${COMBINED_SERVICE}:
|
||||
image: \${NETBIRD_SERVER_IMAGE:-ghcr.io/netbirdio/netbird-server-cloud:latest}
|
||||
environment:
|
||||
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
EOF
|
||||
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- ./${ENTERPRISE_CONFIG_FILE}:/etc/netbird/config.yaml.enterprise:ro
|
||||
command: ["--config", "/etc/netbird/config.yaml.enterprise"]
|
||||
|
||||
postgres:
|
||||
image: postgres:17
|
||||
container_name: netbird-postgres
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
environment:
|
||||
POSTGRES_USER: netbird
|
||||
POSTGRES_PASSWORD: \${POSTGRES_PASSWORD}
|
||||
POSTGRES_DB: netbird
|
||||
volumes:
|
||||
- netbird_postgres:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U netbird -d netbird"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 20
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
|
||||
nats:
|
||||
image: nats:2
|
||||
container_name: netbird-nats
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
|
||||
volumes:
|
||||
- netbird_nats_data:/data
|
||||
|
||||
flow-enricher:
|
||||
image: ghcr.io/netbirdio/flow-enricher-cloud:latest
|
||||
container_name: netbird-flow-enricher
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
nats:
|
||||
condition: service_started
|
||||
environment:
|
||||
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
NB_DATADIR: /var/lib/netbird
|
||||
NB_MANAGEMENT_STORE_ENGINE: postgres
|
||||
NB_MANAGEMENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
NB_STORE_ENGINE_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
NB_TRAFFIC_EVENT_STORE_ENGINE: postgres
|
||||
NB_TRAFFIC_EVENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
NB_MANAGEMENT_STORE_KEY: \${NETBIRD_ENCRYPTION_KEY}
|
||||
NB_FLOW_ADAPTER_TYPE: nats
|
||||
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
|
||||
NB_FLOW_NATS_STREAM: traffic-events
|
||||
NB_METRICS_PORT: 9091
|
||||
NB_PERSISTENCE_RETENTION_PERIOD: 168h
|
||||
|
||||
flow-receiver:
|
||||
image: ghcr.io/netbirdio/flow-receiver-cloud:latest
|
||||
container_name: netbird-flow-receiver
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
depends_on:
|
||||
nats:
|
||||
condition: service_started
|
||||
environment:
|
||||
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
NB_FLOW_LISTEN_PORT: 80
|
||||
NB_FLOW_ADAPTER_TYPE: nats
|
||||
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
|
||||
NB_FLOW_NATS_STREAM: traffic-events
|
||||
NB_FLOW_AUTH_SECRET: \${NB_FLOW_AUTH_SECRET}
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
- traefik.http.routers.netbird-flow.rule=Host(\`${NETBIRD_HOSTNAME}\`) && PathPrefix(\`/flow.FlowService/\`)
|
||||
- traefik.http.routers.netbird-flow.entrypoints=websecure
|
||||
- traefik.http.routers.netbird-flow.tls=true
|
||||
- traefik.http.routers.netbird-flow.tls.certresolver=letsencrypt
|
||||
- traefik.http.routers.netbird-flow.service=netbird-flow-h2c
|
||||
- traefik.http.routers.netbird-flow.priority=100
|
||||
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.port=80
|
||||
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.scheme=h2c
|
||||
EOF
|
||||
fi
|
||||
|
||||
# Volume declarations for anything new the override introduced
|
||||
local has_volumes="no"
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]] || [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
has_volumes="yes"
|
||||
fi
|
||||
|
||||
if [[ "$has_volumes" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
|
||||
volumes:
|
||||
EOF
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo " netbird_postgres:"
|
||||
fi
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
echo " netbird_nats_data:"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Build config.yaml.enterprise by yq-editing the operator's existing
|
||||
# config.yaml. We don't touch the original file.
|
||||
render_enterprise_config() {
|
||||
local pg_dsn="host=postgres user=netbird password=${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
|
||||
yq eval "
|
||||
.server.store.engine = \"postgres\" |
|
||||
.server.store.dsn = \"$pg_dsn\" |
|
||||
.server.activityStore.engine = \"postgres\" |
|
||||
.server.activityStore.dsn = \"$pg_dsn\" |
|
||||
.server.authStore.engine = \"postgres\" |
|
||||
.server.authStore.dsn = \"$pg_dsn\"
|
||||
" "$CONFIG_YAML_HOST" > "$ENTERPRISE_CONFIG_FILE"
|
||||
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
local flow_addr="${NETBIRD_DOMAIN}"
|
||||
yq eval -i "
|
||||
.server.trafficFlow.enabled = true |
|
||||
.server.trafficFlow.address = \"$flow_addr\" |
|
||||
.server.trafficFlow.interval = \"60s\"
|
||||
" "$ENTERPRISE_CONFIG_FILE"
|
||||
fi
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Execution steps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
resolve_data_volume() {
|
||||
local short="$1"
|
||||
local actual
|
||||
# Resolve project-prefixed volume name from Docker Compose config first.
|
||||
actual=$($DOCKER_COMPOSE_COMMAND config 2>/dev/null | yq eval ".volumes.\"$short\".name" - 2>/dev/null)
|
||||
if [[ -n "$actual" && "$actual" != "null" ]]; then
|
||||
echo "$actual"
|
||||
return
|
||||
fi
|
||||
# Relative bind mount: docker-compose resolves it against the compose
|
||||
# file's directory, but `docker run -v` resolves it against the current
|
||||
# working directory. Normalize to an absolute path so both interpretations
|
||||
# agree (and the printed revert command works from any CWD).
|
||||
if [[ "$short" == ./* || "$short" == ../* ]]; then
|
||||
local compose_dir
|
||||
compose_dir="$(cd "$(dirname "$COMPOSE_FILE")" && pwd)"
|
||||
(
|
||||
cd "$compose_dir"
|
||||
cd "$(dirname "$short")"
|
||||
printf '%s/%s\n' "$(pwd)" "$(basename "$short")"
|
||||
)
|
||||
return
|
||||
fi
|
||||
# Not a named volume (e.g. an absolute bind-mount path) — use it as-is.
|
||||
echo "$short"
|
||||
}
|
||||
|
||||
backup_sqlite() {
|
||||
BACKUP_DIR="$(pwd)/backups/sqlite-pre-enterprise-$(date +%Y%m%d-%H%M%S)"
|
||||
mkdir -p "$BACKUP_DIR"
|
||||
local data_volume_actual
|
||||
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
|
||||
echo "Backing up SQLite store from volume '$data_volume_actual' to $BACKUP_DIR ..."
|
||||
docker run --rm \
|
||||
-v "${data_volume_actual}:/var/lib/netbird:ro" \
|
||||
-v "${BACKUP_DIR}:/backup" \
|
||||
busybox \
|
||||
sh -c 'cp -a /var/lib/netbird/. /backup/ 2>/dev/null || true'
|
||||
local copied
|
||||
copied=$(find "$BACKUP_DIR" -mindepth 1 | head -1)
|
||||
if [[ -z "$copied" ]]; then
|
||||
echo " ⚠ Backup directory is empty — the volume '$data_volume_actual' didn't contain data. Aborting." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
echo " done"
|
||||
}
|
||||
|
||||
run_migrate_store() {
|
||||
echo "Running migrate-store (SQLite → Postgres) ..."
|
||||
$DOCKER_COMPOSE_COMMAND run --rm "$COMBINED_SERVICE" migrate-store --config /etc/netbird/config.yaml.enterprise --verify
|
||||
echo " done"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
init_migration() {
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
check_yq
|
||||
check_openssl
|
||||
|
||||
COMPOSE_FILE="${COMPOSE_FILE:-docker-compose.yml}"
|
||||
|
||||
if [[ ! -f "$COMPOSE_FILE" ]]; then
|
||||
echo "$COMPOSE_FILE not found in $(pwd)." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -f "$OVERRIDE_FILE" ]] || [[ -f "$ENTERPRISE_CONFIG_FILE" ]]; then
|
||||
echo "Migration artifacts already exist in $(pwd):"
|
||||
[[ -f "$OVERRIDE_FILE" ]] && echo " $OVERRIDE_FILE"
|
||||
[[ -f "$ENTERPRISE_CONFIG_FILE" ]] && echo " $ENTERPRISE_CONFIG_FILE"
|
||||
echo ""
|
||||
echo "Either you've already migrated, or a previous run was interrupted."
|
||||
echo "To re-run cleanly: rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
COMBINED_SERVICE=$(detect_combined_service)
|
||||
DASHBOARD_SERVICE=$(detect_dashboard_service)
|
||||
CONFIG_YAML_HOST=$(detect_config_yaml_host_path)
|
||||
DATA_VOLUME=$(detect_data_volume)
|
||||
COMPOSE_NETWORK=$(detect_compose_network)
|
||||
|
||||
if [[ -z "$COMBINED_SERVICE" ]]; then
|
||||
echo "Could not find a service running netbirdio/netbird-server* in $COMPOSE_FILE." > /dev/stderr
|
||||
echo "This script targets the community combined-server deployment." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$DASHBOARD_SERVICE" ]]; then
|
||||
echo "Could not find a service running netbirdio/dashboard* in $COMPOSE_FILE." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$CONFIG_YAML_HOST" ]]; then
|
||||
echo "Could not find a config.yaml mount on $COMBINED_SERVICE (expected to bind-mount to /etc/netbird/config.yaml)." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ ! -f "$CONFIG_YAML_HOST" ]]; then
|
||||
echo "config.yaml host file not found at $CONFIG_YAML_HOST." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$DATA_VOLUME" ]]; then
|
||||
echo "Could not find a volume mounted at /var/lib/netbird on $COMBINED_SERVICE." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Detected existing deployment:"
|
||||
echo " Combined service: $COMBINED_SERVICE"
|
||||
echo " Dashboard: $DASHBOARD_SERVICE"
|
||||
echo " config.yaml: $CONFIG_YAML_HOST"
|
||||
echo " Data volume: $DATA_VOLUME"
|
||||
echo " Network: $COMPOSE_NETWORK"
|
||||
echo ""
|
||||
|
||||
local proceed
|
||||
proceed=$(read_yes_no "Proceed with migration?" "y")
|
||||
if [[ "$proceed" != "yes" ]]; then
|
||||
echo "Aborted."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Step 1 — always (this is the point of the script)
|
||||
MIGRATE_IMAGES="yes"
|
||||
echo ""
|
||||
echo "Step 1: Image swap (community → Enterprise). License key required."
|
||||
NB_LICENSE_KEY=$(read_secret " License key")
|
||||
GHCR_USERNAME="netbirdExtAccess1"
|
||||
GHCR_TOKEN=$(read_secret " GHCR token (input hidden)")
|
||||
|
||||
# Step 2 — optional
|
||||
echo ""
|
||||
MIGRATE_POSTGRES=$(read_yes_no "Step 2: Migrate storage from SQLite to Postgres? (recommended)" "n")
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo ""
|
||||
echo " ⚠ Data will be migrated from SQLite to Postgres. The SQLite store"
|
||||
echo " will be backed up automatically. To fully revert later, restore"
|
||||
echo " that backup and delete docker-compose.override.yml +"
|
||||
echo " config.yaml.enterprise."
|
||||
local confirm
|
||||
confirm=$(read_yes_no " Continue?" "y")
|
||||
if [[ "$confirm" != "yes" ]]; then
|
||||
MIGRATE_POSTGRES="no"
|
||||
echo " Skipping Postgres migration."
|
||||
else
|
||||
POSTGRES_PASSWORD=$(rand_password)
|
||||
fi
|
||||
fi
|
||||
|
||||
# Step 3 — optional, only if Postgres is on (flow requires Postgres)
|
||||
echo ""
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
ENABLE_FLOW=$(read_yes_no "Step 3: Enable traffic flow? (requires Postgres)" "n")
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
# Auth secret MUST match server.authSecret from config.yaml
|
||||
NB_FLOW_AUTH_SECRET=$(yq eval '.server.authSecret // ""' "$CONFIG_YAML_HOST")
|
||||
if [[ -z "$NB_FLOW_AUTH_SECRET" ]] || [[ "$NB_FLOW_AUTH_SECRET" == "null" ]]; then
|
||||
echo "Could not read server.authSecret from $CONFIG_YAML_HOST." > /dev/stderr
|
||||
echo "Flow receiver auth must match the combined server's authSecret." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
|
||||
NETBIRD_DOMAIN=$(detect_exposed_address)
|
||||
if [[ -z "$NETBIRD_DOMAIN" ]] || [[ "$NETBIRD_DOMAIN" == "null" ]]; then
|
||||
NETBIRD_DOMAIN=$(read_required " Public NetBird URL (e.g. https://netbird.example.com)")
|
||||
fi
|
||||
# Strip protocol + port to leave just the hostname for the Traefik Host() rule.
|
||||
NETBIRD_HOSTNAME=$(echo "$NETBIRD_DOMAIN" | sed -E 's,^https?://,,' | sed 's,:.*,,' | sed 's,/.*,,')
|
||||
|
||||
# We need the encryption key from the existing config.yaml for the enricher
|
||||
NETBIRD_ENCRYPTION_KEY=$(yq eval '.server.store.encryptionKey // ""' "$CONFIG_YAML_HOST")
|
||||
if [[ -z "$NETBIRD_ENCRYPTION_KEY" ]] || [[ "$NETBIRD_ENCRYPTION_KEY" == "null" ]]; then
|
||||
echo "Could not read server.store.encryptionKey from $CONFIG_YAML_HOST." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
else
|
||||
ENABLE_FLOW="no"
|
||||
echo "Step 3 (traffic flow) skipped — requires Postgres."
|
||||
fi
|
||||
}
|
||||
|
||||
apply_changes() {
|
||||
echo ""
|
||||
echo "Writing $OVERRIDE_FILE ..."
|
||||
install -m 644 /dev/null "$OVERRIDE_FILE"
|
||||
render_override > "$OVERRIDE_FILE"
|
||||
|
||||
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' "$OVERRIDE_FILE" && rm -f "$OVERRIDE_FILE.bak"
|
||||
fi
|
||||
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo "Writing $ENTERPRISE_CONFIG_FILE ..."
|
||||
install -m 600 /dev/null "$ENTERPRISE_CONFIG_FILE"
|
||||
render_enterprise_config
|
||||
fi
|
||||
|
||||
# Persist secrets that the override file references via env interpolation.
|
||||
# We write them to a .env file in the current directory; docker compose
|
||||
# picks it up automatically.
|
||||
echo "Writing .env additions (mode 600) ..."
|
||||
local ENV_FILE=".env"
|
||||
touch "$ENV_FILE"
|
||||
chmod 600 "$ENV_FILE"
|
||||
{
|
||||
echo ""
|
||||
echo "# Added by migrate-to-enterprise.sh on $(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||
echo "NB_LICENSE_KEY=${NB_LICENSE_KEY}"
|
||||
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
echo "NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}"
|
||||
fi
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo "POSTGRES_PASSWORD=${POSTGRES_PASSWORD}"
|
||||
fi
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
echo "NB_FLOW_AUTH_SECRET=${NB_FLOW_AUTH_SECRET}"
|
||||
echo "NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}"
|
||||
fi
|
||||
} >> "$ENV_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Logging in to ghcr.io ..."
|
||||
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
|
||||
unset GHCR_TOKEN
|
||||
|
||||
echo ""
|
||||
echo "Pulling enterprise images ..."
|
||||
$DOCKER_COMPOSE_COMMAND pull
|
||||
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo ""
|
||||
echo "Stopping existing services (volumes preserved) ..."
|
||||
$DOCKER_COMPOSE_COMMAND down
|
||||
|
||||
backup_sqlite
|
||||
|
||||
echo ""
|
||||
echo "Starting Postgres ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d postgres
|
||||
|
||||
# Wait for healthy
|
||||
local counter=0
|
||||
echo -n "Waiting for Postgres to become ready"
|
||||
while ! $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U netbird -d netbird &> /dev/null; do
|
||||
echo -n " ."
|
||||
sleep 2
|
||||
counter=$((counter + 1))
|
||||
if [[ $counter -ge 60 ]]; then
|
||||
echo ""
|
||||
echo "Postgres did not become ready in 120s. Recent logs:"
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
echo " done"
|
||||
|
||||
run_migrate_store
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Bringing up all services ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d
|
||||
|
||||
echo ""
|
||||
echo "Migration complete."
|
||||
}
|
||||
|
||||
print_summary() {
|
||||
echo ""
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " Summary"
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " Images: swapped to enterprise"
|
||||
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " Storage: Postgres (data migrated from SQLite)"
|
||||
[[ "$MIGRATE_POSTGRES" != "yes" ]] && echo " Storage: SQLite (unchanged)"
|
||||
[[ "$ENABLE_FLOW" == "yes" ]] && echo " Traffic flow: enabled"
|
||||
[[ "$ENABLE_FLOW" != "yes" ]] && echo " Traffic flow: disabled"
|
||||
echo ""
|
||||
echo " Generated files (next to your docker-compose.yml):"
|
||||
echo " $OVERRIDE_FILE"
|
||||
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " $ENTERPRISE_CONFIG_FILE"
|
||||
echo " .env (license key + secrets, mode 600)"
|
||||
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " backups/sqlite-pre-enterprise-*/ (SQLite backup)"
|
||||
echo ""
|
||||
echo " Tail logs:"
|
||||
echo " $DOCKER_COMPOSE_COMMAND logs -f $COMBINED_SERVICE"
|
||||
echo ""
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " To revert"
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " $DOCKER_COMPOSE_COMMAND down"
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
# Resolve project-prefixed volume names now (before override is removed).
|
||||
local pg_volume data_volume_actual
|
||||
pg_volume=$(resolve_data_volume "netbird_postgres")
|
||||
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
|
||||
echo " # Remove the Postgres volume FIRST, before deleting the override file:"
|
||||
echo " docker volume rm $pg_volume"
|
||||
echo " # Restore SQLite from the backup created during this run:"
|
||||
echo " docker run --rm -v ${data_volume_actual}:/var/lib/netbird -v ${BACKUP_DIR}:/backup busybox sh -c 'cp -a /backup/. /var/lib/netbird/'"
|
||||
fi
|
||||
echo " rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
|
||||
echo " # Remove migrate-to-enterprise.sh additions from .env (search for the timestamp marker)"
|
||||
echo " $DOCKER_COMPOSE_COMMAND up -d"
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
init_migration
|
||||
apply_changes
|
||||
print_summary
|
||||
@@ -116,6 +116,24 @@ func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, p
|
||||
c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
// injectAllProxyPolicies prepares an account for the per-peer network-map
|
||||
// computation. It prepends the in-memory agent-network services synthesised
|
||||
// from the account's current provider/policy state to account.Services so
|
||||
// the existing InjectProxyPolicies + injectPrivateServicePolicies walks pick
|
||||
// them up alongside persisted reverse-proxy services. Synthesised services
|
||||
// are never persisted; the account is loaded fresh per cycle so re-prepending
|
||||
// is safe and idempotent. Accounts without agent-network providers get an
|
||||
// empty synth slice — no behaviour change.
|
||||
func (c *Controller) injectAllProxyPolicies(ctx context.Context, account *types.Account) {
|
||||
synth, err := c.repo.SynthesizeAgentNetworkServices(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("synthesise agent-network services for account %s: %v", account.Id, err)
|
||||
} else if len(synth) > 0 {
|
||||
account.Services = append(synth, account.Services...)
|
||||
}
|
||||
account.InjectProxyPolicies(ctx)
|
||||
}
|
||||
|
||||
func (c *Controller) CountStreams() int {
|
||||
return c.peersUpdateManager.CountStreams()
|
||||
}
|
||||
@@ -150,7 +168,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
@@ -281,7 +299,15 @@ func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID s
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
// The affected-peer path MUST mirror sendUpdateAccountPeers (line 171)
|
||||
// here: injectAllProxyPolicies prepends the synthesised agent-network
|
||||
// services BEFORE InjectProxyPolicies + private-service policies run.
|
||||
// Previously this path called only account.InjectProxyPolicies, which
|
||||
// skipped the synth-services prepend — so peer-level changes
|
||||
// (proxy restart, embedded peer connect/disconnect) propagated a
|
||||
// network map that omitted the synth DNS zone, and the agent kept
|
||||
// resolving against the stale or absent record.
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
@@ -399,7 +425,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return fmt.Errorf("failed to get validated peers: %v", err)
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
@@ -497,7 +523,7 @@ func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID st
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s with reason %s/%s", len(peerIDs), accountID, util.GetCallerName(), reason.Operation, reason.Resource)
|
||||
|
||||
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
|
||||
peerIDs: make(map[string]struct{}),
|
||||
@@ -603,19 +629,17 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
startPosture := time.Now()
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
@@ -876,7 +900,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
@@ -3,7 +3,9 @@ package controller
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -16,6 +18,10 @@ type Repository interface {
|
||||
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
|
||||
GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error)
|
||||
// SynthesizeAgentNetworkServices returns the in-memory reverse-proxy
|
||||
// services synthesised from the account's agent-network provider/policy
|
||||
// state. Empty for accounts without agent-network providers.
|
||||
SynthesizeAgentNetworkServices(ctx context.Context, accountID string) ([]*service.Service, error)
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
@@ -50,6 +56,10 @@ func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID s
|
||||
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
func (r *repository) SynthesizeAgentNetworkServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||
return agentnetwork.SynthesizeServices(ctx, r.store, accountID)
|
||||
}
|
||||
|
||||
func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) {
|
||||
return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package agentnetwork
|
||||
|
||||
import "github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
|
||||
// init registers the agent-network service synthesiser with the affectedpeers
|
||||
// resolver. Agent-network reverse-proxy services are synthesised on demand and
|
||||
// never persisted, so the resolver can't load them from the store; without them
|
||||
// it can't fold the embedded proxy peer into the affected set on a client
|
||||
// group/peer change, and the proxy never learns a newly authorised client until
|
||||
// it reconnects. Registered here (rather than via a direct
|
||||
// affectedpeers→agentnetwork import) to avoid an import cycle
|
||||
// (agentnetwork → account → affectedpeers).
|
||||
func init() {
|
||||
affectedpeers.SetAgentNetworkSynthesizer(SynthesizeServices)
|
||||
}
|
||||
749
management/internals/modules/agentnetwork/catalog/catalog.go
Normal file
749
management/internals/modules/agentnetwork/catalog/catalog.go
Normal file
@@ -0,0 +1,749 @@
|
||||
// Package catalog defines the static set of Agent Network providers
|
||||
// recognized by the management server. The catalog is consulted both to
|
||||
// validate provider_id on create/update and to surface the available
|
||||
// providers (and their models) to the dashboard.
|
||||
package catalog
|
||||
|
||||
import "github.com/netbirdio/netbird/shared/management/http/api"
|
||||
|
||||
// Model is the in-memory representation of a catalog model.
|
||||
type Model struct {
|
||||
ID string
|
||||
Label string
|
||||
InputPer1k float64
|
||||
OutputPer1k float64
|
||||
ContextWindow int
|
||||
}
|
||||
|
||||
// ProviderKind groups catalog entries for UI presentation. The split
|
||||
// is semantic, not technical:
|
||||
// - KindProvider: the upstream is a vendor's first-party API (OpenAI,
|
||||
// Anthropic, Mistral, Bedrock, etc.) — NetBird talks straight to
|
||||
// the model provider.
|
||||
// - KindGateway: the upstream is itself a routing / aggregation layer
|
||||
// in front of multiple providers (LiteLLM, Portkey, Helicone, …).
|
||||
// These typically need NetBird identity stamped onto upstream
|
||||
// requests so the gateway's analytics and budgets attribute to the
|
||||
// real caller; that's what IdentityInjection is for.
|
||||
// - KindCustom: the catch-all "OpenAI-compatible self-hosted endpoint"
|
||||
// entry (vLLM, Ollama, custom inference servers).
|
||||
//
|
||||
// Frontend uses Kind to group the provider Select in the modal so an
|
||||
// operator can spot at a glance which catalog entries proxy other
|
||||
// providers vs. talk straight to one. Backend doesn't dispatch on Kind
|
||||
// today; it's purely a presentation hint.
|
||||
type ProviderKind string
|
||||
|
||||
const (
|
||||
KindProvider ProviderKind = "provider"
|
||||
KindGateway ProviderKind = "gateway"
|
||||
KindCustom ProviderKind = "custom"
|
||||
)
|
||||
|
||||
// Provider is the in-memory representation of a catalog provider.
|
||||
type Provider struct {
|
||||
ID string
|
||||
Name string
|
||||
Description string
|
||||
DefaultHost string
|
||||
// Kind groups this entry for UI presentation; see ProviderKind.
|
||||
Kind ProviderKind
|
||||
// AuthHeaderName is the HTTP header the provider's API expects
|
||||
// the credential under (e.g. "Authorization" for OpenAI,
|
||||
// "x-api-key" for Anthropic). Combined with AuthHeaderTemplate
|
||||
// at synthesis time to inject the auth header on every upstream
|
||||
// request.
|
||||
AuthHeaderName string
|
||||
AuthHeaderTemplate string
|
||||
DefaultContentType string
|
||||
BrandColor string
|
||||
// ParserID names the proxy LLM parser surface this provider
|
||||
// speaks (matches llm.Parser.ProviderName: "openai",
|
||||
// "anthropic"). Multiple catalog ids may share a parser surface
|
||||
// (e.g. azure_openai_api and mistral_api both speak the OpenAI
|
||||
// shape). Empty when no parser is yet implemented for the
|
||||
// surface — the proxy middleware then falls back to URL sniffing
|
||||
// or skips request-side enrichment.
|
||||
ParserID string
|
||||
// IdentityInjection, when non-nil, instructs the proxy to stamp
|
||||
// the caller's NetBird identity onto upstream requests under the
|
||||
// configured header names. Used for gateways like LiteLLM that
|
||||
// key budgets and attribution off request headers (the gateway
|
||||
// otherwise has no way to learn which user / group made the call).
|
||||
// The proxy strips the same header names from the inbound request
|
||||
// before stamping ours, so an app can't spoof identity by setting
|
||||
// these headers itself.
|
||||
IdentityInjection *IdentityInjection
|
||||
// ExtraHeaders is a catalog-declared list of additional per-
|
||||
// provider routing/config headers the proxy stamps on every
|
||||
// upstream request. Distinct from AuthHeaderName/Template (which
|
||||
// always carries the API_KEY) and from IdentityInjection (caller
|
||||
// identity). Each entry surfaces an optional input on the
|
||||
// dashboard's provider modal whose value lives on the provider
|
||||
// record's ExtraValues map (keyed by ExtraHeader.Name). Empty
|
||||
// list = no extra inputs rendered. Used today by Portkey for
|
||||
// "x-portkey-config: pc-..." (a saved-config id that resolves
|
||||
// upstream provider + credentials on Portkey's hosted side).
|
||||
ExtraHeaders []ExtraHeader
|
||||
Models []Model
|
||||
}
|
||||
|
||||
// ExtraHeader names a single optional per-provider routing/config
|
||||
// header. Catalog declares N of these per provider type; the operator
|
||||
// fills any subset on the provider record (see Provider.ExtraValues).
|
||||
// At synth time, only entries with a non-empty operator value are
|
||||
// stamped; the proxy's identity-inject middleware applies anti-spoof
|
||||
// (Remove + Add) so a client can't supply these headers themselves.
|
||||
//
|
||||
// UI copy (label / help text / tooltip) for each known Name lives on
|
||||
// the dashboard, not here — the backend's job is just to declare
|
||||
// which wire headers are accepted. New provider needs an extra
|
||||
// header? Add the Name here AND the matching UI copy on the dashboard.
|
||||
type ExtraHeader struct {
|
||||
// Name is the wire header name, e.g. "x-portkey-config".
|
||||
Name string
|
||||
}
|
||||
|
||||
// IdentityInjection describes how the proxy stamps NetBird identity onto
|
||||
// upstream gateway requests. Exactly one shape must be set — they're
|
||||
// mutually exclusive and dispatched by the inject middleware.
|
||||
//
|
||||
// Shape choice tracks the wire convention the upstream gateway uses,
|
||||
// not the vendor name. New gateways with a known shape become a catalog
|
||||
// entry, not a new code path.
|
||||
type IdentityInjection struct {
|
||||
// HeaderPair emits separate headers per identity dimension
|
||||
// (end-user id, tags as CSV). LiteLLM and OpenAI-compatible
|
||||
// self-hosted gateways that read identity from dedicated headers.
|
||||
HeaderPair *HeaderPairInjection
|
||||
// JSONMetadata emits a single header carrying a JSON object with
|
||||
// reserved keys for user / groups / etc. Portkey, Helicone-style
|
||||
// metadata headers, anything that wants a structured envelope.
|
||||
JSONMetadata *JSONMetadataInjection
|
||||
}
|
||||
|
||||
// HeaderPairInjection is the LiteLLM-style wire convention.
|
||||
type HeaderPairInjection struct {
|
||||
// Customizable, when true, marks the wire header names as
|
||||
// operator-overridable: the dashboard surfaces EndUserIDHeader
|
||||
// and TagsHeader as editable inputs (defaults shown as
|
||||
// placeholders) and the synthesizer pulls the actual values from
|
||||
// the provider record's IdentityHeader* fields rather than from
|
||||
// these defaults. An empty operator value disables stamping for
|
||||
// that dimension. Used today for Bifrost, whose log-metadata /
|
||||
// telemetry header prefix (x-bf-lh-* vs x-bf-dim-*) is a
|
||||
// per-operator choice; LiteLLM and similar gateways with a fixed
|
||||
// wire protocol leave this false so the catalog defaults are
|
||||
// authoritative.
|
||||
Customizable bool
|
||||
// EndUserIDHeader receives the caller's display identity (user
|
||||
// email when the peer is attached to a user, else peer.Name),
|
||||
// e.g. "x-litellm-end-user-id".
|
||||
EndUserIDHeader string
|
||||
// TagsHeader receives the caller's NetBird group display names
|
||||
// as a CSV, e.g. "x-litellm-tags".
|
||||
TagsHeader string
|
||||
// TagsInBody, when true, additionally writes the tag list into
|
||||
// the request body's metadata.tags array (a JSON path the
|
||||
// gateway parses for budget enforcement). LiteLLM only honours
|
||||
// metadata.tags for tag-budget gating — its x-litellm-tags
|
||||
// header path feeds spend tracking but bypasses
|
||||
// _tag_max_budget_check entirely. Body inject is skipped when
|
||||
// the request body is empty, truncated, non-JSON, or when an
|
||||
// existing metadata field is a non-object value (defensive: we
|
||||
// never clobber a client-supplied non-object). The header path
|
||||
// remains a robust fallback for spend tracking in those cases.
|
||||
TagsInBody bool
|
||||
// EndUserIDInBody, when true, additionally writes the display
|
||||
// identity into the request body's top-level "user" field (the
|
||||
// OpenAI-standard end-user identifier). LiteLLM resolves the end
|
||||
// user id from headers first then body, so for LiteLLM this is
|
||||
// belt-and-suspenders. It matters when an OpenAI-compatible
|
||||
// gateway downstream of LiteLLM (or OpenAI direct, bypassing
|
||||
// LiteLLM) only reads the body, and as anti-spoof: client-
|
||||
// supplied "user" values are overwritten with our trusted
|
||||
// identity. Same skip rules as TagsInBody.
|
||||
EndUserIDInBody bool
|
||||
}
|
||||
|
||||
// JSONMetadataInjection is the Portkey-style wire convention: a single
|
||||
// header carrying a JSON object. NetBird identity fields land under the
|
||||
// configured reserved keys; missing keys (empty string) are skipped at
|
||||
// emit time.
|
||||
type JSONMetadataInjection struct {
|
||||
// Customizable, when true, marks the JSON keys as operator-
|
||||
// overridable. The dashboard surfaces UserKey and GroupsKey as
|
||||
// editable inputs (the catalog values shown as placeholders) and
|
||||
// the synthesizer pulls the actual JSON-key names from the
|
||||
// provider record's IdentityHeader* fields. Same field reuse as
|
||||
// HeaderPair's customizable path — the dimensions (user identity,
|
||||
// groups) are the same, only the wire encoding differs (JSON key
|
||||
// vs HTTP header name). An empty operator value disables emission
|
||||
// for that dimension. Used today for Cloudflare AI Gateway, whose
|
||||
// cf-aig-metadata header accepts arbitrary JSON keys; Portkey
|
||||
// leaves this false because its keys are reserved by the Portkey
|
||||
// schema.
|
||||
Customizable bool
|
||||
// Header is the wire header name carrying the JSON payload, e.g.
|
||||
// "x-portkey-metadata".
|
||||
Header string
|
||||
// UserKey is the JSON key for the caller's display identity.
|
||||
// Portkey reserves "_user" for this dimension.
|
||||
UserKey string
|
||||
// GroupsKey is the JSON key for the caller's NetBird groups,
|
||||
// emitted as a CSV string value (Portkey requires string values).
|
||||
GroupsKey string
|
||||
// MaxValueLength caps each emitted JSON value, in bytes. Portkey
|
||||
// enforces a 128-char limit per value; oversized values are
|
||||
// truncated rather than failing the request. 0 disables the cap.
|
||||
MaxValueLength int
|
||||
}
|
||||
|
||||
// providers is the canonical list of supported Agent Network providers.
|
||||
// Update this list together with the dashboard's PROVIDER_CATALOG.
|
||||
var providers = []Provider{
|
||||
{
|
||||
ID: "openai_api",
|
||||
Kind: KindProvider,
|
||||
Name: "OpenAI API",
|
||||
Description: "GPT, Responses API, and Embeddings",
|
||||
DefaultHost: "api.openai.com",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#10A37F",
|
||||
ParserID: "openai",
|
||||
// Pricing + context windows cross-checked against LiteLLM's
|
||||
// model_prices_and_context_window.json. Notable corrections from
|
||||
// earlier values: o4-mini repriced from $4/$16 to $1.10/$4.40
|
||||
// per MTok, gpt-4o from $5/$15 to $2.50/$10, and the GPT-5
|
||||
// family context windows split between 1.05M for full-size
|
||||
// models and 272K for mini/nano/codex variants.
|
||||
Models: []Model{
|
||||
{ID: "gpt-5.5", Label: "GPT-5.5", InputPer1k: 0.005, OutputPer1k: 0.030, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.5-pro", Label: "GPT-5.5 Pro", InputPer1k: 0.030, OutputPer1k: 0.180, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4", Label: "GPT-5.4", InputPer1k: 0.0025, OutputPer1k: 0.015, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4-pro", Label: "GPT-5.4 Pro", InputPer1k: 0.030, OutputPer1k: 0.180, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4-mini", Label: "GPT-5.4 Mini", InputPer1k: 0.00075, OutputPer1k: 0.0045, ContextWindow: 272000},
|
||||
{ID: "gpt-5.4-nano", Label: "GPT-5.4 Nano", InputPer1k: 0.0002, OutputPer1k: 0.00125, ContextWindow: 272000},
|
||||
{ID: "gpt-5.3-codex", Label: "GPT-5.3 Codex", InputPer1k: 0.00175, OutputPer1k: 0.014, ContextWindow: 272000},
|
||||
{ID: "gpt-5.3-chat-latest", Label: "GPT-5.3 Chat", InputPer1k: 0.00175, OutputPer1k: 0.014, ContextWindow: 128000},
|
||||
{ID: "o4-mini", Label: "o4-mini", InputPer1k: 0.0011, OutputPer1k: 0.0044, ContextWindow: 200000},
|
||||
{ID: "gpt-4.1", Label: "GPT-4.1", InputPer1k: 0.002, OutputPer1k: 0.008, ContextWindow: 1047576},
|
||||
{ID: "gpt-4.1-mini", Label: "GPT-4.1 mini", InputPer1k: 0.0004, OutputPer1k: 0.0016, ContextWindow: 1047576},
|
||||
{ID: "gpt-4.1-nano", Label: "GPT-4.1 nano", InputPer1k: 0.0001, OutputPer1k: 0.0004, ContextWindow: 1047576},
|
||||
{ID: "gpt-4o", Label: "GPT-4o", InputPer1k: 0.0025, OutputPer1k: 0.010, ContextWindow: 128000},
|
||||
{ID: "gpt-4o-mini", Label: "GPT-4o mini", InputPer1k: 0.00015, OutputPer1k: 0.0006, ContextWindow: 128000},
|
||||
{ID: "gpt-4-turbo", Label: "GPT-4 Turbo", InputPer1k: 0.01, OutputPer1k: 0.03, ContextWindow: 128000},
|
||||
{ID: "gpt-3.5-turbo", Label: "GPT-3.5 Turbo", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 16385},
|
||||
{ID: "text-embedding-3-large", Label: "text-embedding-3-large", InputPer1k: 0.00013, OutputPer1k: 0, ContextWindow: 8191},
|
||||
{ID: "text-embedding-3-small", Label: "text-embedding-3-small", InputPer1k: 0.00002, OutputPer1k: 0, ContextWindow: 8191},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "anthropic_api",
|
||||
Kind: KindProvider,
|
||||
Name: "Anthropic API",
|
||||
Description: "Claude Messages API",
|
||||
DefaultHost: "api.anthropic.com",
|
||||
AuthHeaderName: "x-api-key",
|
||||
AuthHeaderTemplate: "${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#D97757",
|
||||
ParserID: "anthropic",
|
||||
// Per Anthropic's current model lineup. Pricing in USD per 1k
|
||||
// tokens. Context windows: 4.6+ family is 1M; Haiku 4.5 stays at
|
||||
// 200K. claude-3-7-sonnet and claude-3-5-haiku retired
|
||||
// 2026-02-19 — dropped from the catalog. claude-opus-4-1
|
||||
// deprecated, retires 2026-08-05 — kept until the cutover.
|
||||
// claude-mythos-5 omitted: Project Glasswing access only, not a
|
||||
// general-availability target. claude-fable-5 requires the
|
||||
// account to be on >= 30-day data retention or all requests
|
||||
// 400.
|
||||
Models: []Model{
|
||||
{ID: "claude-fable-5", Label: "Claude Fable 5", InputPer1k: 0.010, OutputPer1k: 0.050, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-8", Label: "Claude Opus 4.8", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-7", Label: "Claude Opus 4.7", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-6", Label: "Claude Opus 4.6", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-1", Label: "Claude Opus 4.1 (deprecated, retires 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000},
|
||||
{ID: "claude-sonnet-4-6", Label: "Claude Sonnet 4.6", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000},
|
||||
{ID: "claude-sonnet-4-5", Label: "Claude Sonnet 4.5", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000},
|
||||
{ID: "claude-haiku-4-5", Label: "Claude Haiku 4.5", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "azure_openai_api",
|
||||
Kind: KindProvider,
|
||||
Name: "Azure OpenAI API",
|
||||
Description: "Azure-hosted OpenAI deployments",
|
||||
DefaultHost: "<resource>.openai.azure.com",
|
||||
AuthHeaderName: "api-key",
|
||||
AuthHeaderTemplate: "${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#0078D4",
|
||||
ParserID: "openai",
|
||||
// Mirrors openai_api pricing — Azure resells OpenAI models at the
|
||||
// same per-token rates, just under different deployment names.
|
||||
Models: []Model{
|
||||
{ID: "gpt-5.5", Label: "GPT-5.5 (Azure)", InputPer1k: 0.005, OutputPer1k: 0.030, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4", Label: "GPT-5.4 (Azure)", InputPer1k: 0.0025, OutputPer1k: 0.015, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4-mini", Label: "GPT-5.4 Mini (Azure)", InputPer1k: 0.00075, OutputPer1k: 0.0045, ContextWindow: 272000},
|
||||
{ID: "gpt-5.4-nano", Label: "GPT-5.4 Nano (Azure)", InputPer1k: 0.0002, OutputPer1k: 0.00125, ContextWindow: 272000},
|
||||
{ID: "o4-mini", Label: "o4-mini (Azure)", InputPer1k: 0.0011, OutputPer1k: 0.0044, ContextWindow: 200000},
|
||||
{ID: "gpt-4.1", Label: "GPT-4.1 (Azure)", InputPer1k: 0.002, OutputPer1k: 0.008, ContextWindow: 1047576},
|
||||
{ID: "gpt-4.1-mini", Label: "GPT-4.1 mini (Azure)", InputPer1k: 0.0004, OutputPer1k: 0.0016, ContextWindow: 1047576},
|
||||
{ID: "gpt-4o", Label: "GPT-4o (Azure)", InputPer1k: 0.0025, OutputPer1k: 0.010, ContextWindow: 128000},
|
||||
{ID: "gpt-4o-mini", Label: "GPT-4o mini (Azure)", InputPer1k: 0.00015, OutputPer1k: 0.0006, ContextWindow: 128000},
|
||||
{ID: "gpt-35-turbo", Label: "GPT-3.5 Turbo (Azure)", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 16385},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "bedrock_api",
|
||||
Kind: KindProvider,
|
||||
Name: "AWS Bedrock API",
|
||||
Description: "Anthropic, Meta, Cohere via Bedrock",
|
||||
DefaultHost: "bedrock-runtime.<region>.amazonaws.com",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#FF9900",
|
||||
// Anthropic models on Bedrock take the anthropic.* prefix and
|
||||
// follow the same lineup / pricing as the first-party Anthropic
|
||||
// catalog entry above. claude-3-7-sonnet and claude-3-5-haiku
|
||||
// were retired upstream on 2026-02-19 — dropped from the
|
||||
// Bedrock list too. Amazon Nova entries cross-checked against
|
||||
// LiteLLM (added Nova Micro + the new Nova 2 Lite preview).
|
||||
// Llama 3.3 70B entry kept unchanged — LiteLLM tracks only
|
||||
// per-region Llama 3 entries; standalone 3.3 not yet listed.
|
||||
Models: []Model{
|
||||
{ID: "anthropic.claude-opus-4-8", Label: "Claude Opus 4.8 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "anthropic.claude-opus-4-7", Label: "Claude Opus 4.7 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "anthropic.claude-opus-4-6", Label: "Claude Opus 4.6 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "anthropic.claude-opus-4-1", Label: "Claude Opus 4.1 (Bedrock, deprecated 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000},
|
||||
{ID: "anthropic.claude-sonnet-4-6", Label: "Claude Sonnet 4.6 (Bedrock)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000},
|
||||
{ID: "anthropic.claude-sonnet-4-5", Label: "Claude Sonnet 4.5 (Bedrock)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000},
|
||||
{ID: "anthropic.claude-haiku-4-5", Label: "Claude Haiku 4.5 (Bedrock)", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000},
|
||||
{ID: "meta.llama3-3-70b-instruct", Label: "Llama 3.3 70B (Bedrock)", InputPer1k: 0.00072, OutputPer1k: 0.00072, ContextWindow: 128000},
|
||||
{ID: "amazon.nova-2-lite", Label: "Amazon Nova 2 Lite (Bedrock, preview)", InputPer1k: 0.0003, OutputPer1k: 0.0025, ContextWindow: 1000000},
|
||||
{ID: "amazon.nova-pro", Label: "Amazon Nova Pro (Bedrock)", InputPer1k: 0.0008, OutputPer1k: 0.0032, ContextWindow: 300000},
|
||||
{ID: "amazon.nova-lite", Label: "Amazon Nova Lite (Bedrock)", InputPer1k: 0.00006, OutputPer1k: 0.00024, ContextWindow: 300000},
|
||||
{ID: "amazon.nova-micro", Label: "Amazon Nova Micro (Bedrock)", InputPer1k: 0.000035, OutputPer1k: 0.00014, ContextWindow: 128000},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "vertex_ai_api",
|
||||
Kind: KindProvider,
|
||||
Name: "Google Vertex AI API",
|
||||
Description: "Anthropic Claude models hosted on Vertex AI",
|
||||
DefaultHost: "<region>-aiplatform.googleapis.com",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#4285F4",
|
||||
// Vertex carries the model in the URL path and authenticates with a
|
||||
// service-account-minted OAuth token (api_key = "keyfile::<base64 SA>").
|
||||
// Only Anthropic-on-Vertex is metered today: the request parser maps the
|
||||
// anthropic publisher to the Anthropic parser, so the lineup + prices
|
||||
// mirror the first-party Anthropic catalog (LiteLLM vertex_ai/claude-*
|
||||
// confirms the same per-token rates; cross-region profiles in eu/apac
|
||||
// carry a ~10% premium that base pricing does not model). Gemini (the
|
||||
// google publisher) is intentionally omitted until a Gemini parser
|
||||
// exists — the router denies unmeterable publishers rather than forward
|
||||
// them uncounted.
|
||||
Models: []Model{
|
||||
{ID: "claude-fable-5", Label: "Claude Fable 5 (Vertex)", InputPer1k: 0.010, OutputPer1k: 0.050, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-8", Label: "Claude Opus 4.8 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-7", Label: "Claude Opus 4.7 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-6", Label: "Claude Opus 4.6 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-1", Label: "Claude Opus 4.1 (Vertex, deprecated 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000},
|
||||
{ID: "claude-sonnet-4-6", Label: "Claude Sonnet 4.6 (Vertex)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000},
|
||||
{ID: "claude-sonnet-4-5", Label: "Claude Sonnet 4.5 (Vertex)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000},
|
||||
{ID: "claude-haiku-4-5", Label: "Claude Haiku 4.5 (Vertex)", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "mistral_api",
|
||||
Kind: KindProvider,
|
||||
Name: "Mistral API",
|
||||
Description: "Mistral cloud API",
|
||||
DefaultHost: "api.mistral.ai",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#FF7000",
|
||||
ParserID: "openai",
|
||||
// Pricing + context windows cross-checked against LiteLLM. Key
|
||||
// gotchas the marketing page hides:
|
||||
// - `mistral-medium-latest` aliases to Medium 3.1 ($0.40/$2),
|
||||
// NOT Medium 3.5 ($1.50/$7.50). Catalog exposes both.
|
||||
// - `mistral-large-latest` aliases to Large 3 — 262K context,
|
||||
// cheaper than Medium 3.5.
|
||||
// - Magistral models are tuned for reasoning but cap context
|
||||
// at only 40K (vs 128K-262K elsewhere).
|
||||
// - `codestral-latest` still routes to the old 2405 build
|
||||
// ($1/$3) per LiteLLM; the newer codestral-2508 is both
|
||||
// cheaper and longer-context. Both exposed.
|
||||
// - Pixtral was folded into the main Large/Medium series; no
|
||||
// standalone vision entry.
|
||||
Models: []Model{
|
||||
{ID: "mistral-large-latest", Label: "Mistral Large 3", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 262144},
|
||||
{ID: "mistral-medium-latest", Label: "Mistral Medium 3.1", InputPer1k: 0.0004, OutputPer1k: 0.002, ContextWindow: 131072},
|
||||
{ID: "mistral-medium-3-5", Label: "Mistral Medium 3.5", InputPer1k: 0.0015, OutputPer1k: 0.0075, ContextWindow: 262144},
|
||||
{ID: "mistral-small-latest", Label: "Mistral Small 3.2", InputPer1k: 0.00006, OutputPer1k: 0.00018, ContextWindow: 131072},
|
||||
{ID: "magistral-medium-latest", Label: "Magistral Medium (reasoning)", InputPer1k: 0.002, OutputPer1k: 0.005, ContextWindow: 40000},
|
||||
{ID: "magistral-small-latest", Label: "Magistral Small (reasoning)", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 40000},
|
||||
{ID: "devstral-medium-latest", Label: "Devstral Medium 2 (coding)", InputPer1k: 0.0004, OutputPer1k: 0.002, ContextWindow: 256000},
|
||||
{ID: "devstral-small-latest", Label: "Devstral Small 2 (coding)", InputPer1k: 0.0001, OutputPer1k: 0.0003, ContextWindow: 256000},
|
||||
{ID: "codestral-2508", Label: "Codestral 2508", InputPer1k: 0.0003, OutputPer1k: 0.0009, ContextWindow: 256000},
|
||||
{ID: "codestral-latest", Label: "Codestral (legacy 2405)", InputPer1k: 0.001, OutputPer1k: 0.003, ContextWindow: 32000},
|
||||
{ID: "ministral-3-14b-2512", Label: "Ministral 3 14B", InputPer1k: 0.0002, OutputPer1k: 0.0002, ContextWindow: 262144},
|
||||
{ID: "ministral-8b-latest", Label: "Ministral 8B", InputPer1k: 0.00015, OutputPer1k: 0.00015, ContextWindow: 262144},
|
||||
{ID: "ministral-3-3b-2512", Label: "Ministral 3 3B", InputPer1k: 0.0001, OutputPer1k: 0.0001, ContextWindow: 131072},
|
||||
{ID: "mistral-embed", Label: "Mistral Embed", InputPer1k: 0.0001, OutputPer1k: 0, ContextWindow: 8192},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "litellm_proxy",
|
||||
Kind: KindGateway,
|
||||
Name: "LiteLLM Proxy",
|
||||
Description: "Bring your own LiteLLM proxy with NetBird identity stamped on every request",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#0EA5E9",
|
||||
ParserID: "openai",
|
||||
// IdentityInjection requires a LiteLLM virtual key minted with
|
||||
// metadata.allow_client_tags=true; the master key silently drops
|
||||
// caller tags. Tags go out via both the x-litellm-tags header and
|
||||
// body metadata.tags: LiteLLM enforces budgets from the body only,
|
||||
// so the header is the spend-tracking fallback when body injection
|
||||
// can't run. See the Agent Network provider docs for key setup.
|
||||
IdentityInjection: &IdentityInjection{
|
||||
HeaderPair: &HeaderPairInjection{
|
||||
EndUserIDHeader: "x-litellm-end-user-id",
|
||||
TagsHeader: "x-litellm-tags",
|
||||
TagsInBody: true,
|
||||
EndUserIDInBody: true,
|
||||
},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "portkey",
|
||||
Kind: KindGateway,
|
||||
Name: "Portkey AI Gateway",
|
||||
Description: "Portkey AI Gateway with NetBird identity stamped via x-portkey-metadata",
|
||||
DefaultHost: "api.portkey.ai",
|
||||
// Portkey hosted requires x-portkey-api-key (account key)
|
||||
// plus a routing decision per request. The simplest routing
|
||||
// path is a saved Portkey config id stamped via
|
||||
// x-portkey-config — operators paste the pc-... id once and
|
||||
// Portkey resolves the upstream provider + virtual key from
|
||||
// it. ExtraHeaders below surfaces the input. Alternative:
|
||||
// callers author "@org/model" in the body; both flows
|
||||
// coexist (per-request authoring still works without a
|
||||
// configured value).
|
||||
AuthHeaderName: "x-portkey-api-key",
|
||||
AuthHeaderTemplate: "${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#FF5C00",
|
||||
ParserID: "openai",
|
||||
IdentityInjection: &IdentityInjection{
|
||||
JSONMetadata: &JSONMetadataInjection{
|
||||
Header: "x-portkey-metadata",
|
||||
UserKey: "_user",
|
||||
GroupsKey: "groups",
|
||||
MaxValueLength: 128,
|
||||
},
|
||||
},
|
||||
ExtraHeaders: []ExtraHeader{
|
||||
{Name: "x-portkey-config"},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "bifrost",
|
||||
Kind: KindGateway,
|
||||
Name: "Bifrost",
|
||||
Description: "Maxim AI's Bifrost gateway. Point upstream URL at /openai/v1 or /anthropic/v1 on your Bifrost host depending on which body shape your apps use.",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#7C3AED",
|
||||
// ParserID empty: the proxy's request parser sniffs the URL
|
||||
// path. Bifrost's /openai/v1/... contains "/v1/chat/completions"
|
||||
// (matches OpenAIParser.DetectFromURL); /anthropic/v1/messages
|
||||
// contains "/v1/messages" (matches AnthropicParser). Operators
|
||||
// who paste a different prefix get no usage parsing and the
|
||||
// cost meter skips with skipMissingProvider — degraded but
|
||||
// non-fatal.
|
||||
ParserID: "",
|
||||
// Identity-injection headers are operator-customisable. The
|
||||
// HeaderPair values below are PLACEHOLDERS surfaced by the
|
||||
// dashboard; the actual values stamped on the wire come from
|
||||
// the provider record's IdentityHeaderUserID /
|
||||
// IdentityHeaderGroups fields. An empty operator value
|
||||
// disables stamping for that dimension (the inject middleware
|
||||
// already no-ops on empty header names). Defaulting to the
|
||||
// x-bf-dim- family so the values land in Bifrost's
|
||||
// Prometheus/OTEL pipelines when the operator declares the
|
||||
// label names in their client.prometheus_labels config — see
|
||||
// docs.getbifrost.ai/features/telemetry. Operators who use
|
||||
// the always-on x-bf-lh- log-metadata family (no Bifrost-side
|
||||
// declaration required) just edit the inputs.
|
||||
//
|
||||
// Bifrost virtual keys (sk-bf-*) ride Authorization: Bearer.
|
||||
// Operators provision the VK on their Bifrost (UI /
|
||||
// config.json / POST /api/governance/virtual-keys) and paste
|
||||
// the returned sk-bf-... as ${API_KEY}. Pin v1.4+ to avoid
|
||||
// the v1.3.0 x-bf-vk regression (maximhq/bifrost#632).
|
||||
IdentityInjection: &IdentityInjection{
|
||||
HeaderPair: &HeaderPairInjection{
|
||||
EndUserIDHeader: "x-bf-dim-netbird_user_id",
|
||||
TagsHeader: "x-bf-dim-netbird_groups",
|
||||
Customizable: true,
|
||||
},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "cloudflare_ai_gateway",
|
||||
Kind: KindGateway,
|
||||
Name: "Cloudflare AI Gateway",
|
||||
Description: "Cloudflare AI Gateway. Operator pastes the gateway URL (with the upstream provider slug like /openai or /anthropic so the URL sniffer dispatches to the right parser) and a per-gateway authentication token. Recommended setup is BYOK / Stored Keys: Cloudflare manages the upstream provider credential and the gateway token is the only secret NetBird needs.",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "cf-aig-authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#F38020",
|
||||
// ParserID empty: like Bifrost, the proxy's parser-detect
|
||||
// sniffs the URL path. /openai/... contains the OpenAI hint
|
||||
// substrings; /anthropic/v1/messages contains /v1/messages
|
||||
// (matches AnthropicParser). The /compat universal endpoint
|
||||
// also speaks OpenAI shape so OpenAIParser handles it.
|
||||
// Operators who paste a different prefix degrade to no-cost
|
||||
// (skipMissingProvider) but the request still flows.
|
||||
ParserID: "",
|
||||
// cf-aig-metadata is a single header carrying a JSON object;
|
||||
// up to five string/number/boolean values per request. NetBird
|
||||
// occupies two slots (user id + groups CSV) and leaves three
|
||||
// for operator-added context. JSON keys are operator-
|
||||
// customisable so Cloudflare-side log filters can use the
|
||||
// operator's existing label conventions instead of NetBird's
|
||||
// defaults — hence Customizable=true. The dashboard surfaces
|
||||
// the catalog values as placeholders; only the values stored
|
||||
// on the provider record's IdentityHeader* fields land on the
|
||||
// wire (empty operator value = key is omitted from the JSON,
|
||||
// since applyJSONMetadata already skips empty keys).
|
||||
IdentityInjection: &IdentityInjection{
|
||||
JSONMetadata: &JSONMetadataInjection{
|
||||
Header: "cf-aig-metadata",
|
||||
UserKey: "netbird_user_id",
|
||||
GroupsKey: "netbird_groups",
|
||||
Customizable: true,
|
||||
// Cloudflare's docs don't specify a per-value cap;
|
||||
// leaving 0 disables the truncate path. Header-level
|
||||
// constraint is "5 entries max" rather than length.
|
||||
MaxValueLength: 0,
|
||||
},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "vercel_ai_gateway",
|
||||
Kind: KindGateway,
|
||||
Name: "Vercel AI Gateway",
|
||||
Description: "Vercel's unified API for hundreds of models. Single endpoint, OpenAI-compatible body, model dispatch via prefix (openai/..., anthropic/..., google/..., xai/...). Per-user / per-tag attribution lands in Vercel's Custom Reporting API and observability dashboard.",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#000000",
|
||||
// Vercel always speaks OpenAI shape on /v1/chat/completions —
|
||||
// the model prefix in the body picks the upstream provider.
|
||||
// No URL sniffing needed; pin the parser directly.
|
||||
ParserID: "openai",
|
||||
// HeaderPair shape with fixed wire names dictated by Vercel's
|
||||
// Custom Reporting API contract. Customizable=false because
|
||||
// renaming the headers makes Vercel silently stop attributing
|
||||
// — the gateway's reporting endpoint only matches its own
|
||||
// header names. Same fixed-protocol position as LiteLLM.
|
||||
//
|
||||
// Caveats operators should know:
|
||||
// - up to 10 tags total per request (deduped); 11+ → HTTP 400
|
||||
// - each tag must be 1-64 chars
|
||||
// - user up to 256 chars (NetBird user emails fit)
|
||||
// - $0.075 per 1k unique user/tag values written
|
||||
// We don't enforce the caps in the inject middleware today;
|
||||
// operators in groups beyond the 10-tag limit will see Vercel
|
||||
// 400s and need to re-scope their group memberships.
|
||||
IdentityInjection: &IdentityInjection{
|
||||
HeaderPair: &HeaderPairInjection{
|
||||
EndUserIDHeader: "ai-reporting-user",
|
||||
TagsHeader: "ai-reporting-tags",
|
||||
},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "openrouter",
|
||||
Kind: KindGateway,
|
||||
Name: "OpenRouter",
|
||||
Description: "OpenRouter's unified API for hundreds of models. Single endpoint at openrouter.ai/api/v1, OpenAI-compatible body, model dispatch via prefix (anthropic/claude-..., openai/gpt-..., google/gemini-..., etc.). Per-user attribution lands in OpenRouter's analytics via the OpenAI-standard `user` body field; OpenRouter has no groups / tags dimension at request time.",
|
||||
DefaultHost: "openrouter.ai/api/v1",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#6F4FF2",
|
||||
// OpenRouter is single-endpoint OpenAI-shape on /api/v1/chat/completions —
|
||||
// model prefix in the body picks the upstream provider.
|
||||
// Pinning the parser saves URL sniffing.
|
||||
ParserID: "openai",
|
||||
// HeaderPair shape with EndUserIDInBody as the only active
|
||||
// dimension. OpenRouter's per-user attribution is the
|
||||
// OpenAI-standard `user` body field, not a header — and
|
||||
// OpenRouter offers no per-request groups / tags dimension at
|
||||
// all. Customizable=false because the field name is locked by
|
||||
// OpenAI's spec; renaming would just defeat the inject.
|
||||
IdentityInjection: &IdentityInjection{
|
||||
HeaderPair: &HeaderPairInjection{
|
||||
EndUserIDInBody: true,
|
||||
},
|
||||
},
|
||||
// HTTP-Referer + X-OpenRouter-Title surface in OpenRouter's
|
||||
// app rankings and per-app analytics. Operators paste their
|
||||
// own app URL + display name on the provider record so their
|
||||
// requests show under their brand instead of "no app". Both
|
||||
// are static per-deployment, not per-request, hence the
|
||||
// ExtraHeaders mechanism (operator-typed value, stamped on
|
||||
// every request to this provider). Skip X-OpenRouter-Categories
|
||||
// for now — the marketplace-categories dimension is
|
||||
// niche-enough that we'd add it on demand.
|
||||
ExtraHeaders: []ExtraHeader{
|
||||
{Name: "HTTP-Referer"},
|
||||
{Name: "X-OpenRouter-Title"},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "custom",
|
||||
Kind: KindCustom,
|
||||
Name: "Custom / Self-hosted",
|
||||
Description: "OpenAI-compatible endpoint (vLLM, Ollama, …)",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#9CA3AF",
|
||||
Models: []Model{},
|
||||
},
|
||||
}
|
||||
|
||||
// All returns a copy of the full catalog.
|
||||
func All() []Provider {
|
||||
out := make([]Provider, len(providers))
|
||||
copy(out, providers)
|
||||
return out
|
||||
}
|
||||
|
||||
// Lookup returns the catalog entry with the given id, if any.
|
||||
func Lookup(id string) (Provider, bool) {
|
||||
for _, p := range providers {
|
||||
if p.ID == id {
|
||||
return p, true
|
||||
}
|
||||
}
|
||||
return Provider{}, false
|
||||
}
|
||||
|
||||
// IsKnown reports whether the given id refers to a catalog entry.
|
||||
func IsKnown(id string) bool {
|
||||
_, ok := Lookup(id)
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsVertexPathStyle reports whether a provider uses the Google Vertex AI
|
||||
// request shape — the model is carried in the URL path
|
||||
// (/v1/projects/{p}/locations/{r}/publishers/{pub}/models/{model}:{action})
|
||||
// rather than the body, so the proxy routes it by path instead of by model.
|
||||
func IsVertexPathStyle(providerID string) bool {
|
||||
return providerID == "vertex_ai_api"
|
||||
}
|
||||
|
||||
// IsBedrockPathStyle reports whether a provider uses the AWS Bedrock request
|
||||
// shape — the model is carried in the URL path (/model/{modelId}/{action},
|
||||
// action being invoke, invoke-with-response-stream, converse, or
|
||||
// converse-stream) rather than the body, so the proxy routes it by path.
|
||||
func IsBedrockPathStyle(providerID string) bool {
|
||||
return providerID == "bedrock_api"
|
||||
}
|
||||
|
||||
// ToAPIResponse renders a catalog provider as the API representation.
|
||||
func (p Provider) ToAPIResponse() api.AgentNetworkCatalogProvider {
|
||||
models := make([]api.AgentNetworkCatalogModel, 0, len(p.Models))
|
||||
for _, m := range p.Models {
|
||||
models = append(models, api.AgentNetworkCatalogModel{
|
||||
Id: m.ID,
|
||||
Label: m.Label,
|
||||
InputPer1k: m.InputPer1k,
|
||||
OutputPer1k: m.OutputPer1k,
|
||||
ContextWindow: m.ContextWindow,
|
||||
})
|
||||
}
|
||||
kind := api.AgentNetworkCatalogProviderKindProvider
|
||||
switch p.Kind {
|
||||
case KindGateway:
|
||||
kind = api.AgentNetworkCatalogProviderKindGateway
|
||||
case KindCustom:
|
||||
kind = api.AgentNetworkCatalogProviderKindCustom
|
||||
}
|
||||
resp := api.AgentNetworkCatalogProvider{
|
||||
Id: p.ID,
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
DefaultHost: p.DefaultHost,
|
||||
Kind: kind,
|
||||
AuthHeaderTemplate: p.AuthHeaderTemplate,
|
||||
DefaultContentType: p.DefaultContentType,
|
||||
BrandColor: p.BrandColor,
|
||||
Models: models,
|
||||
}
|
||||
if len(p.ExtraHeaders) > 0 {
|
||||
extras := make([]api.AgentNetworkCatalogExtraHeader, 0, len(p.ExtraHeaders))
|
||||
for _, h := range p.ExtraHeaders {
|
||||
extras = append(extras, api.AgentNetworkCatalogExtraHeader{
|
||||
Name: h.Name,
|
||||
})
|
||||
}
|
||||
resp.ExtraHeaders = &extras
|
||||
}
|
||||
// Surface IdentityInjection so the dashboard can decide whether
|
||||
// to render editable inputs vs. a read-only mappings strip per
|
||||
// shape's customizable flag. HeaderPair (Bifrost) and
|
||||
// JSONMetadata (Cloudflare, Portkey) are mutually exclusive on a
|
||||
// given catalog entry; emit whichever shape is set.
|
||||
if p.IdentityInjection != nil {
|
||||
injection := &api.AgentNetworkCatalogIdentityInjection{}
|
||||
if hp := p.IdentityInjection.HeaderPair; hp != nil {
|
||||
injection.HeaderPair = &api.AgentNetworkCatalogHeaderPairInjection{
|
||||
Customizable: hp.Customizable,
|
||||
EndUserIdHeader: hp.EndUserIDHeader,
|
||||
TagsHeader: hp.TagsHeader,
|
||||
}
|
||||
}
|
||||
if jm := p.IdentityInjection.JSONMetadata; jm != nil {
|
||||
injection.JsonMetadata = &api.AgentNetworkCatalogJSONMetadataInjection{
|
||||
Customizable: jm.Customizable,
|
||||
Header: jm.Header,
|
||||
UserKey: jm.UserKey,
|
||||
GroupsKey: jm.GroupsKey,
|
||||
}
|
||||
}
|
||||
if injection.HeaderPair != nil || injection.JsonMetadata != nil {
|
||||
resp.IdentityInjection = injection
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
// addAccessLogEndpoints registers the read-only, server-side-filtered
|
||||
// agent-network access-log listing and the aggregated usage overview.
|
||||
func (h *handler) addAccessLogEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/access-logs", h.listAccessLogs).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/usage/overview", h.getUsageOverview).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getUsageOverview(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
// Reuse the access-log filter for the shared date/user/group/provider/model
|
||||
// params; pagination/sort/search are irrelevant for an aggregate.
|
||||
var filter types.AgentNetworkAccessLogFilter
|
||||
if err := filter.ParseFromRequest(r); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
// Bound the aggregation window so an unbounded or over-wide query can't load
|
||||
// an account's entire usage history into memory.
|
||||
filter.ApplyUsageOverviewBounds(time.Now())
|
||||
granularity := types.ParseUsageGranularity(r.URL.Query().Get("granularity"))
|
||||
|
||||
buckets, err := h.manager.GetUsageOverview(r.Context(), userAuth.AccountId, userAuth.UserId, filter, granularity)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]api.AgentNetworkUsageBucket, 0, len(buckets))
|
||||
for _, b := range buckets {
|
||||
out = append(out, b.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) listAccessLogs(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var filter types.AgentNetworkAccessLogFilter
|
||||
if err := filter.ParseFromRequest(r); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rows, total, err := h.manager.ListAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId, filter)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
data := make([]api.AgentNetworkAccessLog, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
data = append(data, row.ToAPIResponse())
|
||||
}
|
||||
|
||||
pageSize := filter.GetLimit()
|
||||
totalPages := 0
|
||||
if pageSize > 0 {
|
||||
totalPages = int((total + int64(pageSize) - 1) / int64(pageSize))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, api.AgentNetworkAccessLogsResponse{
|
||||
Data: data,
|
||||
Page: filter.Page,
|
||||
PageSize: pageSize,
|
||||
TotalRecords: int(total),
|
||||
TotalPages: totalPages,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// addBudgetRuleEndpoints registers the account-level budget rule routes.
|
||||
func (h *handler) addBudgetRuleEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/budget-rules", h.getAllBudgetRules).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/budget-rules", h.createBudgetRule).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.getBudgetRule).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.updateBudgetRule).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.deleteBudgetRule).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllBudgetRules(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rules, err := h.manager.GetAllBudgetRules(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*api.AgentNetworkBudgetRule, 0, len(rules))
|
||||
for _, rule := range rules {
|
||||
out = append(out, rule.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getBudgetRule(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ruleID := mux.Vars(r)["ruleId"]
|
||||
if ruleID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
rule, err := h.manager.GetBudgetRule(r.Context(), userAuth.AccountId, userAuth.UserId, ruleID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, rule.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) createBudgetRule(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkBudgetRuleRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateBudgetRule(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rule := types.NewAccountBudgetRule(userAuth.AccountId)
|
||||
rule.FromAPIRequest(&req)
|
||||
|
||||
created, err := h.manager.CreateBudgetRule(r.Context(), userAuth.UserId, rule)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateBudgetRule(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ruleID := mux.Vars(r)["ruleId"]
|
||||
if ruleID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkBudgetRuleRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateBudgetRule(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rule := &types.AccountBudgetRule{ID: ruleID, AccountID: userAuth.AccountId}
|
||||
rule.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdateBudgetRule(r.Context(), userAuth.UserId, rule)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteBudgetRule(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ruleID := mux.Vars(r)["ruleId"]
|
||||
if ruleID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeleteBudgetRule(r.Context(), userAuth.AccountId, userAuth.UserId, ruleID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
// validateBudgetRule rejects malformed budget rules. It reuses the policy limit
|
||||
// validation since the cap shape is identical, and rejects empty target entries.
|
||||
func validateBudgetRule(req *api.AgentNetworkBudgetRuleRequest) error {
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
if req.TargetGroups != nil {
|
||||
for _, id := range *req.TargetGroups {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "target_groups must not contain empty entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.TargetUsers != nil {
|
||||
for _, id := range *req.TargetUsers {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "target_users must not contain empty entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
return validatePolicyLimits(req.Limits)
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// TestBudgetRuleHandler_RoundTrip seeds a budget rule via the store and asserts
|
||||
// the GET wire shape carries targets and the reused PolicyLimits cap shape. The
|
||||
// create/update/delete success paths go through accountManager.StoreEvent which
|
||||
// this fixture doesn't wire — they are covered by the manager-level no-mock
|
||||
// test (TestAgentNetwork_BudgetRuleCRUD_RealManager).
|
||||
func TestBudgetRuleHandler_RoundTrip(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
rule := &agentNetworkTypes.AccountBudgetRule{
|
||||
ID: "ainbud_test",
|
||||
AccountID: testAccountID,
|
||||
Name: "org-monthly",
|
||||
Enabled: true,
|
||||
TargetGroups: []string{"grp-eng"},
|
||||
TargetUsers: []string{"user-alice"},
|
||||
Limits: agentNetworkTypes.PolicyLimits{
|
||||
TokenLimit: agentNetworkTypes.PolicyTokenLimit{Enabled: true, GroupCap: 100000, UserCap: 10000, WindowSeconds: 2_592_000},
|
||||
BudgetLimit: agentNetworkTypes.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 500, WindowSeconds: 2_592_000},
|
||||
},
|
||||
}
|
||||
require.NoError(t, f.store.SaveAgentNetworkBudgetRule(context.Background(), rule))
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/budget-rules/"+rule.ID, "")
|
||||
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
|
||||
|
||||
var got api.AgentNetworkBudgetRule
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
assert.Equal(t, "org-monthly", got.Name, "name must round-trip")
|
||||
assert.Equal(t, []string{"grp-eng"}, got.TargetGroups, "target groups must round-trip")
|
||||
assert.Equal(t, []string{"user-alice"}, got.TargetUsers, "target users must round-trip")
|
||||
assert.Equal(t, int64(100000), got.Limits.TokenLimit.GroupCap, "token group cap must round-trip")
|
||||
assert.Equal(t, int64(2_592_000), got.Limits.BudgetLimit.WindowSeconds, "budget window must round-trip")
|
||||
}
|
||||
|
||||
// TestBudgetRuleHandler_ListReturnsArray asserts the list endpoint returns a
|
||||
// JSON array (never null) for an account with no rules.
|
||||
func TestBudgetRuleHandler_ListReturnsArray(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/budget-rules", "")
|
||||
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
|
||||
assert.Equal(t, "[]", trimSpace(rec.Body.String()), "empty account must return an empty array, not null")
|
||||
}
|
||||
|
||||
// TestBudgetRuleHandler_RejectsMissingName covers the validation path (which
|
||||
// runs before the manager call, so it works without a wired accountManager).
|
||||
func TestBudgetRuleHandler_RejectsMissingName(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
body := `{
|
||||
"name": "",
|
||||
"limits": {
|
||||
"token_limit": {"enabled": false, "group_cap": 0, "user_cap": 0, "window_seconds": 0},
|
||||
"budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0}
|
||||
}
|
||||
}`
|
||||
rec := f.do(t, http.MethodPost, "/agent-network/budget-rules", body)
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code,
|
||||
"missing name must be rejected as a validation error (not a route/auth 4xx): got %d body=%s", rec.Code, rec.Body.String())
|
||||
assert.Contains(t, rec.Body.String(), "name",
|
||||
"rejection body must name the offending field, proving the validation path: %s", rec.Body.String())
|
||||
}
|
||||
|
||||
// TestBudgetRuleHandler_RejectsSubMinuteWindow proves budget rules reuse the
|
||||
// policy-limit validation (enabled limit needs window >= 60s).
|
||||
func TestBudgetRuleHandler_RejectsSubMinuteWindow(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
body := `{
|
||||
"name": "bad-window",
|
||||
"limits": {
|
||||
"token_limit": {"enabled": true, "group_cap": 1000, "user_cap": 0, "window_seconds": 30},
|
||||
"budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0}
|
||||
}
|
||||
}`
|
||||
rec := f.do(t, http.MethodPost, "/agent-network/budget-rules", body)
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code,
|
||||
"sub-minute window must be rejected as a validation error (not a route/auth 4xx): got %d body=%s", rec.Code, rec.Body.String())
|
||||
assert.Contains(t, rec.Body.String(), "window_seconds",
|
||||
"rejection body must name the offending window_seconds field, proving the validation path: %s", rec.Body.String())
|
||||
}
|
||||
|
||||
// TestSettingsHandler_GetExposesCollectionToggles asserts the GET settings wire
|
||||
// shape carries the account-level collection toggles after a store seed.
|
||||
func TestSettingsHandler_GetExposesCollectionToggles(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
require.NoError(t, f.store.SaveAgentNetworkSettings(context.Background(), &agentNetworkTypes.Settings{
|
||||
AccountID: testAccountID,
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
Subdomain: "violet",
|
||||
EnableLogCollection: true,
|
||||
EnablePromptCollection: true,
|
||||
RedactPii: false,
|
||||
}))
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/settings", "")
|
||||
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
|
||||
|
||||
var got api.AgentNetworkSettings
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
assert.True(t, got.EnableLogCollection, "log collection toggle must surface on the wire")
|
||||
assert.True(t, got.EnablePromptCollection, "prompt collection toggle must surface on the wire")
|
||||
assert.False(t, got.RedactPii, "redact toggle must surface its false value")
|
||||
assert.Equal(t, "violet.eu.proxy.netbird.io", got.Endpoint, "endpoint stays computed from immutable cluster+subdomain")
|
||||
}
|
||||
|
||||
func trimSpace(s string) string {
|
||||
for len(s) > 0 && (s[len(s)-1] == '\n' || s[len(s)-1] == ' ' || s[len(s)-1] == '\t' || s[len(s)-1] == '\r') {
|
||||
s = s[:len(s)-1]
|
||||
}
|
||||
for len(s) > 0 && (s[0] == '\n' || s[0] == ' ' || s[0] == '\t' || s[0] == '\r') {
|
||||
s = s[1:]
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
// addConsumptionEndpoints registers the read-only Agent Network
|
||||
// consumption listing — backs the dashboard's basic counter view.
|
||||
func (h *handler) addConsumptionEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/consumption", h.listConsumption).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) listConsumption(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.manager.ListConsumption(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]api.AgentNetworkConsumption, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, consumptionToAPI(row))
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func consumptionToAPI(c *types.Consumption) api.AgentNetworkConsumption {
|
||||
windowStart := c.WindowStartUTC
|
||||
updatedAt := c.UpdatedAt
|
||||
return api.AgentNetworkConsumption{
|
||||
DimensionKind: api.AgentNetworkConsumptionDimensionKind(c.DimensionKind),
|
||||
DimensionId: c.DimensionID,
|
||||
WindowSeconds: c.WindowSeconds,
|
||||
WindowStartUtc: windowStart,
|
||||
TokensInput: c.TokensInput,
|
||||
TokensOutput: c.TokensOutput,
|
||||
CostUsd: c.CostUSD,
|
||||
UpdatedAt: &updatedAt,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// addGuardrailEndpoints registers all Agent Network guardrail routes.
|
||||
func (h *handler) addGuardrailEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/guardrails", h.getAllGuardrails).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/guardrails", h.createGuardrail).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.getGuardrail).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.updateGuardrail).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.deleteGuardrail).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllGuardrails(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrails, err := h.manager.GetAllGuardrails(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*api.AgentNetworkGuardrail, 0, len(guardrails))
|
||||
for _, g := range guardrails {
|
||||
out = append(out, g.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getGuardrail(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrailID := mux.Vars(r)["guardrailId"]
|
||||
if guardrailID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrail, err := h.manager.GetGuardrail(r.Context(), userAuth.AccountId, userAuth.UserId, guardrailID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, guardrail.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) createGuardrail(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkGuardrailRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateGuardrail(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrail := types.NewGuardrail(userAuth.AccountId)
|
||||
guardrail.FromAPIRequest(&req)
|
||||
|
||||
created, err := h.manager.CreateGuardrail(r.Context(), userAuth.UserId, guardrail)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateGuardrail(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrailID := mux.Vars(r)["guardrailId"]
|
||||
if guardrailID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkGuardrailRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateGuardrail(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrail := &types.Guardrail{
|
||||
ID: guardrailID,
|
||||
AccountID: userAuth.AccountId,
|
||||
}
|
||||
guardrail.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdateGuardrail(r.Context(), userAuth.UserId, guardrail)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteGuardrail(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrailID := mux.Vars(r)["guardrailId"]
|
||||
if guardrailID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeleteGuardrail(r.Context(), userAuth.AccountId, userAuth.UserId, guardrailID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func validateGuardrail(req *api.AgentNetworkGuardrailRequest) error {
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
|
||||
c := req.Checks
|
||||
if c.ModelAllowlist.Enabled {
|
||||
for _, id := range c.ModelAllowlist.Models {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "model_allowlist.models must not contain empty entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,256 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
const (
|
||||
testAccountID = "acc-1"
|
||||
testUserID = "user-bob"
|
||||
)
|
||||
|
||||
// agentNetworkHandlerFixture builds a real agentnetwork.Manager with
|
||||
// a sqlite store and an always-allow permissions mock, then exposes
|
||||
// the HTTP handlers via a gorilla router. Tests issue requests
|
||||
// through httptest and assert on the wire shape — the same path the
|
||||
// dashboard exercises.
|
||||
type agentNetworkHandlerFixture struct {
|
||||
store store.Store
|
||||
manager agentnetwork.Manager
|
||||
router *mux.Router
|
||||
}
|
||||
|
||||
func newAgentNetworkHandlerFixture(t *testing.T) *agentNetworkHandlerFixture {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("sqlite store not properly supported on Windows yet")
|
||||
}
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(nbtypes.SqliteStoreEngine))
|
||||
|
||||
st, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
perms := permissions.NewMockManager(ctrl)
|
||||
// Always-allow: the handler tests are about wire shape, not
|
||||
// authz. Authz is covered by the manager's own tests.
|
||||
perms.EXPECT().
|
||||
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(true, context.Background(), nil).
|
||||
AnyTimes()
|
||||
|
||||
manager := agentnetwork.NewManager(st, perms, nil, nil)
|
||||
h := &handler{manager: manager}
|
||||
|
||||
router := mux.NewRouter()
|
||||
h.addPolicyEndpoints(router)
|
||||
h.addConsumptionEndpoints(router)
|
||||
h.addBudgetRuleEndpoints(router)
|
||||
h.addSettingsEndpoints(router)
|
||||
|
||||
return &agentNetworkHandlerFixture{
|
||||
store: st,
|
||||
manager: manager,
|
||||
router: router,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *agentNetworkHandlerFixture) do(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
var reader io.Reader
|
||||
if body != "" {
|
||||
reader = strings.NewReader(body)
|
||||
}
|
||||
req := httptest.NewRequest(method, path, reader)
|
||||
if body != "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
f.router.ServeHTTP(rec, req)
|
||||
return rec
|
||||
}
|
||||
|
||||
// seedProvider persists a minimal provider record so policy create
|
||||
// passes the manager's destination_provider_ids existence check.
|
||||
func (f *agentNetworkHandlerFixture) seedProvider(t *testing.T, id string) {
|
||||
t.Helper()
|
||||
require.NoError(t, f.store.SaveAgentNetworkProvider(context.Background(), &agentNetworkTypes.Provider{
|
||||
ID: id,
|
||||
AccountID: testAccountID,
|
||||
ProviderID: "openai_api",
|
||||
Name: "test-" + id,
|
||||
UpstreamURL: "https://api.openai.com",
|
||||
APIKey: "sk-test",
|
||||
Enabled: true,
|
||||
SessionPrivateKey: "test-priv-key",
|
||||
SessionPublicKey: "test-pub-key",
|
||||
}))
|
||||
}
|
||||
|
||||
// TestPolicyHandler_WindowSecondsRoundTrip ports bash 10 to Go:
|
||||
// assert that a policy with window_seconds on both Token + Budget
|
||||
// halves round-trips through GET unchanged AND that legacy
|
||||
// window_hours / window_days are absent from the JSON response. We
|
||||
// seed the policy directly via the store rather than POST-ing
|
||||
// because the create path goes through the manager's
|
||||
// accountManager.StoreEvent which we don't wire in this fixture; the
|
||||
// on-wire shape is what matters here, and the POST validation path
|
||||
// is covered separately by the RejectsSubMinuteWindow test.
|
||||
func TestPolicyHandler_WindowSecondsRoundTrip(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
policy := &agentNetworkTypes.Policy{
|
||||
ID: "ainpol_test",
|
||||
AccountID: testAccountID,
|
||||
Name: "round-trip",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-engineers"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: agentNetworkTypes.PolicyLimits{
|
||||
TokenLimit: agentNetworkTypes.PolicyTokenLimit{Enabled: true, GroupCap: 10000, UserCap: 5000, WindowSeconds: 86_400},
|
||||
BudgetLimit: agentNetworkTypes.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 10.0, UserCapUsd: 2.5, WindowSeconds: 2_592_000},
|
||||
},
|
||||
}
|
||||
require.NoError(t, f.store.SaveAgentNetworkPolicy(context.Background(), policy))
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/policies/"+policy.ID, "")
|
||||
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
|
||||
|
||||
var got api.AgentNetworkPolicy
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
assert.Equal(t, int64(86_400), got.Limits.TokenLimit.WindowSeconds, "token_limit.window_seconds must round-trip")
|
||||
assert.Equal(t, int64(2_592_000), got.Limits.BudgetLimit.WindowSeconds, "budget_limit.window_seconds must round-trip")
|
||||
|
||||
// Legacy field names must NOT appear in the response — would
|
||||
// signal that the management server is still emitting the old
|
||||
// shape and would fool a v1 dashboard into rendering days/hours.
|
||||
assert.NotContains(t, rec.Body.String(), "window_hours",
|
||||
"legacy window_hours field must be absent from the on-wire response")
|
||||
assert.NotContains(t, rec.Body.String(), "window_days",
|
||||
"legacy window_days field must be absent from the on-wire response")
|
||||
}
|
||||
|
||||
// TestPolicyHandler_RejectsSubMinuteWindow ports bash 20 to Go: an
|
||||
// enabled limit with window_seconds < 60 must surface as a 4xx
|
||||
// because anything finer than per-minute produces an untenable
|
||||
// volume of consumption rows for a feature whose value comes from
|
||||
// per-window cap enforcement.
|
||||
func TestPolicyHandler_RejectsSubMinuteWindow(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
f.seedProvider(t, "prov-1")
|
||||
|
||||
body := `{
|
||||
"name": "sub-minute-window",
|
||||
"enabled": true,
|
||||
"source_groups": ["grp-engineers"],
|
||||
"destination_provider_ids": ["prov-1"],
|
||||
"guardrail_ids": [],
|
||||
"limits": {
|
||||
"token_limit": {"enabled": true, "group_cap": 10000, "user_cap": 5000, "window_seconds": 30},
|
||||
"budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0}
|
||||
}
|
||||
}`
|
||||
rec := f.do(t, http.MethodPost, "/agent-network/policies", body)
|
||||
// 422 specifically (InvalidArgument) proves the window-validation path —
|
||||
// a route miss would be 404 and an auth failure 403, so a generic 4xx
|
||||
// would let those false-pass.
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code,
|
||||
"enabled token_limit with window_seconds<60 must be rejected as a validation error: got %d body=%s", rec.Code, rec.Body.String())
|
||||
assert.Contains(t, rec.Body.String(), "window_seconds",
|
||||
"rejection body must name the offending window_seconds field, proving it's the validation path: %s", rec.Body.String())
|
||||
}
|
||||
|
||||
// TestConsumptionHandler_EmptyAccountReturnsArray ports bash 30 to
|
||||
// Go: GET /agent-network/consumption on a clean account always
|
||||
// returns a JSON array (possibly empty), never a 404 / 500. The
|
||||
// dashboard depends on this shape to render its empty state.
|
||||
func TestConsumptionHandler_EmptyAccountReturnsArray(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/consumption", "")
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var rows []api.AgentNetworkConsumption
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rows),
|
||||
"response must always be a JSON array — even when empty: %s", rec.Body.String())
|
||||
assert.Empty(t, rows)
|
||||
}
|
||||
|
||||
// TestConsumptionHandler_PopulatedAccountListsRows mirrors the
|
||||
// /consumption read after a few RecordConsumption calls. Validates
|
||||
// the wire shape carries every field the dashboard reads (dim_kind,
|
||||
// dim_id, window_seconds, window_start_utc, tokens, cost_usd) and
|
||||
// rows are ordered window-newest-first.
|
||||
func TestConsumptionHandler_PopulatedAccountListsRows(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
require.NoError(t, f.manager.RecordConsumption(
|
||||
context.Background(), testAccountID,
|
||||
agentNetworkTypes.DimensionGroup, "grp-engineers",
|
||||
86_400, 100, 50, 0.0125,
|
||||
))
|
||||
require.NoError(t, f.manager.RecordConsumption(
|
||||
context.Background(), testAccountID,
|
||||
agentNetworkTypes.DimensionUser, testUserID,
|
||||
86_400, 100, 50, 0.0125,
|
||||
))
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/consumption", "")
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var rows []api.AgentNetworkConsumption
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rows))
|
||||
require.Len(t, rows, 2, "two RecordConsumption calls must yield two rows")
|
||||
|
||||
// Index by dim_kind so we can assert the full wire shape of each row,
|
||||
// including the dimension id and the aligned window start the dashboard
|
||||
// keys on. Both rows share totals and window.
|
||||
byKind := make(map[string]api.AgentNetworkConsumption, len(rows))
|
||||
for _, row := range rows {
|
||||
assert.Equal(t, int64(100), row.TokensInput)
|
||||
assert.Equal(t, int64(50), row.TokensOutput)
|
||||
assert.InDelta(t, 0.0125, row.CostUsd, 1e-9)
|
||||
assert.Equal(t, int64(86_400), row.WindowSeconds)
|
||||
assert.False(t, row.WindowStartUtc.IsZero(), "window_start_utc must be set on every row")
|
||||
byKind[string(row.DimensionKind)] = row
|
||||
}
|
||||
|
||||
groupRow, ok := byKind["group"]
|
||||
require.True(t, ok, "group dimension must surface")
|
||||
assert.Equal(t, "grp-engineers", groupRow.DimensionId, "group row must carry the source group id as dimension_id")
|
||||
|
||||
userRow, ok := byKind["user"]
|
||||
require.True(t, ok, "user dimension must surface")
|
||||
assert.Equal(t, testUserID, userRow.DimensionId, "user row must carry the user id as dimension_id")
|
||||
|
||||
// Both rows fall in the same aligned window (same length, recorded
|
||||
// together), so window_start_utc must match across them.
|
||||
assert.Equal(t, groupRow.WindowStartUtc, userRow.WindowStartUtc,
|
||||
"rows recorded in the same window must share the aligned window_start_utc")
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// minWindowSeconds is the floor enforced on enabled token / budget
|
||||
// limit windows. One minute is short enough for fine-grained burst
|
||||
// control without producing untenable consumption-row volume at scale.
|
||||
const minWindowSeconds int64 = 60
|
||||
|
||||
// addPolicyEndpoints registers all Agent Network policy routes on the
|
||||
// shared handler.
|
||||
func (h *handler) addPolicyEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/policies", h.getAllPolicies).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/policies", h.createPolicy).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/policies/{policyId}", h.getPolicy).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/policies/{policyId}", h.updatePolicy).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/policies/{policyId}", h.deletePolicy).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policies, err := h.manager.GetAllPolicies(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*api.AgentNetworkPolicy, 0, len(policies))
|
||||
for _, p := range policies {
|
||||
out = append(out, p.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policyID := mux.Vars(r)["policyId"]
|
||||
if policyID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
policy, err := h.manager.GetPolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policyID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, policy.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkPolicyRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validatePolicy(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policy := types.NewPolicy(userAuth.AccountId)
|
||||
policy.FromAPIRequest(&req)
|
||||
|
||||
created, err := h.manager.CreatePolicy(r.Context(), userAuth.UserId, policy)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policyID := mux.Vars(r)["policyId"]
|
||||
if policyID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkPolicyRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validatePolicy(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policy := &types.Policy{
|
||||
ID: policyID,
|
||||
AccountID: userAuth.AccountId,
|
||||
}
|
||||
policy.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdatePolicy(r.Context(), userAuth.UserId, policy)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policyID := mux.Vars(r)["policyId"]
|
||||
if policyID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeletePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policyID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func validatePolicy(req *api.AgentNetworkPolicyRequest) error {
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
if len(req.SourceGroups) == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "source_groups must contain at least one group id")
|
||||
}
|
||||
for _, id := range req.SourceGroups {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "source_groups must not contain empty entries")
|
||||
}
|
||||
}
|
||||
if len(req.DestinationProviderIds) == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "destination_provider_ids must contain at least one provider id")
|
||||
}
|
||||
for _, id := range req.DestinationProviderIds {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "destination_provider_ids must not contain empty entries")
|
||||
}
|
||||
}
|
||||
if req.GuardrailIds != nil {
|
||||
for _, id := range *req.GuardrailIds {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "guardrail_ids must not contain empty entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.Limits != nil {
|
||||
if err := validatePolicyLimits(*req.Limits); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePolicyLimits(l api.AgentNetworkPolicyLimits) error {
|
||||
if l.TokenLimit.Enabled {
|
||||
if l.TokenLimit.WindowSeconds < minWindowSeconds {
|
||||
return status.Errorf(status.InvalidArgument, "limits.token_limit.window_seconds must be at least %d (one minute) when enabled", minWindowSeconds)
|
||||
}
|
||||
if l.TokenLimit.GroupCap < 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.token_limit.group_cap must not be negative")
|
||||
}
|
||||
if l.TokenLimit.UserCap < 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.token_limit.user_cap must not be negative")
|
||||
}
|
||||
if l.TokenLimit.GroupCap == 0 && l.TokenLimit.UserCap == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.token_limit requires group_cap or user_cap to be greater than zero when enabled")
|
||||
}
|
||||
}
|
||||
if l.BudgetLimit.Enabled {
|
||||
if l.BudgetLimit.WindowSeconds < minWindowSeconds {
|
||||
return status.Errorf(status.InvalidArgument, "limits.budget_limit.window_seconds must be at least %d (one minute) when enabled", minWindowSeconds)
|
||||
}
|
||||
if l.BudgetLimit.GroupCapUsd < 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.budget_limit.group_cap_usd must not be negative")
|
||||
}
|
||||
if l.BudgetLimit.UserCapUsd < 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.budget_limit.user_cap_usd must not be negative")
|
||||
}
|
||||
if l.BudgetLimit.GroupCapUsd == 0 && l.BudgetLimit.UserCapUsd == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.budget_limit requires group_cap_usd or user_cap_usd to be greater than zero when enabled")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
// Package handlers serves the Agent Network HTTP API.
|
||||
//
|
||||
// All persistence is delegated to agentnetwork.Manager so this layer only
|
||||
// translates between the wire format (api.AgentNetworkProvider*) and the
|
||||
// domain types.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/catalog"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
manager agentnetwork.Manager
|
||||
}
|
||||
|
||||
// RegisterEndpoints registers all Agent Network routes.
|
||||
func RegisterEndpoints(manager agentnetwork.Manager, router *mux.Router) {
|
||||
h := &handler{manager: manager}
|
||||
router.HandleFunc("/agent-network/catalog/providers", h.getCatalogProviders).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers", h.getAllProviders).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers", h.createProvider).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers/{providerId}", h.getProvider).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers/{providerId}", h.updateProvider).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers/{providerId}", h.deleteProvider).Methods("DELETE", "OPTIONS")
|
||||
h.addPolicyEndpoints(router)
|
||||
h.addGuardrailEndpoints(router)
|
||||
h.addSettingsEndpoints(router)
|
||||
h.addConsumptionEndpoints(router)
|
||||
h.addAccessLogEndpoints(router)
|
||||
h.addBudgetRuleEndpoints(router)
|
||||
}
|
||||
|
||||
func (h *handler) getCatalogProviders(w http.ResponseWriter, r *http.Request) {
|
||||
if _, err := nbcontext.GetUserAuthFromContext(r.Context()); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
entries := catalog.All()
|
||||
out := make([]api.AgentNetworkCatalogProvider, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
out = append(out, e.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getAllProviders(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.manager.GetAllProviders(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*api.AgentNetworkProvider, 0, len(providers))
|
||||
for _, p := range providers {
|
||||
out = append(out, p.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
providerID := mux.Vars(r)["providerId"]
|
||||
if providerID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.manager.GetProvider(r.Context(), userAuth.AccountId, userAuth.UserId, providerID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, provider.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) createProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkProviderRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validate(&req, true); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
provider := types.NewProvider(userAuth.AccountId)
|
||||
provider.FromAPIRequest(&req)
|
||||
|
||||
bootstrapCluster := ""
|
||||
if req.BootstrapCluster != nil {
|
||||
bootstrapCluster = *req.BootstrapCluster
|
||||
}
|
||||
|
||||
created, err := h.manager.CreateProvider(r.Context(), userAuth.UserId, provider, bootstrapCluster)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
providerID := mux.Vars(r)["providerId"]
|
||||
if providerID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkProviderRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validate(&req, false); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
provider := &types.Provider{
|
||||
ID: providerID,
|
||||
AccountID: userAuth.AccountId,
|
||||
}
|
||||
provider.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdateProvider(r.Context(), userAuth.UserId, provider)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
providerID := mux.Vars(r)["providerId"]
|
||||
if providerID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeleteProvider(r.Context(), userAuth.AccountId, userAuth.UserId, providerID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func validate(req *api.AgentNetworkProviderRequest, requireAPIKey bool) error {
|
||||
if strings.TrimSpace(req.ProviderId) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "provider_id is required")
|
||||
}
|
||||
if !catalog.IsKnown(req.ProviderId) {
|
||||
return status.Errorf(status.InvalidArgument, "provider_id %q is not a known catalog provider", req.ProviderId)
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
if strings.TrimSpace(req.UpstreamUrl) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "upstream_url is required")
|
||||
}
|
||||
u, err := url.Parse(strings.TrimSpace(req.UpstreamUrl))
|
||||
if err != nil || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") {
|
||||
return status.Errorf(status.InvalidArgument, "upstream_url must be a full http(s) URL")
|
||||
}
|
||||
if requireAPIKey && (req.ApiKey == nil || strings.TrimSpace(*req.ApiKey) == "") {
|
||||
return status.Errorf(status.InvalidArgument, "api_key is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// addSettingsEndpoints registers the Agent Network settings routes. The
|
||||
// settings row is bootstrapped server-side on first provider create; GET reads
|
||||
// it and PUT updates the mutable collection toggles (cluster/subdomain stay
|
||||
// immutable).
|
||||
func (h *handler) addSettingsEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/settings", h.getSettings).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/settings", h.updateSettings).Methods("PUT", "OPTIONS")
|
||||
}
|
||||
|
||||
// updateSettings applies the collection toggles to the account's settings row.
|
||||
func (h *handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkSettingsRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings := &types.Settings{AccountID: userAuth.AccountId}
|
||||
settings.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdateSettings(r.Context(), userAuth.UserId, settings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
// getSettings returns the account's agent-network settings. The settings
|
||||
// row is bootstrapped on first provider create, so freshly-onboarded
|
||||
// accounts have nothing to read. Rather than 404-ing in that case (which
|
||||
// the dashboard would have to special-case), return a JSON null with 200
|
||||
// so consumers can branch on the body alone.
|
||||
func (h *handler) getSettings(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.manager.GetSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
var sErr *status.Error
|
||||
if errors.As(err, &sErr) && sErr.Type() == status.NotFound {
|
||||
util.WriteJSONObject(r.Context(), w, nil)
|
||||
return
|
||||
}
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, settings.ToAPIResponse())
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
// Package labelgen produces DNS-safe Agent Network subdomain labels.
|
||||
package labelgen
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// pickAttempts caps the random retries before falling back to the
|
||||
// suffixed form. Eight is a soft compromise: with a near-empty taken
|
||||
// set the very first pick almost always succeeds; when the wordlist is
|
||||
// densely populated the fallback eventually fires anyway.
|
||||
const pickAttempts = 8
|
||||
|
||||
var (
|
||||
dedupOnce sync.Once
|
||||
uniqWords []string
|
||||
)
|
||||
|
||||
// uniqueWords returns the wordlist deduplicated and sorted for
|
||||
// deterministic exhaustion behaviour. Lazy-built once per process.
|
||||
func uniqueWords() []string {
|
||||
dedupOnce.Do(func() {
|
||||
seen := make(map[string]struct{}, len(words))
|
||||
uniqWords = make([]string, 0, len(words))
|
||||
for _, w := range words {
|
||||
if _, ok := seen[w]; ok {
|
||||
continue
|
||||
}
|
||||
seen[w] = struct{}{}
|
||||
uniqWords = append(uniqWords, w)
|
||||
}
|
||||
sort.Strings(uniqWords)
|
||||
})
|
||||
return uniqWords
|
||||
}
|
||||
|
||||
// PickUnique selects a label not already in `taken`. It tries up to
|
||||
// pickAttempts random picks; on exhaustion it scans the deduplicated
|
||||
// wordlist for any remaining free entry, and if none is left appends
|
||||
// `-<fallbackSuffix>` to a deterministic word and returns. The caller
|
||||
// is responsible for seeding rng (math/rand).
|
||||
func PickUnique(rng *rand.Rand, taken map[string]struct{}, fallbackSuffix string) string {
|
||||
pool := uniqueWords()
|
||||
if len(pool) == 0 {
|
||||
return fallbackSuffix
|
||||
}
|
||||
|
||||
for i := 0; i < pickAttempts; i++ {
|
||||
w := pool[rng.Intn(len(pool))]
|
||||
if _, ok := taken[w]; !ok {
|
||||
return w
|
||||
}
|
||||
}
|
||||
|
||||
for _, w := range pool {
|
||||
if _, ok := taken[w]; !ok {
|
||||
return w
|
||||
}
|
||||
}
|
||||
|
||||
w := pool[rng.Intn(len(pool))]
|
||||
return fmt.Sprintf("%s-%s", w, fallbackSuffix)
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package labelgen
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestPickUnique_DeterministicWithSeededRng locks the property the
|
||||
// caller relies on: same seed + same taken set → same pick. Without
|
||||
// that, the bootstrap flow can't reproduce a label across retries.
|
||||
func TestPickUnique_DeterministicWithSeededRng(t *testing.T) {
|
||||
taken := map[string]struct{}{}
|
||||
|
||||
rngA := rand.New(rand.NewSource(42))
|
||||
rngB := rand.New(rand.NewSource(42))
|
||||
|
||||
a := PickUnique(rngA, taken, "abcd")
|
||||
b := PickUnique(rngB, taken, "abcd")
|
||||
|
||||
assert.Equal(t, a, b, "Same seed and taken set must produce identical pick")
|
||||
}
|
||||
|
||||
// TestPickUnique_AvoidsTakenWordsWhenMostAreReserved seeds taken with
|
||||
// every word in the pool except a handful and confirms PickUnique
|
||||
// finds one of the remaining free entries instead of returning the
|
||||
// fallback form.
|
||||
func TestPickUnique_AvoidsTakenWordsWhenMostAreReserved(t *testing.T) {
|
||||
pool := uniqueWords()
|
||||
require.NotEmpty(t, pool, "wordlist must be populated for the test to mean anything")
|
||||
|
||||
free := map[string]struct{}{
|
||||
pool[0]: {},
|
||||
pool[len(pool)/2]: {},
|
||||
pool[len(pool)-1]: {},
|
||||
}
|
||||
|
||||
taken := make(map[string]struct{}, len(pool))
|
||||
for _, w := range pool {
|
||||
if _, ok := free[w]; ok {
|
||||
continue
|
||||
}
|
||||
taken[w] = struct{}{}
|
||||
}
|
||||
|
||||
rng := rand.New(rand.NewSource(7))
|
||||
got := PickUnique(rng, taken, "abcd")
|
||||
|
||||
_, isFree := free[got]
|
||||
assert.True(t, isFree, "PickUnique must return one of the free words; got %q", got)
|
||||
assert.NotContains(t, got, "-", "Free pick must not be the suffix fallback form")
|
||||
}
|
||||
|
||||
// TestPickUnique_FallsBackWhenAllReserved exhausts the pool and
|
||||
// confirms PickUnique appends the supplied suffix instead of
|
||||
// returning a duplicate.
|
||||
func TestPickUnique_FallsBackWhenAllReserved(t *testing.T) {
|
||||
pool := uniqueWords()
|
||||
|
||||
taken := make(map[string]struct{}, len(pool))
|
||||
for _, w := range pool {
|
||||
taken[w] = struct{}{}
|
||||
}
|
||||
|
||||
rng := rand.New(rand.NewSource(99))
|
||||
got := PickUnique(rng, taken, "abcd")
|
||||
|
||||
assert.True(t, strings.HasSuffix(got, "-abcd"), "Exhausted pool must produce <word>-<suffix>; got %q", got)
|
||||
|
||||
prefix := strings.TrimSuffix(got, "-abcd")
|
||||
found := false
|
||||
for _, w := range pool {
|
||||
if w == prefix {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Fallback prefix must be drawn from the wordlist; got %q", prefix)
|
||||
}
|
||||
|
||||
// TestUniqueWords_DropsDuplicates guards against authoring slips in
|
||||
// words.go: every entry must be unique and DNS-safe.
|
||||
func TestUniqueWords_DropsDuplicates(t *testing.T) {
|
||||
pool := uniqueWords()
|
||||
seen := make(map[string]struct{}, len(pool))
|
||||
for _, w := range pool {
|
||||
_, dup := seen[w]
|
||||
assert.False(t, dup, "Duplicate entry %q in deduplicated pool", w)
|
||||
seen[w] = struct{}{}
|
||||
assert.GreaterOrEqual(t, len(w), 4, "Word %q is shorter than 4 chars", w)
|
||||
assert.LessOrEqual(t, len(w), 12, "Word %q is longer than 12 chars", w)
|
||||
for _, r := range w {
|
||||
ok := r >= 'a' && r <= 'z'
|
||||
assert.True(t, ok, "Word %q contains non-lowercase-ASCII rune %q", w, r)
|
||||
}
|
||||
}
|
||||
assert.GreaterOrEqual(t, len(pool), 500, "Pool must contain at least 500 unique words")
|
||||
}
|
||||
136
management/internals/modules/agentnetwork/labelgen/words.go
Normal file
136
management/internals/modules/agentnetwork/labelgen/words.go
Normal file
@@ -0,0 +1,136 @@
|
||||
// Package labelgen produces DNS-safe Agent Network subdomain labels.
|
||||
//
|
||||
// The wordlist below is a curated subset drawn from public-domain
|
||||
// nature / common-noun pools (e.g. EFF's diceware lists). Every entry
|
||||
// is lowercase ASCII, 4–12 chars, no hyphens, no digits, and was
|
||||
// hand-checked to avoid offensive, brand, or region-specific terms.
|
||||
package labelgen
|
||||
|
||||
// words is the pool PickUnique selects from. The slice is intentionally
|
||||
// not sorted — random picks distribute across the list naturally.
|
||||
var words = []string{
|
||||
"acorn", "adobe", "agate", "alder", "almond", "alpine", "amber", "amethyst",
|
||||
"anchor", "antler", "apple", "apricot", "arcade", "arctic", "arrow", "ashen",
|
||||
"aspen", "atlas", "atom", "aurora", "autumn", "azure",
|
||||
"badger", "bamboo", "banana", "banjo", "barley", "barn", "basalt", "basil",
|
||||
"basin", "bayou", "beach", "beacon", "beaver", "beech", "beetle", "berry",
|
||||
"birch", "bison", "blossom", "blue", "bobcat", "bonsai", "boulder", "branch",
|
||||
"brass", "breeze", "bridge", "bright", "brook", "broom", "brown", "buffalo",
|
||||
"bumble", "burrow", "butter", "button",
|
||||
"cabin", "cactus", "calm", "camel", "campfire", "canary", "candle", "canoe",
|
||||
"canyon", "cardinal", "carrot", "cascade", "castle", "cedar", "celery", "cello",
|
||||
"cement", "cherry", "chestnut", "chime", "cinnamon", "cinder", "citron", "clay",
|
||||
"clear", "cliff", "clock", "cloud", "clover", "coast", "cobalt", "cobble",
|
||||
"cocoa", "coffee", "comet", "compass", "copper", "coral", "corner", "cosmos",
|
||||
"cotton", "cougar", "country", "coyote", "cove", "crane", "crater", "creek",
|
||||
"crescent", "crimson", "crocus", "crystal", "cypress",
|
||||
"daffodil", "dahlia", "daisy", "dawn", "deer", "delta", "denim", "desert",
|
||||
"dewdrop", "diamond", "dolphin", "doodle", "dove", "dragon", "drift", "drop",
|
||||
"dune", "dusk", "dusty",
|
||||
"eagle", "earth", "echo", "elder", "elkhorn", "ember", "emerald", "emperor",
|
||||
"evergreen", "evening",
|
||||
"falcon", "fawn", "feather", "fern", "fiddle", "field", "fiesta", "finch",
|
||||
"firepit", "firefly", "fjord", "flame", "flax", "fleece", "flint", "floral",
|
||||
"flower", "flute", "foal", "foggy", "forest", "fountain", "foxglove", "fresh",
|
||||
"frost", "fuchsia", "fudge",
|
||||
"gable", "galaxy", "garden", "garnet", "gazelle", "geode", "geyser", "ginger",
|
||||
"glacier", "glade", "glass", "glow", "gold", "goose", "gorge", "gourd",
|
||||
"granite", "grape", "grass", "gravel", "grayling", "greenery", "grizzly", "grove",
|
||||
"gull", "gumdrop", "gust",
|
||||
"hammock", "harbor", "harvest", "hawk", "hazel", "heather", "hedge", "heron",
|
||||
"hibiscus", "hickory", "hideaway", "highland", "hill", "hive", "hollow", "honey",
|
||||
"hopper", "horizon", "hummingbird", "husky",
|
||||
"iceberg", "indigo", "iris", "island", "ivory", "ivybush",
|
||||
"jade", "jasmine", "jasper", "jaybird", "jelly", "jewel", "jonquil", "journey",
|
||||
"juniper", "jupiter", "jute",
|
||||
"kale", "kangaroo", "kayak", "kelp", "kestrel", "kettle", "khaki", "kindling",
|
||||
"kingfisher", "kiwi", "knapweed", "koala",
|
||||
"lagoon", "lake", "lantern", "larch", "lark", "laurel", "lava", "lavender",
|
||||
"leaf", "lemon", "lichen", "light", "lilac", "lily", "lime", "limestone",
|
||||
"linden", "linen", "lion", "lobster", "locust", "loon", "lotus", "lumber",
|
||||
"lunar", "lupine", "lynx",
|
||||
"madrone", "magenta", "magnolia", "mahogany", "mallow", "mango", "manor", "maple",
|
||||
"marble", "marigold", "marina", "marlin", "marsh", "mauve", "meadow", "melody",
|
||||
"melon", "merlin", "metal", "midnight", "milk", "millet", "mineral", "mint",
|
||||
"mirror", "mist", "mitten", "molasses", "moon", "moose", "morning", "moss",
|
||||
"mountain", "mulberry", "muscat", "mustard",
|
||||
"narwhal", "navy", "nectar", "needle", "nest", "nettle", "newt", "nightfall",
|
||||
"noon", "nook", "north", "nova", "nutmeg",
|
||||
"oaken", "oasis", "oatmeal", "ocean", "ochre", "octagon", "olive", "onyx",
|
||||
"opal", "orange", "orbit", "orchard", "orchid", "oregano", "orion", "osprey",
|
||||
"otter", "outpost", "owlet", "oyster",
|
||||
"painter", "palace", "palm", "pansy", "panther", "papaya", "paprika", "parsley",
|
||||
"partridge", "passage", "pastel", "patio", "peach", "peacock", "pear", "pearl",
|
||||
"pebble", "pecan", "pelican", "penguin", "peony", "pepper", "perch", "peridot",
|
||||
"pewter", "phoenix", "pier", "pillar", "pine", "pineapple", "pinto", "piper",
|
||||
"pistachio", "plain", "planet", "plateau", "platinum", "plum", "plume", "polar",
|
||||
"pollen", "pond", "poplar", "poppy", "porcelain", "portal", "portrait", "potato",
|
||||
"prairie", "primrose", "prism", "puffin", "pumpkin",
|
||||
"quail", "quartz", "quaver", "quill", "quince", "quinoa",
|
||||
"rabbit", "raccoon", "radish", "rain", "rainbow", "raindrop", "rapids", "raspberry",
|
||||
"raven", "ravine", "redwood", "reed", "reef", "ridge", "river", "robin",
|
||||
"rocket", "rubyred", "rose", "rosemary", "rosewood", "ruffle", "rugby", "russet",
|
||||
"rustic", "ryefield",
|
||||
"saffron", "sage", "salmon", "sand", "sandstone", "sapphire", "savanna", "scarlet",
|
||||
"scout", "seal", "season", "seaweed", "sequoia", "shadow", "shamrock", "shell",
|
||||
"sherbet", "shore", "silver", "siskin", "skybloom", "skyline", "sleet", "smoke",
|
||||
"snail", "snapdragon", "snow", "snowflake", "snowy", "solar", "song", "sonic",
|
||||
"sorrel", "south", "sparkle", "sparrow", "spice", "spider", "spinach", "spire",
|
||||
"spring", "sprout", "spruce", "squirrel", "starfish", "starlight", "stoat", "stone",
|
||||
"stork", "storm", "stream", "studio", "summer", "sunbeam", "sundew", "sunny",
|
||||
"sunrise", "sunset", "swallow", "swan", "sweet", "sycamore",
|
||||
"tangelo", "tangerine", "tansy", "taupe", "teak", "teal", "thicket", "thistle",
|
||||
"thrush", "thunder", "tide", "tiger", "tinder", "topaz", "torch", "tortoise",
|
||||
"tower", "trail", "tranquil", "tundra", "tulip", "turquoise", "turtle", "twig",
|
||||
"twilight",
|
||||
"umber", "uplands",
|
||||
"valley", "vanilla", "velvet", "venus", "verdant", "verdigris", "vermilion", "violet",
|
||||
"vista", "vivid", "volcano", "vortex",
|
||||
"walnut", "warbler", "watercress", "waterfall", "wave", "waxwing", "weasel", "westwind",
|
||||
"whale", "whisker", "whisper", "wicker", "wildwood", "willow", "winter", "wisp",
|
||||
"wisteria", "wolf", "wombat", "woodland", "woolly", "wren", "wreath",
|
||||
"yarrow", "yellow", "yewtree", "yodel",
|
||||
"zebra", "zenith", "zephyr", "zinnia",
|
||||
"alabaster", "alfalfa", "almanac", "anise", "antelope", "arbor", "arena", "armadillo",
|
||||
"avocet", "azalea", "balsam", "bayou", "beacon", "blizzard", "bluebell", "bluebird",
|
||||
"bluejay", "bobolink", "borage", "boreal", "buckeye", "buckthorn", "buttercup",
|
||||
"cabana", "calico", "canopy", "caraway", "cardamom", "cattail", "celadon", "centaur",
|
||||
"chambray", "chamois", "champlain", "chestnuts", "chickadee", "chinook", "chipmunk", "cinnabar",
|
||||
"cirrus", "citrine", "clematis", "copperhead",
|
||||
"crocodile", "currant", "cuttlebone", "daffy", "dapple", "delphinium", "dervish", "diamondback",
|
||||
"dogwood", "dolphins", "dragonfly", "driftwood", "dusk", "dustpan", "ebony", "edelweiss",
|
||||
"emperor", "endive", "estuary", "everglade", "fairway", "feldspar", "fennel", "fieldstone",
|
||||
"firebrand", "firefly", "fireweed", "firework", "flagstone", "fossil", "frostbite", "galleon",
|
||||
"gardener", "geranium", "gingko", "ginseng", "goldfish", "goldfinch", "goldenrod", "graphite",
|
||||
"greenfinch", "guppy", "haiku", "halibut", "hammerhead", "harbinger", "harvest", "hatchling",
|
||||
"havana", "hawthorn", "hazelnut", "heartwood", "henna", "heron", "highrise", "homestead",
|
||||
"honeycomb", "honeydew", "horseshoe", "hyacinth", "iceland", "icicle", "indigobird", "ironwood",
|
||||
"jacaranda", "jamboree", "javelina", "jellyfish", "junebug", "kaleido", "kayaker", "kerchief",
|
||||
"keystone", "kingdom", "labrador", "lacewing", "ladybug", "lakeside", "lamplight", "leopard",
|
||||
"lighthouse", "lilypad", "lullaby", "magnet", "mahonia", "mandolin", "manzanita", "maraschino",
|
||||
"mariner", "marsupial", "mastodon", "matterhorn", "mayflower", "mayfly", "meadowlark", "merlot",
|
||||
"meteor", "midshipman", "millpond", "mimosa", "minnow", "mockingbird", "molten", "monarch",
|
||||
"monsoon", "moondust", "moonlight", "moorland", "morning", "mossland", "mountain", "mulch",
|
||||
"narcissus", "nautilus", "nettlebush", "northstar", "nuthatch", "obsidian", "okra", "olivine",
|
||||
"opalescent", "orchidea", "orchard", "ornament", "outrigger", "oxalis", "paddler", "paintbrush",
|
||||
"papyrus", "paradise", "pasture", "patchwork", "pathway", "peridot", "periwinkle", "petalbloom",
|
||||
"petrel", "petunia", "phlox", "pikeperch", "pinecone", "pioneer", "pipevine", "platypus",
|
||||
"pomelo", "pondweed", "porpoise", "powder", "promise", "puddle", "pumice", "puzzle",
|
||||
"quetzal", "quicksilver", "raccoon", "ragwort", "rainforest", "ramble", "rapid", "rascal",
|
||||
"raspberry", "redbud", "redfern", "redpoll", "reedling", "ringtail", "riverbed", "riverbird",
|
||||
"riverstone", "rockcress", "roebuck", "rosebay", "rosehip", "rosemary", "rowan", "rumble",
|
||||
"runaway", "rustler", "sagebrush", "sailcloth", "salamander", "salsify", "samphire", "sandbar",
|
||||
"sanddollar", "sandpiper", "santolina", "sapodilla", "sassafras", "scallion", "schooner", "seafoam",
|
||||
"seafrost", "seagrass", "seahorse", "seaport", "seashell", "seaspray", "shamble", "shimmer",
|
||||
"shoreline", "silkmoth", "silverfox", "skylark", "snapdragon", "snowberry", "snowdrop", "snowfall",
|
||||
"snowmelt", "softwood", "songbird", "sorghum", "southwind", "speedwell", "spinnaker", "spruce",
|
||||
"starlight", "starling", "stormcloud", "summit", "sundance", "sundew", "sundial", "sunflower",
|
||||
"surface", "swallowtail", "sweetcorn", "sycamore", "tabletop", "tamarack", "tamarind", "tangerine",
|
||||
"tarragon", "telescope", "thicket", "thrasher", "thunder", "thyme", "tideline", "timberland",
|
||||
"tinderbox", "topiary", "torchwood", "totem", "tradewind", "treasure", "tremolo", "trinket",
|
||||
"trumpetvine", "tugboat", "tundra", "turnstone", "underbrush", "vagabond", "valerian", "vanilla",
|
||||
"velveteen", "vermilion", "vinca", "vineyard", "violet", "voyager", "wagonwheel", "walnutwood",
|
||||
"watermark", "watershed", "waterway", "wavefront", "westerly", "whaleback", "whetstone", "wicker",
|
||||
"wildbloom", "wildflower", "wilderness", "windsong", "windward", "winterberry", "woodbine", "woodfern",
|
||||
"woodland", "woodthrush", "woolgrass", "yellowfin", "zenithal", "zucchini",
|
||||
}
|
||||
896
management/internals/modules/agentnetwork/manager.go
Normal file
896
management/internals/modules/agentnetwork/manager.go
Normal file
@@ -0,0 +1,896 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/labelgen"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// ensureSessionKeys mints an ed25519 session keypair on the provider
|
||||
// when one is missing. Idempotent: skips when both fields are already
|
||||
// populated (e.g. update or migrated rows). The keys are used by the
|
||||
// synthesised reverse-proxy service to sign / verify session JWTs
|
||||
// after a successful OIDC handshake.
|
||||
func ensureSessionKeys(p *types.Provider) error {
|
||||
if p.SessionPrivateKey != "" && p.SessionPublicKey != "" {
|
||||
return nil
|
||||
}
|
||||
pair, err := sessionkey.GenerateKeyPair()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate provider session keys: %w", err)
|
||||
}
|
||||
p.SessionPrivateKey = pair.PrivateKey
|
||||
p.SessionPublicKey = pair.PublicKey
|
||||
return nil
|
||||
}
|
||||
|
||||
// Manager governs the lifecycle of Agent Network providers and policies.
|
||||
type Manager interface {
|
||||
GetAllProviders(ctx context.Context, accountID, userID string) ([]*types.Provider, error)
|
||||
GetProvider(ctx context.Context, accountID, userID, providerID string) (*types.Provider, error)
|
||||
CreateProvider(ctx context.Context, userID string, provider *types.Provider, bootstrapCluster string) (*types.Provider, error)
|
||||
UpdateProvider(ctx context.Context, userID string, provider *types.Provider) (*types.Provider, error)
|
||||
DeleteProvider(ctx context.Context, accountID, userID, providerID string) error
|
||||
|
||||
GetAllPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
||||
GetPolicy(ctx context.Context, accountID, userID, policyID string) (*types.Policy, error)
|
||||
CreatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error)
|
||||
UpdatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error)
|
||||
DeletePolicy(ctx context.Context, accountID, userID, policyID string) error
|
||||
|
||||
GetAllGuardrails(ctx context.Context, accountID, userID string) ([]*types.Guardrail, error)
|
||||
GetGuardrail(ctx context.Context, accountID, userID, guardrailID string) (*types.Guardrail, error)
|
||||
CreateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error)
|
||||
UpdateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error)
|
||||
DeleteGuardrail(ctx context.Context, accountID, userID, guardrailID string) error
|
||||
|
||||
GetAllBudgetRules(ctx context.Context, accountID, userID string) ([]*types.AccountBudgetRule, error)
|
||||
GetBudgetRule(ctx context.Context, accountID, userID, ruleID string) (*types.AccountBudgetRule, error)
|
||||
CreateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error)
|
||||
UpdateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error)
|
||||
DeleteBudgetRule(ctx context.Context, accountID, userID, ruleID string) error
|
||||
|
||||
GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error)
|
||||
UpdateSettings(ctx context.Context, userID string, settings *types.Settings) (*types.Settings, error)
|
||||
|
||||
ListConsumption(ctx context.Context, accountID, userID string) ([]*types.Consumption, error)
|
||||
ListAccessLogs(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error)
|
||||
GetUsageOverview(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter, granularity types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error)
|
||||
StartAccessLogCleanup(ctx context.Context, cleanupIntervalHours int)
|
||||
RecordConsumption(ctx context.Context, accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds, tokensIn, tokensOut int64, costUSD float64) error
|
||||
RecordAccountBudgetUsage(ctx context.Context, accountID, userID string, groupIDs []string, tokensIn, tokensOut int64, costUSD float64) error
|
||||
RecordUsage(ctx context.Context, in RecordUsageInput) error
|
||||
SelectPolicyForRequest(ctx context.Context, in PolicySelectionInput) (*PolicySelectionResult, error)
|
||||
}
|
||||
|
||||
// PolicySelectionInput is the per-request selection envelope. The
|
||||
// proxy populates it from CapturedData (account, user, groups) plus
|
||||
// the provider llm_router resolved.
|
||||
type PolicySelectionInput struct {
|
||||
AccountID string
|
||||
UserID string
|
||||
GroupIDs []string
|
||||
ProviderID string
|
||||
}
|
||||
|
||||
// PolicySelectionResult names the policy that "pays" for this request
|
||||
// plus the deny envelope when every applicable policy has exhausted
|
||||
// every cap. AttributionGroupID is the lowest group id (string sort)
|
||||
// of caller_groups ∩ selected_policy.source_groups; empty when no
|
||||
// group dimension applies. WindowSeconds is the chosen policy's
|
||||
// effective window length in seconds (token_limit's wins when both
|
||||
// halves are enabled with mismatched windows; budget_limit's
|
||||
// otherwise; 0 when no caps are configured at all).
|
||||
type PolicySelectionResult struct {
|
||||
Allow bool
|
||||
SelectedPolicyID string
|
||||
AttributionGroupID string
|
||||
WindowSeconds int64
|
||||
DenyCode string
|
||||
DenyReason string
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
proxyController proxy.Controller
|
||||
|
||||
// reconcileCache holds the last set of synthesised proxy mappings
|
||||
// per account so reconcile can emit precise Create/Update/Delete
|
||||
// updates instead of a full re-push on every mutation. Keyed by
|
||||
// accountID, then by synthesised service ID.
|
||||
reconcileMu sync.Mutex
|
||||
reconcileCache map[string]map[string]*proto.ProxyMapping
|
||||
|
||||
// labelRngMu guards labelRng. PickUnique consumes math/rand.Source
|
||||
// state; concurrent provider creates would otherwise race.
|
||||
labelRngMu sync.Mutex
|
||||
labelRng *rand.Rand
|
||||
}
|
||||
|
||||
// NewManager constructs the persistent Agent Network manager. The
|
||||
// manager persists provider/policy/guardrail configuration and, on
|
||||
// every mutation, reconciles the in-memory synthesised reverse-proxy
|
||||
// services with the proxy cluster via proxyController. Pass nil for
|
||||
// proxyController to disable the reconcile push (useful in tests).
|
||||
func NewManager(
|
||||
store store.Store,
|
||||
permissionsManager permissions.Manager,
|
||||
accountManager account.Manager,
|
||||
proxyController proxy.Controller,
|
||||
) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
proxyController: proxyController,
|
||||
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
|
||||
labelRng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllProviders(ctx context.Context, accountID, userID string) ([]*types.Provider, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetProvider(ctx context.Context, accountID, userID, providerID string) (*types.Provider, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthNone, accountID, providerID)
|
||||
}
|
||||
|
||||
// CreateProvider persists a new provider for the account. bootstrapCluster
|
||||
// is used only when the per-account agent-network Settings row hasn't
|
||||
// been created yet; otherwise it is ignored (the cluster is pinned on
|
||||
// Settings and every provider in the account routes through it).
|
||||
func (m *managerImpl) CreateProvider(ctx context.Context, userID string, provider *types.Provider, bootstrapCluster string) (*types.Provider, error) {
|
||||
if err := m.requirePermission(ctx, provider.AccountID, userID, operations.Create); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// An empty api_key would silently produce a synthesised service
|
||||
// that 401s on every upstream request. Surface the misconfiguration
|
||||
// at create time instead.
|
||||
if strings.TrimSpace(provider.APIKey) == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "api_key is required when creating an agent network provider")
|
||||
}
|
||||
|
||||
if provider.ID == "" {
|
||||
fresh := types.NewProvider(provider.AccountID)
|
||||
provider.ID = fresh.ID
|
||||
provider.CreatedAt = fresh.CreatedAt
|
||||
provider.UpdatedAt = fresh.UpdatedAt
|
||||
}
|
||||
|
||||
if err := ensureSessionKeys(provider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.store.SaveAgentNetworkProvider(ctx, provider); err != nil {
|
||||
return nil, fmt.Errorf("save agent network provider: %w", err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(bootstrapCluster) != "" {
|
||||
if _, err := m.bootstrapSettingsIfNeeded(ctx, provider.AccountID, bootstrapCluster); err != nil {
|
||||
// The provider create has already succeeded; logging the
|
||||
// bootstrap miss matches the plan's PoC behaviour. The synth
|
||||
// path treats a missing settings row as a no-op, and the next
|
||||
// provider create retries the bootstrap.
|
||||
log.WithContext(ctx).Debugf("agent-network bootstrap settings for account %s on cluster %s: %v", provider.AccountID, bootstrapCluster, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, provider.ID, provider.AccountID, activity.AgentNetworkProviderCreated, provider.EventMeta())
|
||||
m.reconcile(ctx, provider.AccountID)
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateProvider(ctx context.Context, userID string, provider *types.Provider) (*types.Provider, error) {
|
||||
if err := m.requirePermission(ctx, provider.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthUpdate, provider.AccountID, provider.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get agent network provider: %w", err)
|
||||
}
|
||||
|
||||
// Preserve the API key if the caller didn't rotate it. A
|
||||
// whitespace-only value is treated as "not rotated" rather than a
|
||||
// real key, but it must not silently overwrite a valid stored key.
|
||||
if provider.APIKey == "" {
|
||||
provider.APIKey = existing.APIKey
|
||||
} else if strings.TrimSpace(provider.APIKey) == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "api_key must be non-blank when rotating an agent network provider")
|
||||
}
|
||||
// Always preserve the session keypair across updates so existing
|
||||
// session cookies stay valid. The keys are server-managed and
|
||||
// never surfaced through the API.
|
||||
provider.SessionPrivateKey = existing.SessionPrivateKey
|
||||
provider.SessionPublicKey = existing.SessionPublicKey
|
||||
if err := ensureSessionKeys(provider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
provider.CreatedAt = existing.CreatedAt
|
||||
provider.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkProvider(ctx, provider); err != nil {
|
||||
return nil, fmt.Errorf("save agent network provider: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, provider.ID, provider.AccountID, activity.AgentNetworkProviderUpdated, provider.EventMeta())
|
||||
m.reconcile(ctx, provider.AccountID)
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteProvider(ctx context.Context, accountID, userID, providerID string) error {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
provider, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthUpdate, accountID, providerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get agent network provider: %w", err)
|
||||
}
|
||||
|
||||
// Refuse to delete while any policy still references this provider.
|
||||
// The operator must detach it first.
|
||||
policies, err := m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get agent network policies: %w", err)
|
||||
}
|
||||
var blocking []string
|
||||
for _, p := range policies {
|
||||
if slices.Contains(p.DestinationProviderIDs, providerID) {
|
||||
blocking = append(blocking, p.Name)
|
||||
}
|
||||
}
|
||||
if len(blocking) > 0 {
|
||||
return status.Errorf(
|
||||
status.InvalidArgument,
|
||||
"provider is in use by %d %s (%s); detach it before deleting",
|
||||
len(blocking),
|
||||
pluralize(len(blocking), "policy", "policies"),
|
||||
strings.Join(blocking, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
if err := m.store.DeleteAgentNetworkProvider(ctx, accountID, providerID); err != nil {
|
||||
return fmt.Errorf("failed to delete agent network provider: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, providerID, accountID, activity.AgentNetworkProviderDeleted, provider.EventMeta())
|
||||
m.reconcile(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func pluralize(n int, singular, plural string) string {
|
||||
if n == 1 {
|
||||
return singular
|
||||
}
|
||||
return plural
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPolicy(ctx context.Context, accountID, userID, policyID string) (*types.Policy, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error) {
|
||||
if err := m.requirePermission(ctx, policy.AccountID, userID, operations.Create); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if policy.ID == "" {
|
||||
fresh := types.NewPolicy(policy.AccountID)
|
||||
policy.ID = fresh.ID
|
||||
policy.CreatedAt = fresh.CreatedAt
|
||||
policy.UpdatedAt = fresh.UpdatedAt
|
||||
}
|
||||
|
||||
if err := m.validateProviderRefs(ctx, policy.AccountID, policy.DestinationProviderIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.store.SaveAgentNetworkPolicy(ctx, policy); err != nil {
|
||||
return nil, fmt.Errorf("failed to save agent network policy: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, policy.ID, policy.AccountID, activity.AgentNetworkPolicyCreated, policy.EventMeta())
|
||||
m.reconcile(ctx, policy.AccountID)
|
||||
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error) {
|
||||
if err := m.requirePermission(ctx, policy.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthUpdate, policy.AccountID, policy.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get agent network policy: %w", err)
|
||||
}
|
||||
|
||||
if err := m.validateProviderRefs(ctx, policy.AccountID, policy.DestinationProviderIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
policy.CreatedAt = existing.CreatedAt
|
||||
policy.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkPolicy(ctx, policy); err != nil {
|
||||
return nil, fmt.Errorf("failed to save agent network policy: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, policy.ID, policy.AccountID, activity.AgentNetworkPolicyUpdated, policy.EventMeta())
|
||||
m.reconcile(ctx, policy.AccountID)
|
||||
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeletePolicy(ctx context.Context, accountID, userID, policyID string) error {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
policy, err := m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get agent network policy: %w", err)
|
||||
}
|
||||
|
||||
if err := m.store.DeleteAgentNetworkPolicy(ctx, accountID, policyID); err != nil {
|
||||
return fmt.Errorf("failed to delete agent network policy: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, policyID, accountID, activity.AgentNetworkPolicyDeleted, policy.EventMeta())
|
||||
m.reconcile(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllGuardrails(ctx context.Context, accountID, userID string) ([]*types.Guardrail, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetGuardrail(ctx context.Context, accountID, userID, guardrailID string) (*types.Guardrail, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthNone, accountID, guardrailID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error) {
|
||||
if err := m.requirePermission(ctx, guardrail.AccountID, userID, operations.Create); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if guardrail.ID == "" {
|
||||
fresh := types.NewGuardrail(guardrail.AccountID)
|
||||
guardrail.ID = fresh.ID
|
||||
guardrail.CreatedAt = fresh.CreatedAt
|
||||
guardrail.UpdatedAt = fresh.UpdatedAt
|
||||
}
|
||||
|
||||
if err := m.store.SaveAgentNetworkGuardrail(ctx, guardrail); err != nil {
|
||||
return nil, fmt.Errorf("failed to save agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, guardrail.ID, guardrail.AccountID, activity.AgentNetworkGuardrailCreated, guardrail.EventMeta())
|
||||
m.reconcile(ctx, guardrail.AccountID)
|
||||
|
||||
return guardrail, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error) {
|
||||
if err := m.requirePermission(ctx, guardrail.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthUpdate, guardrail.AccountID, guardrail.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
guardrail.CreatedAt = existing.CreatedAt
|
||||
guardrail.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkGuardrail(ctx, guardrail); err != nil {
|
||||
return nil, fmt.Errorf("failed to save agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, guardrail.ID, guardrail.AccountID, activity.AgentNetworkGuardrailUpdated, guardrail.EventMeta())
|
||||
m.reconcile(ctx, guardrail.AccountID)
|
||||
|
||||
return guardrail, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteGuardrail(ctx context.Context, accountID, userID, guardrailID string) error {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
guardrail, err := m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthUpdate, accountID, guardrailID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
if err := m.store.DeleteAgentNetworkGuardrail(ctx, accountID, guardrailID); err != nil {
|
||||
return fmt.Errorf("failed to delete agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, guardrailID, accountID, activity.AgentNetworkGuardrailDeleted, guardrail.EventMeta())
|
||||
m.reconcile(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllBudgetRules returns every account-level budget rule for the account.
|
||||
func (m *managerImpl) GetAllBudgetRules(ctx context.Context, accountID, userID string) ([]*types.AccountBudgetRule, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// GetBudgetRule returns a single account-level budget rule.
|
||||
func (m *managerImpl) GetBudgetRule(ctx context.Context, accountID, userID, ruleID string) (*types.AccountBudgetRule, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthNone, accountID, ruleID)
|
||||
}
|
||||
|
||||
// CreateBudgetRule persists a new account-level budget rule. Budget rules are
|
||||
// enforced at request time (CheckLLMPolicyLimits), not baked into the synth
|
||||
// proxy config, so no reconcile is needed.
|
||||
func (m *managerImpl) CreateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
|
||||
if err := m.requirePermission(ctx, rule.AccountID, userID, operations.Create); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rule.ID == "" {
|
||||
fresh := types.NewAccountBudgetRule(rule.AccountID)
|
||||
rule.ID = fresh.ID
|
||||
rule.CreatedAt = fresh.CreatedAt
|
||||
rule.UpdatedAt = fresh.UpdatedAt
|
||||
}
|
||||
|
||||
if err := m.store.SaveAgentNetworkBudgetRule(ctx, rule); err != nil {
|
||||
return nil, fmt.Errorf("save agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, rule.ID, rule.AccountID, activity.AgentNetworkBudgetRuleCreated, rule.EventMeta())
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// UpdateBudgetRule updates an existing account-level budget rule.
|
||||
func (m *managerImpl) UpdateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
|
||||
if err := m.requirePermission(ctx, rule.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthUpdate, rule.AccountID, rule.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
rule.CreatedAt = existing.CreatedAt
|
||||
rule.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkBudgetRule(ctx, rule); err != nil {
|
||||
return nil, fmt.Errorf("save agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, rule.ID, rule.AccountID, activity.AgentNetworkBudgetRuleUpdated, rule.EventMeta())
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// DeleteBudgetRule removes an account-level budget rule.
|
||||
func (m *managerImpl) DeleteBudgetRule(ctx context.Context, accountID, userID, ruleID string) error {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rule, err := m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthUpdate, accountID, ruleID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.store.DeleteAgentNetworkBudgetRule(ctx, accountID, ruleID); err != nil {
|
||||
return fmt.Errorf("delete agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, ruleID, accountID, activity.AgentNetworkBudgetRuleDeleted, rule.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateSettings applies the mutable account-level settings — the collection
|
||||
// toggles — onto the existing row. Cluster and Subdomain are immutable and are
|
||||
// preserved from the persisted row regardless of the input. Because the
|
||||
// collection toggles change the synthesised service config (prompt-capture
|
||||
// gating, access-log emission), a reconcile is triggered so the proxy and peer
|
||||
// network maps converge on the new state.
|
||||
func (m *managerImpl) UpdateSettings(ctx context.Context, userID string, settings *types.Settings) (*types.Settings, error) {
|
||||
if err := m.requirePermission(ctx, settings.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthUpdate, settings.AccountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get agent network settings: %w", err)
|
||||
}
|
||||
|
||||
existing.EnableLogCollection = settings.EnableLogCollection
|
||||
existing.EnablePromptCollection = settings.EnablePromptCollection
|
||||
existing.RedactPii = settings.RedactPii
|
||||
existing.AccessLogRetentionDays = settings.AccessLogRetentionDays
|
||||
existing.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkSettings(ctx, existing); err != nil {
|
||||
return nil, fmt.Errorf("save agent network settings: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, settings.AccountID, settings.AccountID, activity.AgentNetworkSettingsUpdated, map[string]any{
|
||||
"log_collection": existing.EnableLogCollection,
|
||||
"prompt_collection": existing.EnablePromptCollection,
|
||||
"redact_pii": existing.RedactPii,
|
||||
})
|
||||
m.reconcile(ctx, settings.AccountID)
|
||||
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
// validateProviderRefs ensures every destination provider id refers to a
|
||||
// provider that exists in the same account.
|
||||
func (m *managerImpl) validateProviderRefs(ctx context.Context, accountID string, providerIDs []string) error {
|
||||
if len(providerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, id := range providerIDs {
|
||||
if _, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthNone, accountID, id); err != nil {
|
||||
// Only a genuine not-found means the reference is invalid; a
|
||||
// store/runtime error must propagate as-is rather than be
|
||||
// masked as a client validation error.
|
||||
var sErr *status.Error
|
||||
if errors.As(err, &sErr) && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "destination_provider_ids: provider %s does not exist", id)
|
||||
}
|
||||
return fmt.Errorf("get destination provider %s: %w", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSettings returns the agent-network settings row for the account.
|
||||
// Returns the underlying status.NotFound when no row has been
|
||||
// bootstrapped yet (i.e. the account has no providers).
|
||||
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// bootstrapSettingsIfNeeded creates the per-account agent-network
|
||||
// settings row when missing. The cluster comes from the create-time
|
||||
// hint the dashboard sends (auto-picked from the active cluster list);
|
||||
// the subdomain is picked from the curated wordlist avoiding
|
||||
// collisions on the same cluster. Idempotent: if a row already exists
|
||||
// it is returned untouched and the hint is ignored.
|
||||
func (m *managerImpl) bootstrapSettingsIfNeeded(ctx context.Context, accountID, providerCluster string) (*types.Settings, error) {
|
||||
if accountID == "" {
|
||||
return nil, fmt.Errorf("bootstrap settings: account id is required")
|
||||
}
|
||||
if strings.TrimSpace(providerCluster) == "" {
|
||||
return nil, fmt.Errorf("bootstrap settings: provider cluster is required")
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err == nil {
|
||||
return existing, nil
|
||||
}
|
||||
var sErr *status.Error
|
||||
if !errors.As(err, &sErr) || sErr.Type() != status.NotFound {
|
||||
return nil, fmt.Errorf("get agent network settings: %w", err)
|
||||
}
|
||||
|
||||
siblings, err := m.store.GetAgentNetworkSettingsByCluster(ctx, store.LockingStrengthNone, providerCluster)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list agent network settings on cluster: %w", err)
|
||||
}
|
||||
taken := make(map[string]struct{}, len(siblings))
|
||||
for _, s := range siblings {
|
||||
taken[s.Subdomain] = struct{}{}
|
||||
}
|
||||
|
||||
suffix := accountID
|
||||
if len(suffix) > 4 {
|
||||
suffix = suffix[:4]
|
||||
}
|
||||
|
||||
m.labelRngMu.Lock()
|
||||
subdomain := labelgen.PickUnique(m.labelRng, taken, suffix)
|
||||
m.labelRngMu.Unlock()
|
||||
|
||||
now := time.Now().UTC()
|
||||
settings := &types.Settings{
|
||||
AccountID: accountID,
|
||||
Cluster: providerCluster,
|
||||
Subdomain: subdomain,
|
||||
// Logs on by default; usage is collected regardless. Retention bounds
|
||||
// how long full log rows are kept.
|
||||
EnableLogCollection: true,
|
||||
AccessLogRetentionDays: types.DefaultAccessLogRetentionDays,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := m.store.SaveAgentNetworkSettings(ctx, settings); err != nil {
|
||||
return nil, fmt.Errorf("save agent network settings: %w", err)
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// ListConsumption returns every consumption row recorded for the
|
||||
// account, ordered window-newest-first. Backs the dashboard's basic
|
||||
// counter view; permission gate is the same Read role that gates
|
||||
// every other agent-network surface.
|
||||
func (m *managerImpl) ListConsumption(ctx context.Context, accountID, userID string) ([]*types.Consumption, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.ListAgentNetworkConsumption(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// ListAccessLogs returns a paginated, server-side-filtered page of
|
||||
// agent-network access logs plus the total count matching the filter.
|
||||
func (m *managerImpl) ListAccessLogs(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return m.store.GetAgentNetworkAccessLogs(ctx, store.LockingStrengthNone, accountID, filter)
|
||||
}
|
||||
|
||||
// GetUsageOverview returns the filtered usage rows aggregated into time buckets
|
||||
// at the requested granularity, oldest-first.
|
||||
func (m *managerImpl) GetUsageOverview(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter, granularity types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := m.store.GetAgentNetworkUsageRows(ctx, store.LockingStrengthNone, accountID, filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return types.AggregateUsageByGranularity(rows, granularity), nil
|
||||
}
|
||||
|
||||
// StartAccessLogCleanup launches a background sweep that periodically deletes
|
||||
// each account's agent-network access-log rows older than that account's
|
||||
// AccessLogRetentionDays. Usage records are never swept. A non-positive
|
||||
// interval defaults to 24h.
|
||||
func (m *managerImpl) StartAccessLogCleanup(ctx context.Context, cleanupIntervalHours int) {
|
||||
if cleanupIntervalHours <= 0 {
|
||||
cleanupIntervalHours = 24
|
||||
}
|
||||
interval := time.Duration(cleanupIntervalHours) * time.Hour
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
m.cleanupAccessLogsOnce(ctx) // run once on startup
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.cleanupAccessLogsOnce(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// cleanupAccessLogsOnce sweeps every account's expired access-log rows against
|
||||
// its configured retention. Best-effort: a per-account failure is logged and
|
||||
// the sweep continues.
|
||||
func (m *managerImpl) cleanupAccessLogsOnce(ctx context.Context) {
|
||||
settings, err := m.store.GetAllAgentNetworkSettings(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("agent-network access-log cleanup: list settings: %v", err)
|
||||
return
|
||||
}
|
||||
for _, s := range settings {
|
||||
if s.AccessLogRetentionDays <= 0 {
|
||||
continue // keep indefinitely
|
||||
}
|
||||
cutoff := time.Now().UTC().AddDate(0, 0, -s.AccessLogRetentionDays)
|
||||
deleted, err := m.store.DeleteOldAgentNetworkAccessLogs(ctx, s.AccountID, cutoff)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("agent-network access-log cleanup for account %s: %v", s.AccountID, err)
|
||||
continue
|
||||
}
|
||||
if deleted > 0 {
|
||||
log.WithContext(ctx).Infof("agent-network access-log cleanup: deleted %d rows for account %s (retention %d days)", deleted, s.AccountID, s.AccessLogRetentionDays)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecordConsumption increments the (dim, window) counter by the
|
||||
// supplied deltas. The window_start is computed from time.Now under
|
||||
// the supplied window_seconds so callers don't have to pre-align —
|
||||
// the proxy's post-flight path simply hands us tokens + cost and
|
||||
// which dimension we're booking against.
|
||||
func (m *managerImpl) RecordConsumption(ctx context.Context, accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds, tokensIn, tokensOut int64, costUSD float64) error {
|
||||
if accountID == "" || dimID == "" || windowSeconds <= 0 {
|
||||
return status.Errorf(status.InvalidArgument, "account_id, dim_id and window_seconds must be set")
|
||||
}
|
||||
windowStart := types.WindowStart(time.Now(), windowSeconds)
|
||||
return m.store.IncrementAgentNetworkConsumption(ctx, accountID, kind, dimID, windowSeconds, windowStart, tokensIn, tokensOut, costUSD)
|
||||
}
|
||||
|
||||
func (m *managerImpl) requirePermission(ctx context.Context, accountID, userID string, op operations.Operation) error {
|
||||
ok, _, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.AgentNetwork, op)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockManager struct{}
|
||||
|
||||
// NewManagerMock returns a no-op manager useful for tests.
|
||||
func NewManagerMock() Manager {
|
||||
return &mockManager{}
|
||||
}
|
||||
|
||||
func (*mockManager) GetAllProviders(_ context.Context, _, _ string) ([]*types.Provider, error) {
|
||||
return []*types.Provider{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetProvider(_ context.Context, _, _, _ string) (*types.Provider, error) {
|
||||
return &types.Provider{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) CreateProvider(_ context.Context, _ string, p *types.Provider, _ string) (*types.Provider, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (*mockManager) UpdateProvider(_ context.Context, _ string, p *types.Provider) (*types.Provider, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (*mockManager) DeleteProvider(_ context.Context, _, _, _ string) error { return nil }
|
||||
|
||||
func (*mockManager) GetAllPolicies(_ context.Context, _, _ string) ([]*types.Policy, error) {
|
||||
return []*types.Policy{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetPolicy(_ context.Context, _, _, _ string) (*types.Policy, error) {
|
||||
return &types.Policy{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) CreatePolicy(_ context.Context, _ string, p *types.Policy) (*types.Policy, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (*mockManager) UpdatePolicy(_ context.Context, _ string, p *types.Policy) (*types.Policy, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (*mockManager) DeletePolicy(_ context.Context, _, _, _ string) error { return nil }
|
||||
|
||||
func (*mockManager) GetAllGuardrails(_ context.Context, _, _ string) ([]*types.Guardrail, error) {
|
||||
return []*types.Guardrail{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetGuardrail(_ context.Context, _, _, _ string) (*types.Guardrail, error) {
|
||||
return &types.Guardrail{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) CreateGuardrail(_ context.Context, _ string, g *types.Guardrail) (*types.Guardrail, error) {
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (*mockManager) UpdateGuardrail(_ context.Context, _ string, g *types.Guardrail) (*types.Guardrail, error) {
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (*mockManager) DeleteGuardrail(_ context.Context, _, _, _ string) error { return nil }
|
||||
|
||||
func (*mockManager) GetAllBudgetRules(_ context.Context, _, _ string) ([]*types.AccountBudgetRule, error) {
|
||||
return []*types.AccountBudgetRule{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetBudgetRule(_ context.Context, _, _, _ string) (*types.AccountBudgetRule, error) {
|
||||
return &types.AccountBudgetRule{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) CreateBudgetRule(_ context.Context, _ string, r *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (*mockManager) UpdateBudgetRule(_ context.Context, _ string, r *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (*mockManager) DeleteBudgetRule(_ context.Context, _, _, _ string) error { return nil }
|
||||
|
||||
func (*mockManager) GetSettings(_ context.Context, _, _ string) (*types.Settings, error) {
|
||||
return nil, status.Errorf(status.NotFound, "agent network settings not found")
|
||||
}
|
||||
|
||||
func (*mockManager) UpdateSettings(_ context.Context, _ string, s *types.Settings) (*types.Settings, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (*mockManager) ListConsumption(_ context.Context, _, _ string) ([]*types.Consumption, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*mockManager) ListAccessLogs(_ context.Context, _, _ string, _ types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetUsageOverview(_ context.Context, _, _ string, _ types.AgentNetworkAccessLogFilter, _ types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*mockManager) StartAccessLogCleanup(_ context.Context, _ int) {}
|
||||
|
||||
func (*mockManager) RecordConsumption(_ context.Context, _ string, _ types.ConsumptionDimension, _ string, _, _, _ int64, _ float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*mockManager) RecordAccountBudgetUsage(_ context.Context, _, _ string, _ []string, _, _ int64, _ float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*mockManager) RecordUsage(_ context.Context, _ RecordUsageInput) error {
|
||||
return nil
|
||||
}
|
||||
660
management/internals/modules/agentnetwork/policyselect.go
Normal file
660
management/internals/modules/agentnetwork/policyselect.go
Normal file
@@ -0,0 +1,660 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// validateUsageDeltas rejects negative or non-finite usage counters before they
|
||||
// reach the consumption store, so a bad delta can't decrement or poison totals.
|
||||
// The store batch method enforces the same invariant; this is the manager-level
|
||||
// guard so direct callers fail fast with a clear error.
|
||||
func validateUsageDeltas(tokensIn, tokensOut int64, costUSD float64) error {
|
||||
if tokensIn < 0 || tokensOut < 0 || costUSD < 0 || math.IsNaN(costUSD) || math.IsInf(costUSD, 0) {
|
||||
return status.Errorf(status.InvalidArgument, "usage deltas must be non-negative and finite")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deny codes the proxy surfaces back to the caller when every
|
||||
// applicable policy is exhausted. The proxy converts these into
|
||||
// upstream-shaped error responses.
|
||||
const (
|
||||
//nolint:gosec // policy deny code label, not a credential
|
||||
denyCodeTokenCapExceeded = "llm_policy.token_cap_exceeded"
|
||||
//nolint:gosec // policy deny code label, not a credential
|
||||
denyCodeBudgetCapExceeded = "llm_policy.budget_cap_exceeded"
|
||||
//nolint:gosec // account deny code label, not a credential
|
||||
denyCodeAccountTokenCapExceeded = "llm_account.token_cap_exceeded"
|
||||
//nolint:gosec // account deny code label, not a credential
|
||||
denyCodeAccountBudgetCapExceeded = "llm_account.budget_cap_exceeded"
|
||||
)
|
||||
|
||||
// consumptionCache holds the consumption counters prefetched for one
|
||||
// policy-selection request, keyed by ConsumptionKey. A miss returns a zero
|
||||
// counter — the same contract the store's single-row getter uses for absent
|
||||
// rows — so the eval logic is identical whether a counter exists yet or not.
|
||||
type consumptionCache map[types.ConsumptionKey]*types.Consumption
|
||||
|
||||
func (c consumptionCache) get(accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds int64, windowStart time.Time) *types.Consumption {
|
||||
key := types.ConsumptionKey{Kind: kind, DimID: dimID, WindowSeconds: windowSeconds, WindowStartUTC: windowStart.UTC()}
|
||||
if row, ok := c[key]; ok && row != nil {
|
||||
return row
|
||||
}
|
||||
return &types.Consumption{
|
||||
AccountID: accountID,
|
||||
DimensionKind: kind,
|
||||
DimensionID: dimID,
|
||||
WindowSeconds: windowSeconds,
|
||||
WindowStartUTC: windowStart.UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// addLimitKeys records the user/group consumption keys a single enabled (token
|
||||
// or budget) limit window reads for the given attribution group, into a dedup
|
||||
// set. attrGroup may be empty (no group dimension applies).
|
||||
func addLimitKeys(set map[types.ConsumptionKey]struct{}, userID, attrGroup string, windowSeconds int64, now time.Time) {
|
||||
if windowSeconds <= 0 {
|
||||
return
|
||||
}
|
||||
ws := types.WindowStart(now, windowSeconds)
|
||||
if userID != "" {
|
||||
set[types.ConsumptionKey{Kind: types.DimensionUser, DimID: userID, WindowSeconds: windowSeconds, WindowStartUTC: ws}] = struct{}{}
|
||||
}
|
||||
if attrGroup != "" {
|
||||
set[types.ConsumptionKey{Kind: types.DimensionGroup, DimID: attrGroup, WindowSeconds: windowSeconds, WindowStartUTC: ws}] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// prefetchConsumption loads, in one store round-trip, every consumption counter
|
||||
// that the account-budget ceiling and the candidate policies will read while
|
||||
// scoring this request. This replaces the per-cap point reads the selector
|
||||
// previously issued one at a time (the N+1 on the hot path).
|
||||
func (m *managerImpl) prefetchConsumption(ctx context.Context, in PolicySelectionInput, rules []*types.AccountBudgetRule, candidates []*types.Policy, now time.Time) (consumptionCache, error) {
|
||||
set := make(map[types.ConsumptionKey]struct{})
|
||||
for _, p := range candidates {
|
||||
attr := lowestIntersect(p.SourceGroups, in.GroupIDs)
|
||||
if p.Limits.TokenLimit.Enabled {
|
||||
addLimitKeys(set, in.UserID, attr, p.Limits.TokenLimit.WindowSeconds, now)
|
||||
}
|
||||
if p.Limits.BudgetLimit.Enabled {
|
||||
addLimitKeys(set, in.UserID, attr, p.Limits.BudgetLimit.WindowSeconds, now)
|
||||
}
|
||||
}
|
||||
for _, r := range rules {
|
||||
if r == nil || !r.Enabled || !budgetRuleApplies(r, in) {
|
||||
continue
|
||||
}
|
||||
attr := lowestIntersect(r.TargetGroups, in.GroupIDs)
|
||||
if r.Limits.TokenLimit.Enabled {
|
||||
addLimitKeys(set, in.UserID, attr, r.Limits.TokenLimit.WindowSeconds, now)
|
||||
}
|
||||
if r.Limits.BudgetLimit.Enabled {
|
||||
addLimitKeys(set, in.UserID, attr, r.Limits.BudgetLimit.WindowSeconds, now)
|
||||
}
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return consumptionCache{}, nil
|
||||
}
|
||||
keys := make([]types.ConsumptionKey, 0, len(set))
|
||||
for k := range set {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
rows, err := m.store.GetAgentNetworkConsumptionBatch(ctx, store.LockingStrengthNone, in.AccountID, keys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch read consumption: %w", err)
|
||||
}
|
||||
return consumptionCache(rows), nil
|
||||
}
|
||||
|
||||
// SelectPolicyForRequest picks the policy that "pays" for the
|
||||
// incoming request. The chosen policy is the one with the largest
|
||||
// pool that still has headroom — drain the bigger bucket first,
|
||||
// fall through to the next-biggest only when the current one's
|
||||
// group cap or shared per-user cap is exhausted. This matches
|
||||
// operator intuition for layered tiers ("privileged group has the
|
||||
// 10k budget, regular group has 1k as the safety net") and avoids
|
||||
// the load-balancer flapping that fraction-based scoring produces
|
||||
// once any cap has been touched.
|
||||
//
|
||||
// Ordering across non-exhausted candidates:
|
||||
// 1. Policies with NO enabled caps (catch-all-allow) win over any
|
||||
// capped policy — operators who configure unlimited access
|
||||
// expect requests to attribute there until they explicitly add
|
||||
// caps.
|
||||
// 2. Larger group token cap wins.
|
||||
// 3. Larger group budget USD cap wins.
|
||||
// 4. Larger user token cap wins.
|
||||
// 5. Larger user budget USD cap wins.
|
||||
// 6. Older created_at wins (deterministic final tiebreak so
|
||||
// multi-node selection converges).
|
||||
//
|
||||
// Returns Allow=true with empty SelectedPolicyID when no policy in
|
||||
// the account targets the (provider, caller-groups) combination —
|
||||
// llm_router is the gate that owns "no policy authorises this
|
||||
// request" semantics; this function trusts that authorisation has
|
||||
// already happened upstream and only does the limit-aware
|
||||
// attribution.
|
||||
func (m *managerImpl) SelectPolicyForRequest(ctx context.Context, in PolicySelectionInput) (*PolicySelectionResult, error) {
|
||||
if in.AccountID == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "account_id is required")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, in.AccountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list account budget rules: %w", err)
|
||||
}
|
||||
policies, err := m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, in.AccountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list account policies: %w", err)
|
||||
}
|
||||
candidates := filterApplicablePolicies(policies, in)
|
||||
|
||||
// Prefetch every consumption counter the ceiling + candidate policies will
|
||||
// read, in a single store round-trip, then score against the cache.
|
||||
cache, err := m.prefetchConsumption(ctx, in, rules, candidates, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Account-level budget rules are an always-on ceiling, evaluated
|
||||
// independently of policy selection (they bind even for catch-all-allow
|
||||
// policies or requests that match no policy). All applicable rules must
|
||||
// pass — this is where min-wins lives.
|
||||
if deny, code, reason := checkAccountBudget(in, rules, cache, now); deny {
|
||||
return &PolicySelectionResult{Allow: false, DenyCode: code, DenyReason: reason}, nil
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return &PolicySelectionResult{Allow: true}, nil
|
||||
}
|
||||
scored, lastDenyCode, lastDenyReason := scoreCandidates(in, candidates, cache, now)
|
||||
if len(scored) == 0 {
|
||||
return &PolicySelectionResult{
|
||||
Allow: false,
|
||||
DenyCode: lastDenyCode,
|
||||
DenyReason: lastDenyReason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sort.SliceStable(scored, func(i, j int) bool {
|
||||
// Catch-all-allow (no caps configured) wins outright over
|
||||
// any capped policy.
|
||||
iNoCap := isUncapped(scored[i].policy)
|
||||
jNoCap := isUncapped(scored[j].policy)
|
||||
if iNoCap != jNoCap {
|
||||
return iNoCap
|
||||
}
|
||||
// Bigger pool drains first. Group caps dominate (shared
|
||||
// across the group) before individual caps.
|
||||
if a, b := groupCapTokens(scored[i].policy), groupCapTokens(scored[j].policy); a != b {
|
||||
return a > b
|
||||
}
|
||||
if a, b := groupCapBudgetUsd(scored[i].policy), groupCapBudgetUsd(scored[j].policy); a != b {
|
||||
return a > b
|
||||
}
|
||||
if a, b := userCapTokens(scored[i].policy), userCapTokens(scored[j].policy); a != b {
|
||||
return a > b
|
||||
}
|
||||
if a, b := userCapBudgetUsd(scored[i].policy), userCapBudgetUsd(scored[j].policy); a != b {
|
||||
return a > b
|
||||
}
|
||||
return scored[i].policy.CreatedAt.Before(scored[j].policy.CreatedAt)
|
||||
})
|
||||
|
||||
winner := scored[0]
|
||||
return &PolicySelectionResult{
|
||||
Allow: true,
|
||||
SelectedPolicyID: winner.policy.ID,
|
||||
AttributionGroupID: winner.attributionGroup,
|
||||
WindowSeconds: winner.windowSeconds,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// filterApplicablePolicies returns the enabled policies that target
|
||||
// the requested provider and have at least one of the caller's groups
|
||||
// in their source_groups. Caller's group set is matched
|
||||
// case-sensitively against policy.SourceGroups.
|
||||
func filterApplicablePolicies(policies []*types.Policy, in PolicySelectionInput) []*types.Policy {
|
||||
if len(policies) == 0 {
|
||||
return nil
|
||||
}
|
||||
groupSet := make(map[string]struct{}, len(in.GroupIDs))
|
||||
for _, g := range in.GroupIDs {
|
||||
if g != "" {
|
||||
groupSet[g] = struct{}{}
|
||||
}
|
||||
}
|
||||
out := make([]*types.Policy, 0, len(policies))
|
||||
for _, p := range policies {
|
||||
if p == nil || !p.Enabled {
|
||||
continue
|
||||
}
|
||||
if !sliceContains(p.DestinationProviderIDs, in.ProviderID) {
|
||||
continue
|
||||
}
|
||||
if !anyGroupMatches(p.SourceGroups, groupSet) {
|
||||
continue
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// candidate is the per-policy intermediate the selector ranks. A
|
||||
// policy that's been exhausted on any enabled cap never makes it
|
||||
// into this slice; the selector's deny envelope carries the latest
|
||||
// exhaustion's reason out separately.
|
||||
type candidate struct {
|
||||
policy *types.Policy
|
||||
attributionGroup string
|
||||
windowSeconds int64
|
||||
}
|
||||
|
||||
// scoreCandidates evaluates every applicable policy against the
|
||||
// caller's current consumption. Exhausted policies are filtered out
|
||||
// of the returned slice; the most recent exhaustion's deny code +
|
||||
// human reason is returned alongside so the caller can surface it
|
||||
// when no candidate survives.
|
||||
func scoreCandidates(
|
||||
in PolicySelectionInput,
|
||||
candidates []*types.Policy,
|
||||
cache consumptionCache,
|
||||
now time.Time,
|
||||
) ([]candidate, string, string) {
|
||||
out := make([]candidate, 0, len(candidates))
|
||||
var lastDenyCode, lastDenyReason string
|
||||
|
||||
for _, p := range candidates {
|
||||
c, exhausted, denyCode, denyReason := scoreOne(in, p, cache, now)
|
||||
if exhausted {
|
||||
lastDenyCode = denyCode
|
||||
lastDenyReason = denyReason
|
||||
continue
|
||||
}
|
||||
out = append(out, c)
|
||||
}
|
||||
return out, lastDenyCode, lastDenyReason
|
||||
}
|
||||
|
||||
// scoreOne checks a single policy for cap exhaustion. Returns the
|
||||
// candidate envelope when the policy still has headroom on every
|
||||
// enabled cap; reports exhausted=true with a deny code naming the
|
||||
// offending cap kind otherwise.
|
||||
func scoreOne(
|
||||
in PolicySelectionInput,
|
||||
p *types.Policy,
|
||||
cache consumptionCache,
|
||||
now time.Time,
|
||||
) (candidate, bool, string, string) {
|
||||
attrGroup := lowestIntersect(p.SourceGroups, in.GroupIDs)
|
||||
c := candidate{
|
||||
policy: p,
|
||||
attributionGroup: attrGroup,
|
||||
windowSeconds: effectiveWindowSeconds(p),
|
||||
}
|
||||
|
||||
if p.Limits.TokenLimit.Enabled && p.Limits.TokenLimit.WindowSeconds > 0 {
|
||||
if exhausted, reason := evalTokenCap(cache, in.AccountID, in.UserID, attrGroup, p.Limits.TokenLimit, now, "policy "+p.ID); exhausted {
|
||||
return candidate{}, true, denyCodeTokenCapExceeded, reason
|
||||
}
|
||||
}
|
||||
|
||||
if p.Limits.BudgetLimit.Enabled && p.Limits.BudgetLimit.WindowSeconds > 0 {
|
||||
if exhausted, reason := evalBudgetCap(cache, in.AccountID, in.UserID, attrGroup, p.Limits.BudgetLimit, now, "policy "+p.ID); exhausted {
|
||||
return candidate{}, true, denyCodeBudgetCapExceeded, reason
|
||||
}
|
||||
}
|
||||
|
||||
return c, false, "", ""
|
||||
}
|
||||
|
||||
// evalTokenCap reports whether the token limit is already exhausted for the
|
||||
// caller in its own window. attrGroup may be empty (no group dimension applies).
|
||||
// label identifies the cap source ("policy <id>" or "account rule <id>") for the
|
||||
// deny reason. It is the shared primitive behind both policy and account-rule
|
||||
// enforcement.
|
||||
func evalTokenCap(
|
||||
cache consumptionCache,
|
||||
accountID, userID, attrGroup string,
|
||||
tl types.PolicyTokenLimit,
|
||||
now time.Time,
|
||||
label string,
|
||||
) (bool, string) {
|
||||
windowStart := types.WindowStart(now, tl.WindowSeconds)
|
||||
|
||||
if tl.UserCap > 0 && userID != "" {
|
||||
row := cache.get(accountID, types.DimensionUser, userID, tl.WindowSeconds, windowStart)
|
||||
used := row.TokensInput + row.TokensOutput
|
||||
if used >= tl.UserCap {
|
||||
return true, fmt.Sprintf("user token cap exhausted on %s (used %d of %d)", label, used, tl.UserCap)
|
||||
}
|
||||
}
|
||||
|
||||
if tl.GroupCap > 0 && attrGroup != "" {
|
||||
row := cache.get(accountID, types.DimensionGroup, attrGroup, tl.WindowSeconds, windowStart)
|
||||
used := row.TokensInput + row.TokensOutput
|
||||
if used >= tl.GroupCap {
|
||||
return true, fmt.Sprintf("group token cap exhausted on %s (used %d of %d)", label, used, tl.GroupCap)
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// evalBudgetCap is the budget (USD) counterpart of evalTokenCap.
|
||||
func evalBudgetCap(
|
||||
cache consumptionCache,
|
||||
accountID, userID, attrGroup string,
|
||||
bl types.PolicyBudgetLimit,
|
||||
now time.Time,
|
||||
label string,
|
||||
) (bool, string) {
|
||||
windowStart := types.WindowStart(now, bl.WindowSeconds)
|
||||
|
||||
if bl.UserCapUsd > 0 && userID != "" {
|
||||
row := cache.get(accountID, types.DimensionUser, userID, bl.WindowSeconds, windowStart)
|
||||
if row.CostUSD >= bl.UserCapUsd {
|
||||
return true, fmt.Sprintf("user budget cap exhausted on %s (used $%.4f of $%.4f)", label, row.CostUSD, bl.UserCapUsd)
|
||||
}
|
||||
}
|
||||
|
||||
if bl.GroupCapUsd > 0 && attrGroup != "" {
|
||||
row := cache.get(accountID, types.DimensionGroup, attrGroup, bl.WindowSeconds, windowStart)
|
||||
if row.CostUSD >= bl.GroupCapUsd {
|
||||
return true, fmt.Sprintf("group budget cap exhausted on %s (used $%.4f of $%.4f)", label, row.CostUSD, bl.GroupCapUsd)
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// checkAccountBudget evaluates every applicable account-level budget rule as an
|
||||
// all-must-pass ceiling. A rule applies when the caller is in its TargetUsers,
|
||||
// one of its TargetGroups, or it has no targets at all (account-wide). Returns
|
||||
// deny=true with an llm_account.* code on the first exhausted rule. Group caps
|
||||
// attribute to the lowest intersecting group (the same model policies use), so
|
||||
// multi-group behavior is unchanged.
|
||||
func checkAccountBudget(in PolicySelectionInput, rules []*types.AccountBudgetRule, cache consumptionCache, now time.Time) (bool, string, string) {
|
||||
for _, r := range rules {
|
||||
if r == nil || !r.Enabled || !budgetRuleApplies(r, in) {
|
||||
continue
|
||||
}
|
||||
attrGroup := lowestIntersect(r.TargetGroups, in.GroupIDs)
|
||||
label := "account rule " + r.ID
|
||||
|
||||
if r.Limits.TokenLimit.Enabled && r.Limits.TokenLimit.WindowSeconds > 0 {
|
||||
if exhausted, reason := evalTokenCap(cache, in.AccountID, in.UserID, attrGroup, r.Limits.TokenLimit, now, label); exhausted {
|
||||
return true, denyCodeAccountTokenCapExceeded, reason
|
||||
}
|
||||
}
|
||||
|
||||
if r.Limits.BudgetLimit.Enabled && r.Limits.BudgetLimit.WindowSeconds > 0 {
|
||||
if exhausted, reason := evalBudgetCap(cache, in.AccountID, in.UserID, attrGroup, r.Limits.BudgetLimit, now, label); exhausted {
|
||||
return true, denyCodeAccountBudgetCapExceeded, reason
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, "", ""
|
||||
}
|
||||
|
||||
// budgetRuleApplies reports whether an account budget rule binds the caller:
|
||||
// a direct user match, a group intersection, or an untargeted (account-wide)
|
||||
// rule.
|
||||
func budgetRuleApplies(r *types.AccountBudgetRule, in PolicySelectionInput) bool {
|
||||
if len(r.TargetUsers) == 0 && len(r.TargetGroups) == 0 {
|
||||
return true
|
||||
}
|
||||
if in.UserID != "" && sliceContains(r.TargetUsers, in.UserID) {
|
||||
return true
|
||||
}
|
||||
groupSet := make(map[string]struct{}, len(in.GroupIDs))
|
||||
for _, g := range in.GroupIDs {
|
||||
if g != "" {
|
||||
groupSet[g] = struct{}{}
|
||||
}
|
||||
}
|
||||
return anyGroupMatches(r.TargetGroups, groupSet)
|
||||
}
|
||||
|
||||
// RecordAccountBudgetUsage fans the served request's usage out to every
|
||||
// applicable account budget rule's own (dimension, window) counter. The user
|
||||
// dimension is always booked when a rule has a user-applicable cap; the group
|
||||
// dimension books against the rule's lowest intersecting group. This runs
|
||||
// alongside the policy-window record so account ceilings accumulate in their own
|
||||
// windows (commonly monthly) independently of the per-policy window.
|
||||
func (m *managerImpl) RecordAccountBudgetUsage(ctx context.Context, accountID, userID string, groupIDs []string, tokensIn, tokensOut int64, costUSD float64) error {
|
||||
if accountID == "" {
|
||||
return status.Errorf(status.InvalidArgument, "account_id is required")
|
||||
}
|
||||
if err := validateUsageDeltas(tokensIn, tokensOut, costUSD); err != nil {
|
||||
return err
|
||||
}
|
||||
rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list account budget rules: %w", err)
|
||||
}
|
||||
set := make(map[types.ConsumptionKey]struct{})
|
||||
addAccountBudgetKeys(set, PolicySelectionInput{AccountID: accountID, UserID: userID, GroupIDs: groupIDs}, rules, time.Now().UTC())
|
||||
if len(set) == 0 {
|
||||
return nil
|
||||
}
|
||||
return m.store.IncrementAgentNetworkConsumptionBatch(ctx, accountID, keysSlice(set), tokensIn, tokensOut, costUSD)
|
||||
}
|
||||
|
||||
// RecordUsageInput carries everything RecordUsage books for one served request.
|
||||
type RecordUsageInput struct {
|
||||
AccountID string
|
||||
UserID string
|
||||
AttributionGroupID string // selected policy's attribution group (policy window)
|
||||
GroupIDs []string
|
||||
WindowSeconds int64 // selected policy's window; 0 means no policy cap
|
||||
TokensIn int64
|
||||
TokensOut int64
|
||||
CostUSD float64
|
||||
}
|
||||
|
||||
// RecordUsage books a served request's usage against every counter it touches —
|
||||
// the selected policy's per-(user, group) window plus every applicable account
|
||||
// budget rule's own window — deduplicated and written in a single transaction.
|
||||
// Two counters that collapse to the same (dimension, window) tuple are booked
|
||||
// once, so a single request can never double-count against one cap.
|
||||
func (m *managerImpl) RecordUsage(ctx context.Context, in RecordUsageInput) error {
|
||||
if in.AccountID == "" {
|
||||
return status.Errorf(status.InvalidArgument, "account_id is required")
|
||||
}
|
||||
if err := validateUsageDeltas(in.TokensIn, in.TokensOut, in.CostUSD); err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
set := make(map[types.ConsumptionKey]struct{})
|
||||
|
||||
// Policy-window dimensions are booked only when a policy cap bound this
|
||||
// request (window > 0). A zero window means catch-all-allow / no policy cap;
|
||||
// the account fan-out below still books against the budget rules' windows.
|
||||
if in.WindowSeconds > 0 {
|
||||
addLimitKeys(set, in.UserID, in.AttributionGroupID, in.WindowSeconds, now)
|
||||
}
|
||||
|
||||
rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, in.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list account budget rules: %w", err)
|
||||
}
|
||||
addAccountBudgetKeys(set, PolicySelectionInput{AccountID: in.AccountID, UserID: in.UserID, GroupIDs: in.GroupIDs}, rules, now)
|
||||
|
||||
if len(set) == 0 {
|
||||
return nil
|
||||
}
|
||||
return m.store.IncrementAgentNetworkConsumptionBatch(ctx, in.AccountID, keysSlice(set), in.TokensIn, in.TokensOut, in.CostUSD)
|
||||
}
|
||||
|
||||
// addAccountBudgetKeys adds the (dimension, window) keys a served request books
|
||||
// against every applicable account budget rule into the dedup set.
|
||||
func addAccountBudgetKeys(set map[types.ConsumptionKey]struct{}, in PolicySelectionInput, rules []*types.AccountBudgetRule, now time.Time) {
|
||||
for _, r := range rules {
|
||||
if r == nil || !r.Enabled || !budgetRuleApplies(r, in) {
|
||||
continue
|
||||
}
|
||||
attrGroup := lowestIntersect(r.TargetGroups, in.GroupIDs)
|
||||
for _, window := range ruleWindows(r) {
|
||||
addLimitKeys(set, in.UserID, attrGroup, window, now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// keysSlice flattens a ConsumptionKey set into a slice.
|
||||
func keysSlice(set map[types.ConsumptionKey]struct{}) []types.ConsumptionKey {
|
||||
keys := make([]types.ConsumptionKey, 0, len(set))
|
||||
for k := range set {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// ruleWindows returns the distinct enabled window lengths a budget rule books
|
||||
// against (token window and/or budget window, deduplicated).
|
||||
func ruleWindows(r *types.AccountBudgetRule) []int64 {
|
||||
var windows []int64
|
||||
if r.Limits.TokenLimit.Enabled && r.Limits.TokenLimit.WindowSeconds > 0 {
|
||||
windows = append(windows, r.Limits.TokenLimit.WindowSeconds)
|
||||
}
|
||||
if r.Limits.BudgetLimit.Enabled && r.Limits.BudgetLimit.WindowSeconds > 0 {
|
||||
bw := r.Limits.BudgetLimit.WindowSeconds
|
||||
if len(windows) == 0 || windows[0] != bw {
|
||||
windows = append(windows, bw)
|
||||
}
|
||||
}
|
||||
return windows
|
||||
}
|
||||
|
||||
// effectiveWindowSeconds returns the window length the proxy should
|
||||
// hand back to RecordLLMUsage. When both halves are enabled with
|
||||
// different windows, token_limit wins (the more common config); when
|
||||
// only one is enabled that one wins; when neither is enabled the
|
||||
// returned value is 0 — RecordLLMUsage treats 0 as "no limit
|
||||
// tracking" and skips the increment, which is the right pass-through
|
||||
// for catch-all-allow policies with no caps configured.
|
||||
func effectiveWindowSeconds(p *types.Policy) int64 {
|
||||
if p.Limits.TokenLimit.Enabled && p.Limits.TokenLimit.WindowSeconds > 0 {
|
||||
return p.Limits.TokenLimit.WindowSeconds
|
||||
}
|
||||
if p.Limits.BudgetLimit.Enabled && p.Limits.BudgetLimit.WindowSeconds > 0 {
|
||||
return p.Limits.BudgetLimit.WindowSeconds
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// lowestIntersect returns the lowest-by-string-sort element of
|
||||
// callerGroups ∩ sourceGroups. Empty when the intersection is empty.
|
||||
// Lowest is deterministic so multi-node selection converges.
|
||||
func lowestIntersect(sourceGroups, callerGroups []string) string {
|
||||
if len(sourceGroups) == 0 || len(callerGroups) == 0 {
|
||||
return ""
|
||||
}
|
||||
srcSet := make(map[string]struct{}, len(sourceGroups))
|
||||
for _, g := range sourceGroups {
|
||||
srcSet[g] = struct{}{}
|
||||
}
|
||||
var best string
|
||||
for _, g := range callerGroups {
|
||||
if _, ok := srcSet[g]; !ok {
|
||||
continue
|
||||
}
|
||||
if best == "" || g < best {
|
||||
best = g
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func anyGroupMatches(sourceGroups []string, callerSet map[string]struct{}) bool {
|
||||
for _, g := range sourceGroups {
|
||||
if _, ok := callerSet[g]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isUncapped reports whether a policy has any enabled cap with a
|
||||
// positive limit value. Mirrors the eval functions' guards: a policy
|
||||
// with token_limit.enabled=true but every cap value at 0 still
|
||||
// counts as uncapped because the eval would query nothing and bind
|
||||
// nothing.
|
||||
func isUncapped(p *types.Policy) bool {
|
||||
tl := p.Limits.TokenLimit
|
||||
if tl.Enabled && tl.WindowSeconds > 0 && (tl.GroupCap > 0 || tl.UserCap > 0) {
|
||||
return false
|
||||
}
|
||||
bl := p.Limits.BudgetLimit
|
||||
if bl.Enabled && bl.WindowSeconds > 0 && (bl.GroupCapUsd > 0 || bl.UserCapUsd > 0) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// groupCapTokens returns the policy's group-token cap when the token
|
||||
// limit is enabled, zero otherwise. Drives the primary "bigger pool
|
||||
// first" sort.
|
||||
func groupCapTokens(p *types.Policy) int64 {
|
||||
if p.Limits.TokenLimit.Enabled {
|
||||
return p.Limits.TokenLimit.GroupCap
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// groupCapBudgetUsd returns the policy's group-budget cap in USD
|
||||
// when the budget limit is enabled, zero otherwise. Secondary sort
|
||||
// key after token group cap so budget-only policies still order
|
||||
// predictably.
|
||||
func groupCapBudgetUsd(p *types.Policy) float64 {
|
||||
if p.Limits.BudgetLimit.Enabled {
|
||||
return p.Limits.BudgetLimit.GroupCapUsd
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// userCapTokens returns the policy's per-user token cap when the
|
||||
// token limit is enabled, zero otherwise. Tertiary sort key, used
|
||||
// when group caps tie or are absent.
|
||||
func userCapTokens(p *types.Policy) int64 {
|
||||
if p.Limits.TokenLimit.Enabled {
|
||||
return p.Limits.TokenLimit.UserCap
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// userCapBudgetUsd returns the policy's per-user budget cap in USD
|
||||
// when the budget limit is enabled, zero otherwise. Quaternary sort
|
||||
// key for budget-only policies whose group caps tie or are absent.
|
||||
func userCapBudgetUsd(p *types.Policy) float64 {
|
||||
if p.Limits.BudgetLimit.Enabled {
|
||||
return p.Limits.BudgetLimit.UserCapUsd
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func sliceContains(haystack []string, needle string) bool {
|
||||
for _, v := range haystack {
|
||||
if v == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// mockManager fallback so tests that don't care about selection still
|
||||
// compile.
|
||||
func (*mockManager) SelectPolicyForRequest(_ context.Context, _ PolicySelectionInput) (*PolicySelectionResult, error) {
|
||||
return &PolicySelectionResult{Allow: true}, nil
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// GC-2 no-mock enforcement tests for the account-budget ceiling. They drive the
|
||||
// real store + real consumption accounting through SelectPolicyForRequest and
|
||||
// RecordAccountBudgetUsage, asserting min-wins (account binds independently of
|
||||
// policy), targeting (groups + direct users), and the record fan-out.
|
||||
|
||||
func accountWideUserTokenRule(id string, userCap, window int64) *types.AccountBudgetRule {
|
||||
r := types.NewAccountBudgetRule(realSelectAccount)
|
||||
r.ID = id
|
||||
r.Limits.TokenLimit = types.PolicyTokenLimit{Enabled: true, UserCap: userCap, WindowSeconds: window}
|
||||
return r
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AccountCeilingBindsEvenWithUncappedPolicy proves
|
||||
// min-wins: the account user ceiling denies once exhausted even though a
|
||||
// catch-all-allow (uncapped) policy would otherwise pass the request. The
|
||||
// account gate runs independently of and ahead of policy selection.
|
||||
func TestSelectPolicy_RealStore_AccountCeilingBindsEvenWithUncappedPolicy(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// An uncapped (catch-all-allow) policy: enabled token limit, zero caps.
|
||||
uncapped := capPolicy("pol-open", realSelectAccount, []string{"grp-eng"}, "prov-1", 0, 86_400)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, uncapped))
|
||||
|
||||
// Account-wide user ceiling of 100 tokens in an hourly window.
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-1", 100, 3_600)))
|
||||
|
||||
in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"}
|
||||
|
||||
// Fresh: account ceiling has headroom, uncapped policy wins.
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "fresh account ceiling must allow")
|
||||
|
||||
// Drain the account user ceiling via the fan-out path.
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", []string{"grp-eng"}, 100, 0, 0))
|
||||
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "account ceiling must deny even though the policy is uncapped (min-wins)")
|
||||
assert.Equal(t, denyCodeAccountTokenCapExceeded, res.DenyCode, "deny must carry the llm_account.* code")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AccountGroupCeiling proves a group-targeted rule
|
||||
// binds the caller's group dimension.
|
||||
func TestSelectPolicy_RealStore_AccountGroupCeiling(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
rule := types.NewAccountBudgetRule(realSelectAccount)
|
||||
rule.ID = "ainbud-grp"
|
||||
rule.TargetGroups = []string{"grp-eng"}
|
||||
rule.Limits.BudgetLimit = types.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 5.0, WindowSeconds: 2_592_000}
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, rule))
|
||||
|
||||
in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"}
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "fresh group ceiling must allow")
|
||||
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", []string{"grp-eng"}, 0, 0, 5.0))
|
||||
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "group budget ceiling must deny once spent")
|
||||
assert.Equal(t, denyCodeAccountBudgetCapExceeded, res.DenyCode, "account budget deny code")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AccountTargetUsersBindsOnlyThatUser proves a
|
||||
// TargetUsers rule tightens only the named user, leaving others unbound.
|
||||
func TestSelectPolicy_RealStore_AccountTargetUsersBindsOnlyThatUser(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
rule := types.NewAccountBudgetRule(realSelectAccount)
|
||||
rule.ID = "ainbud-alice"
|
||||
rule.TargetUsers = []string{"alice"}
|
||||
rule.Limits.TokenLimit = types.PolicyTokenLimit{Enabled: true, UserCap: 100, WindowSeconds: 3_600}
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, rule))
|
||||
|
||||
// Record alice's usage to the rule window.
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "alice", nil, 100, 0, 0))
|
||||
|
||||
aliceIn := PolicySelectionInput{AccountID: realSelectAccount, UserID: "alice", ProviderID: "prov-1"}
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, aliceIn)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "alice is bound by the TargetUsers rule and is exhausted")
|
||||
|
||||
bobIn := PolicySelectionInput{AccountID: realSelectAccount, UserID: "bob", ProviderID: "prov-1"}
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, bobIn)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "bob is not in TargetUsers, so the rule must not bind him")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AccountRuleRecordsToOwnWindow proves the record
|
||||
// fan-out books usage in the rule's own window (distinct from any policy
|
||||
// window), so the account ceiling accumulates independently.
|
||||
func TestSelectPolicy_RealStore_AccountRuleRecordsToOwnWindow(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-w", 100, 3_600)))
|
||||
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", nil, 60, 0, 0))
|
||||
|
||||
// Same user, a policy-style daily window must NOT see the account-window
|
||||
// usage — windows are independent counters.
|
||||
dailyRow, err := s.GetAgentNetworkConsumption(ctx, store.LockingStrengthNone, realSelectAccount, types.DimensionUser, "user-1", 86_400, types.WindowStart(time.Now().UTC(), 86_400))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), dailyRow.TokensInput+dailyRow.TokensOutput, "daily window must be untouched by the hourly account-rule record")
|
||||
|
||||
// A second record pushes the hourly account window to its cap → deny.
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", nil, 40, 0, 0))
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", ProviderID: "prov-1"})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "100 tokens recorded in the rule's hourly window must exhaust the 100-token ceiling")
|
||||
assert.Equal(t, denyCodeAccountTokenCapExceeded, res.DenyCode, "account token deny code")
|
||||
}
|
||||
|
||||
// TestRecordUsage_RealStore_BooksPolicyAndAccountWindows proves the batched
|
||||
// post-flight write books the selected policy's window AND every applicable
|
||||
// account rule's (independent) window in a single call — the #6 batched-write
|
||||
// path the proxy's RecordLLMUsage RPC now uses.
|
||||
func TestRecordUsage_RealStore_BooksPolicyAndAccountWindows(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Policy: 100-token group cap on a daily window. Account rule: 100-token
|
||||
// user ceiling on an hourly window — an independent counter.
|
||||
policy := capPolicy("pol-1", realSelectAccount, []string{"grp-eng"}, "prov-1", 100, 86_400)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, policy))
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-1", 100, 3_600)))
|
||||
|
||||
in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"}
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
require.True(t, res.Allow)
|
||||
require.Equal(t, "pol-1", res.SelectedPolicyID)
|
||||
|
||||
// One batched record books the policy window (group + user @86400) and the
|
||||
// account rule window (user @3600) atomically.
|
||||
require.NoError(t, mgr.RecordUsage(ctx, RecordUsageInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
AttributionGroupID: res.AttributionGroupID,
|
||||
GroupIDs: []string{"grp-eng"},
|
||||
WindowSeconds: res.WindowSeconds,
|
||||
TokensIn: 100,
|
||||
}))
|
||||
|
||||
// The next selection denies — the account hourly ceiling binds first.
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "usage booked by RecordUsage must enforce on the next request")
|
||||
|
||||
// Prove BOTH windows were booked in the one call via a direct batch read.
|
||||
now := time.Now().UTC()
|
||||
userKey := types.ConsumptionKey{Kind: types.DimensionUser, DimID: "user-1", WindowSeconds: 3_600, WindowStartUTC: types.WindowStart(now, 3_600)}
|
||||
groupKey := types.ConsumptionKey{Kind: types.DimensionGroup, DimID: "grp-eng", WindowSeconds: 86_400, WindowStartUTC: types.WindowStart(now, 86_400)}
|
||||
rows, err := s.GetAgentNetworkConsumptionBatch(ctx, store.LockingStrengthNone, realSelectAccount, []types.ConsumptionKey{userKey, groupKey})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, rows, userKey, "account rule user/hourly window booked")
|
||||
require.Contains(t, rows, groupKey, "policy group/daily window booked")
|
||||
assert.Equal(t, int64(100), rows[userKey].TokensInput, "account hourly user counter")
|
||||
assert.Equal(t, int64(100), rows[groupKey].TokensInput, "policy daily group counter")
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// This file is the no-mock regression guard for policy limit enforcement.
|
||||
// policyselect_test.go pins the same behavior through a gomock store with
|
||||
// explicit call-sequence expectations — brittle precisely where the upcoming
|
||||
// account-budget work (GC-2) refactors the cap-eval primitive and adds an
|
||||
// account-level gate. These tests drive the REAL sqlite store + REAL
|
||||
// consumption accounting and assert observable behavior (allow / deny /
|
||||
// selection / attribution), not which store methods get called. They must keep
|
||||
// passing unchanged after GC-2 lands, which is what proves "current behavior is
|
||||
// not changed."
|
||||
|
||||
const realSelectAccount = "acc-realselect-1"
|
||||
|
||||
// newRealSelectorMgr builds a managerImpl backed by a real sqlite test store.
|
||||
func newRealSelectorMgr(t *testing.T) (*managerImpl, store.Store) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
t.Cleanup(cleanup)
|
||||
return &managerImpl{store: s}, s
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_NoApplicablePolicies pins the pass-through:
|
||||
// nothing targets the (provider, groups) combination, so the selector allows
|
||||
// without attribution or consumption tracking.
|
||||
func TestSelectPolicy_RealStore_NoApplicablePolicies(t *testing.T) {
|
||||
mgr, _ := newRealSelectorMgr(t)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-x"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "no applicable policy must pass through as allow")
|
||||
assert.Empty(t, res.SelectedPolicyID, "no selection when nothing applies")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AllowAndLowestGroupAttribution pins the v1
|
||||
// attribution rule (lowest intersecting group by string sort) through the
|
||||
// real store, with a fresh (zero) consumption row.
|
||||
func TestSelectPolicy_RealStore_AllowAndLowestGroupAttribution(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
p := capPolicy("pol-A", realSelectAccount, []string{"grp-zz", "grp-aa", "grp-mm"}, "prov-1", 10_000, 86_400)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p))
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-zz", "grp-aa", "grp-mm"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "fresh state under cap must allow")
|
||||
assert.Equal(t, "pol-A", res.SelectedPolicyID, "only applicable policy must be selected")
|
||||
assert.Equal(t, "grp-aa", res.AttributionGroupID, "lowest-by-sort intersecting group must win")
|
||||
assert.Equal(t, int64(86_400), res.WindowSeconds, "selected policy's window must be returned")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_LargerPoolWins_FallsThroughWhenExhausted pins the
|
||||
// core selection behavior end to end. The two policies bind DISTINCT groups so
|
||||
// they read separate counters — the only shape where fall-through actually
|
||||
// yields headroom (policies on the same group share one counter, as
|
||||
// policyselect_test.go notes). Larger pool wins fresh; after real consumption
|
||||
// drains the larger group, selection falls through to the smaller; once both
|
||||
// counters are exhausted the request is denied.
|
||||
func TestSelectPolicy_RealStore_LargerPoolWins_FallsThroughWhenExhausted(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
tight := capPolicy("pol-tight", realSelectAccount, []string{"grp-tight"}, "prov-1", 100, 86_400)
|
||||
tight.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
wide := capPolicy("pol-wide", realSelectAccount, []string{"grp-wide"}, "prov-1", 10_000, 86_400)
|
||||
wide.CreatedAt = time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, tight))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, wide))
|
||||
|
||||
// Caller is in both groups, so both policies apply with independent counters.
|
||||
in := PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-tight", "grp-wide"},
|
||||
ProviderID: "prov-1",
|
||||
}
|
||||
|
||||
// Fresh: larger pool wins.
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-wide", res.SelectedPolicyID, "larger pool drains first")
|
||||
|
||||
// Drain only the wide group's counter to its cap.
|
||||
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-wide", 86_400, 10_000, 0, 0))
|
||||
|
||||
// Wide exhausted, tight's separate counter is fresh → fall through to tight.
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "tight pool has its own untouched counter")
|
||||
assert.Equal(t, "pol-tight", res.SelectedPolicyID, "selection falls through to the smaller pool once the larger is exhausted")
|
||||
|
||||
// Drain the tight group's counter too → both exhausted → deny.
|
||||
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-tight", 86_400, 100, 0, 0))
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "both group counters exhausted must deny")
|
||||
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode, "deny code names the offending cap kind")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_BudgetCapDenies pins budget (USD) enforcement
|
||||
// through the real store: once recorded cost reaches the cap, deny.
|
||||
func TestSelectPolicy_RealStore_BudgetCapDenies(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
p := &types.Policy{
|
||||
ID: "pol-budget",
|
||||
AccountID: realSelectAccount,
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-eng"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
BudgetLimit: types.PolicyBudgetLimit{
|
||||
Enabled: true,
|
||||
GroupCapUsd: 5.0,
|
||||
WindowSeconds: 86_400,
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p))
|
||||
|
||||
in := PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-eng"},
|
||||
ProviderID: "prov-1",
|
||||
}
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "fresh budget must allow")
|
||||
|
||||
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-eng", 86_400, 0, 0, 5.0))
|
||||
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "cost at the cap must deny")
|
||||
assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode, "budget deny code must be surfaced")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_GroupCounterSharedAcrossPolicies pins that two
|
||||
// policies on the same group+window read one shared consumption counter: usage
|
||||
// recorded once is visible to both, so exhausting the group budget denies
|
||||
// regardless of which policy would attribute.
|
||||
func TestSelectPolicy_RealStore_GroupCounterSharedAcrossPolicies(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
a := capPolicy("pol-a", realSelectAccount, []string{"grp-eng"}, "prov-1", 1_000, 86_400)
|
||||
b := capPolicy("pol-b", realSelectAccount, []string{"grp-eng"}, "prov-1", 1_000, 86_400)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, a))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, b))
|
||||
|
||||
in := PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-eng"},
|
||||
ProviderID: "prov-1",
|
||||
}
|
||||
|
||||
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-eng", 86_400, 1_000, 0, 0))
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "shared group counter at cap denies both equal policies")
|
||||
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode, "token deny code on the shared counter")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_DisabledPolicyIgnored pins that a disabled policy
|
||||
// is invisible to selection even when it otherwise matches.
|
||||
func TestSelectPolicy_RealStore_DisabledPolicyIgnored(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
p := capPolicy("pol-disabled", realSelectAccount, []string{"grp-eng"}, "prov-1", 10_000, 86_400)
|
||||
p.Enabled = false
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p))
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-eng"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "no enabled policy applies → pass-through allow")
|
||||
assert.Empty(t, res.SelectedPolicyID, "disabled policy must not be selected")
|
||||
}
|
||||
641
management/internals/modules/agentnetwork/policyselect_test.go
Normal file
641
management/internals/modules/agentnetwork/policyselect_test.go
Normal file
@@ -0,0 +1,641 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
nbstatus "github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func newSelectorMgr(t *testing.T, ctrl *gomock.Controller) (*managerImpl, *store.MockStore) {
|
||||
t.Helper()
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
// SelectPolicyForRequest evaluates the account-budget ceiling before policy
|
||||
// selection. These policy-selection tests don't exercise account rules, so
|
||||
// default to "no rules" — the no-mock policyselect_realstore_test.go covers
|
||||
// the account gate's behavior end to end.
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkBudgetRules(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nil, nil).
|
||||
AnyTimes()
|
||||
return &managerImpl{store: mockStore}, mockStore
|
||||
}
|
||||
|
||||
type usedKey struct {
|
||||
kind types.ConsumptionDimension
|
||||
dimID string
|
||||
window int64
|
||||
}
|
||||
|
||||
// expectConsumptionBatch stubs the batched consumption read to return the
|
||||
// supplied per-(kind, dim, window) counters, filling each row's window start
|
||||
// from the actual request keys so it always matches what the selector computed.
|
||||
// Keys absent from used resolve to zero counters.
|
||||
func expectConsumptionBatch(mockStore *store.MockStore, used map[usedKey]*types.Consumption) {
|
||||
mockStore.EXPECT().
|
||||
GetAgentNetworkConsumptionBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, _ store.LockingStrength, _ string, keys []types.ConsumptionKey) (map[types.ConsumptionKey]*types.Consumption, error) {
|
||||
out := make(map[types.ConsumptionKey]*types.Consumption)
|
||||
for _, k := range keys {
|
||||
if row, ok := used[usedKey{k.Kind, k.DimID, k.WindowSeconds}]; ok {
|
||||
rc := *row
|
||||
rc.WindowStartUTC = k.WindowStartUTC
|
||||
out[k] = &rc
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}).
|
||||
AnyTimes()
|
||||
}
|
||||
|
||||
func capPolicy(id, account string, sourceGroups []string, providerID string, tokenCap int64, windowSec int64) *types.Policy {
|
||||
return &types.Policy{
|
||||
ID: id,
|
||||
AccountID: account,
|
||||
Enabled: true,
|
||||
SourceGroups: sourceGroups,
|
||||
DestinationProviderIDs: []string{providerID},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{
|
||||
Enabled: true,
|
||||
GroupCap: tokenCap,
|
||||
WindowSeconds: windowSec,
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectPolicy_NoApplicablePolicies covers the pass-through path:
|
||||
// llm_router authorisation is upstream of selection; when the
|
||||
// selector finds no policy targeting the (provider, caller-groups)
|
||||
// combination, it returns Allow with no attribution and lets the
|
||||
// request continue without consumption tracking.
|
||||
func TestSelectPolicy_NoApplicablePolicies(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{}, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-x"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "no applicable policies = pass-through allow")
|
||||
assert.Empty(t, res.SelectedPolicyID, "no selection when nothing applies")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_AllowWithLowestGroupAttribution proves the v1
|
||||
// attribution rule: when the caller's groups intersect a policy's
|
||||
// source_groups in multiple positions, the selector picks the lowest
|
||||
// group id by string sort so multi-node selection converges.
|
||||
func TestSelectPolicy_AllowWithLowestGroupAttribution(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policy := capPolicy("pol-A", "acc-1", []string{"grp-zz", "grp-aa", "grp-mm"}, "prov-1", 10_000, 86_400)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policy}, nil)
|
||||
// Fresh: zero consumption across the board.
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-zz", "grp-aa", "grp-mm"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow)
|
||||
assert.Equal(t, "pol-A", res.SelectedPolicyID)
|
||||
assert.Equal(t, "grp-aa", res.AttributionGroupID,
|
||||
"lowest-by-sort intersection wins so multi-node selection converges")
|
||||
assert.Equal(t, int64(86_400), res.WindowSeconds)
|
||||
}
|
||||
|
||||
// TestSelectPolicy_LargerPoolWinsAcrossUsageLevels proves the core
|
||||
// selection rule: among multiple applicable policies with caps, the
|
||||
// selector picks the one with the larger absolute pool — at every
|
||||
// usage level, not just at fresh state. The smaller-pool policy is
|
||||
// only reached when the larger one is exhausted. This is the
|
||||
// "drain biggest first" semantic operators expect for layered
|
||||
// tiers; a fraction-based score would flap between the two as
|
||||
// soon as one is partially used.
|
||||
func TestSelectPolicy_LargerPoolWinsAcrossUsageLevels(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
tight := capPolicy("pol-tight", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400)
|
||||
tight.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
wide := capPolicy("pol-wide", "acc-1", []string{"grp-engineers"}, "prov-1", 10_000, 86_400)
|
||||
wide.CreatedAt = time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{tight, wide}, nil)
|
||||
|
||||
// Both partially used. tight at 50/100 (50% used); wide at
|
||||
// 50/10000 (0.5% used). Old fraction-based algo would pick wide
|
||||
// here too — but for the wrong reason ("more relative slack").
|
||||
// New algo picks wide because its initial group cap is bigger
|
||||
// (10000 > 100), and that decision is stable as wide drains.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 50},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-wide", res.SelectedPolicyID,
|
||||
"the policy with the bigger initial pool wins — operators expect 'drain the privileged tier first', not load-balance across tiers")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_StaysOnLargerPoolAfterPartialDrain locks the
|
||||
// stickiness contract reported by operators: with two policies
|
||||
// where A has a 200-token group cap and B has 150, the very first
|
||||
// request goes to A AND every subsequent request continues to land
|
||||
// on A until A's group cap is exhausted — at which point B becomes
|
||||
// the only candidate. A fraction-based score would flap to B as
|
||||
// soon as A had any consumption (B's 1.0 fraction beats A's 0.75)
|
||||
// even though A still has more absolute headroom; that produced
|
||||
// confusing per-policy attribution ledger entries and stranded
|
||||
// A's remaining capacity behind B's exhaustion.
|
||||
func TestSelectPolicy_StaysOnLargerPoolAfterPartialDrain(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policyA := capPolicy("pol-A-200", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400)
|
||||
policyB := capPolicy("pol-B-150", "acc-1", []string{"grp-engineers"}, "prov-1", 150, 86_400)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policyA, policyB}, nil)
|
||||
|
||||
// A is partially drained (50/200 used = 25% used; 75% headroom
|
||||
// remaining). B is fresh (0/150). The old fraction-based score
|
||||
// would pick B here (1.0 > 0.75 fraction); the new pool-size
|
||||
// score sticks with A (200 > 150 absolute cap).
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 50},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-A-200", res.SelectedPolicyID,
|
||||
"once attribution lands on the bigger pool it must STAY there until exhausted — operators expect 'drain A then B', not 'flip to B as soon as A is touched'")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_FallsThroughToSmallerPoolWhenLargerExhausted
|
||||
// proves the second half of the stickiness contract: once the
|
||||
// larger-pool policy IS exhausted, the smaller one takes over.
|
||||
// Without this we'd deny on requests the smaller policy is fully
|
||||
// equipped to serve.
|
||||
func TestSelectPolicy_FallsThroughToSmallerPoolWhenLargerExhausted(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policyA := capPolicy("pol-A-200", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400)
|
||||
// B uses a different window length so it has an INDEPENDENT counter — the
|
||||
// realistic shape for fall-through. On the SAME (group, window) tuple the
|
||||
// counter is shared, so A's cap of 200 being reached would also exhaust B's
|
||||
// 150; independent counters are what let A exhaust while B retains headroom.
|
||||
policyB := capPolicy("pol-B-150", "acc-1", []string{"grp-engineers"}, "prov-1", 150, 3_600)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policyA, policyB}, nil)
|
||||
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 200}, // A: 200 >= 200 → exhausted
|
||||
{types.DimensionGroup, "grp-engineers", 3_600}: {TokensInput: 100}, // B: 100 < 150 → headroom
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-B-150", res.SelectedPolicyID,
|
||||
"once the bigger pool is exhausted, the smaller one must take over — denying when capacity remains would strand B's allowance")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_TiebreakByLargerGroupPool covers the user-reported
|
||||
// bug: an admin in two groups (Users + Admins) where Users is bound
|
||||
// by a smaller-group-cap policy (50 group, 100 user) and Admins is
|
||||
// bound by a bigger-group-cap policy (100 group, 20 user) MUST get
|
||||
// attributed to the Admins policy on the first request.
|
||||
//
|
||||
// Without this rule, the fresh-state fraction is 1.0 for both and
|
||||
// the older policy wins by created_at. The first 24-token request
|
||||
// then drains the shared user counter past Admins's tight 20-token
|
||||
// user cap, locking Admins out of selection forever. The 100-token
|
||||
// Admins group pool ends up stranded while requests pile onto the
|
||||
// 50-token Users pool — the opposite of what the operator intended
|
||||
// when they put the bigger pool on the privileged group.
|
||||
func TestSelectPolicy_TiebreakByLargerGroupPool(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
// Policy A: Users group, smaller group pool, looser per-user cap.
|
||||
policyA := &types.Policy{
|
||||
ID: "pol-Users",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-Users"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{
|
||||
Enabled: true, GroupCap: 50, UserCap: 100, WindowSeconds: 86_400,
|
||||
},
|
||||
},
|
||||
// Older — would win the legacy created_at tiebreak.
|
||||
CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
// Policy B: Admins group, bigger group pool, tighter per-user cap.
|
||||
policyB := &types.Policy{
|
||||
ID: "pol-Admins",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-Admins"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{
|
||||
Enabled: true, GroupCap: 100, UserCap: 20, WindowSeconds: 86_400,
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policyA, policyB}, nil)
|
||||
// Fresh state: every cap evaluation reads zero usage.
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-Users", "grp-Admins"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-Admins", res.SelectedPolicyID,
|
||||
"the bigger group pool wins the fresh-state tiebreak — picking Users first would burn the shared user counter past Admins's tight user cap on the very first request and strand the bigger Admins pool")
|
||||
assert.Equal(t, "grp-Admins", res.AttributionGroupID)
|
||||
}
|
||||
|
||||
// TestSelectPolicy_TiebreakByCreatedAt proves the deterministic
|
||||
// final tiebreak: when two applicable policies have the same
|
||||
// headroom fraction AND the same group cap (so the larger-pool rule
|
||||
// can't differentiate either), the older policy wins so attribution
|
||||
// is stable across replays.
|
||||
func TestSelectPolicy_TiebreakByCreatedAt(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
older := capPolicy("pol-old", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400)
|
||||
older.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
newer := capPolicy("pol-new", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400)
|
||||
newer.CreatedAt = time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{newer, older}, nil)
|
||||
// Both at zero consumption → identical headroom fraction.
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-old", res.SelectedPolicyID,
|
||||
"older policy wins on equal-headroom tiebreak so attribution is stable across replays")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_DeniesWhenAllExhausted proves the deny envelope:
|
||||
// when every applicable policy has at least one cap fully exhausted,
|
||||
// the selector returns Allow=false with the most-recent exhaustion's
|
||||
// deny code + human reason. The proxy's middleware surfaces this as
|
||||
// a 403 with the canonical llm_policy.* code.
|
||||
func TestSelectPolicy_DeniesWhenAllExhausted(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
a := capPolicy("pol-a", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400)
|
||||
b := capPolicy("pol-b", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{a, b}, nil)
|
||||
|
||||
// Shared group counter at 200: A (cap 100) and B (cap 200) both exhausted.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 200},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "every applicable policy exhausted = deny")
|
||||
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode)
|
||||
assert.Contains(t, res.DenyReason, "token cap exhausted",
|
||||
"deny reason must name the exhausted cap kind for operator debugging")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_UncappedPolicyAlwaysWinsAgainstCapped proves the
|
||||
// catch-all-allow contract: a policy with NO enabled caps wins
|
||||
// against any capped policy regardless of how much headroom the
|
||||
// capped one has, because operators who configure unlimited access
|
||||
// expect requests to attribute there until they explicitly add caps.
|
||||
func TestSelectPolicy_UncappedPolicyAlwaysWinsAgainstCapped(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
uncapped := &types.Policy{
|
||||
ID: "pol-uncapped",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-engineers"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
// All Limits.*.Enabled = false (zero-value).
|
||||
CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
wide := capPolicy("pol-wide", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000_000, 86_400)
|
||||
wide.CreatedAt = time.Date(2025, 12, 1, 0, 0, 0, 0, time.UTC) // older than uncapped
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{uncapped, wide}, nil)
|
||||
// Only the wide policy reads consumption; uncapped doesn't query
|
||||
// because it has no enabled caps.
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-uncapped", res.SelectedPolicyID,
|
||||
"a no-caps policy must always win selection — that's how operators express 'unlimited access through this path'")
|
||||
assert.Equal(t, int64(0), res.WindowSeconds, "no caps configured = WindowSeconds=0 so RecordLLMUsage skips counter writes")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_DisabledPolicyIgnored proves disabled policies
|
||||
// don't count toward selection — even when they'd otherwise be the
|
||||
// best match. Operators disable a policy to take it offline; the
|
||||
// selector must respect that and route through whatever's left.
|
||||
func TestSelectPolicy_DisabledPolicyIgnored(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
disabled := capPolicy("pol-disabled", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000_000, 86_400)
|
||||
disabled.Enabled = false
|
||||
enabled := capPolicy("pol-enabled", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{disabled, enabled}, nil)
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-enabled", res.SelectedPolicyID,
|
||||
"disabled policies must be ignored at selection time")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_StoreErrorPropagates locks the no-fail-open
|
||||
// contract: a transient store error must surface to the caller, not
|
||||
// be silently treated as "no policies = allow". A false allow on the
|
||||
// hot path would let a request slip past every cap.
|
||||
func TestSelectPolicy_StoreErrorPropagates(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return(nil, errors.New("boom"))
|
||||
|
||||
_, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
})
|
||||
require.Error(t, err, "store errors must surface — never fail open on the hot path")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RejectsEmptyAccount is the input-validation guard:
|
||||
// empty account_id is a programmer error and must surface as
|
||||
// InvalidArgument, not as a silent zero-result lookup.
|
||||
func TestSelectPolicy_RejectsEmptyAccount(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, _ := newSelectorMgr(t, ctrl)
|
||||
|
||||
_, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{})
|
||||
require.Error(t, err)
|
||||
var sErr *nbstatus.Error
|
||||
require.True(t, errors.As(err, &sErr))
|
||||
assert.Equal(t, nbstatus.InvalidArgument, sErr.Type())
|
||||
}
|
||||
|
||||
// TestSelectPolicy_SharesGroupCounterAcrossPolicies locks the
|
||||
// counter-keying design fork: counters are keyed on (account,
|
||||
// dim_kind, dim_id, window_hours, window_start) — NOT on policy_id.
|
||||
// Two policies that target the same group with the SAME window length
|
||||
// share one bucket: spend booked under policy A is visible to policy
|
||||
// B's headroom calculation and counts toward B's cap.
|
||||
//
|
||||
// This is what makes "operator's per-group enforcement" sane — caps
|
||||
// describe how much a GROUP can use, not how much each policy owes.
|
||||
func TestSelectPolicy_SharesGroupCounterAcrossPolicies(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
// Two policies, both targeting grp-engineers + prov-1, same 24h
|
||||
// window length. Different cap sizes.
|
||||
policyA := capPolicy("pol-A", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400)
|
||||
policyB := capPolicy("pol-B", "acc-1", []string{"grp-engineers"}, "prov-1", 5_000, 86_400)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policyA, policyB}, nil)
|
||||
// Both policies query the SAME consumption row — same dim_id,
|
||||
// same window_hours, same window_start. The mock returns the
|
||||
// same row for both calls, simulating the shared counter.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 800},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// 800 used → policy A has 200 tokens left of 1000 (20% headroom);
|
||||
// policy B has 4200 left of 5000 (84% headroom). B wins.
|
||||
assert.Equal(t, "pol-B", res.SelectedPolicyID,
|
||||
"the SAME 800 tokens count toward both policies — counters share the (group, window) key, caps differ per policy")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_AntiFallThroughOnLowestGroup locks the no-fall-
|
||||
// through behaviour: when a caller is in multiple of a policy's
|
||||
// source_groups and the lowest-by-sort group is exhausted, we DENY
|
||||
// rather than fall through to a less-loaded sibling. Per-group caps
|
||||
// are independent (each group has its own bucket), but attribution
|
||||
// is one-shot — operators wanting fall-through must split into
|
||||
// separate policies.
|
||||
//
|
||||
// This nails down semantics future contributors might "improve" into
|
||||
// fall-through behaviour by accident.
|
||||
func TestSelectPolicy_AntiFallThroughOnLowestGroup(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
// Policy targets two groups; caller is in both.
|
||||
policy := capPolicy("pol-1", "acc-1", []string{"grp-aaa", "grp-bbb"}, "prov-1", 100, 86_400)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policy}, nil)
|
||||
|
||||
// grp-aaa is the lowest by sort → attribution picks it, and the
|
||||
// prefetch only collects the attribution group's key. We exhaust
|
||||
// grp-aaa (100/100); grp-bbb's counter is never requested because the
|
||||
// selector attributes one-shot to the lowest group, so it can't fall
|
||||
// through to a less-loaded sibling.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-aaa", 86_400}: {TokensInput: 100},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-aaa", "grp-bbb"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow,
|
||||
"lowest-group-by-sort attribution does NOT fall through to a less-loaded sibling — operators wanting fall-through must split into separate policies")
|
||||
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode)
|
||||
assert.Contains(t, res.DenyReason, "pol-1",
|
||||
"deny reason names the exhausted policy id so operators can grep it from the access log")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_BudgetOnlyExhaustionDenies covers the symmetric
|
||||
// path to TestSelectPolicy_DeniesWhenAllExhausted but for the budget
|
||||
// cap: a policy with token_limit DISABLED and budget_limit at-cap
|
||||
// must deny with llm_policy.budget_cap_exceeded (not the token code).
|
||||
//
|
||||
// Without this, the budget evaluation path in evalBudgetCap could
|
||||
// silently regress and we'd still pass DeniesWhenAllExhausted (which
|
||||
// only exercises tokens).
|
||||
func TestSelectPolicy_BudgetOnlyExhaustionDenies(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policy := &types.Policy{
|
||||
ID: "pol-budget",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-engineers"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{Enabled: false},
|
||||
BudgetLimit: types.PolicyBudgetLimit{
|
||||
Enabled: true,
|
||||
GroupCapUsd: 10.00,
|
||||
WindowSeconds: 86_400,
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policy}, nil)
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {CostUSD: 10.50}, // over the $10 cap
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "budget cap exhausted must deny independently of any token cap state")
|
||||
assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode,
|
||||
"deny code must be the budget code — token-only deny would silently regress the budget evaluation path")
|
||||
assert.Contains(t, res.DenyReason, "budget", "deny reason names the budget cap kind for operator debugging")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_BudgetTighterThanTokenWins is the dual-cap headroom
|
||||
// fork: when both Token and Budget are enabled on the same policy,
|
||||
// the SMALLER remaining ratio gates the policy. A policy with
|
||||
// abundant token headroom but near-zero budget headroom must deny on
|
||||
// budget, not pass on tokens.
|
||||
func TestSelectPolicy_BudgetTighterThanTokenWins(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policy := &types.Policy{
|
||||
ID: "pol-dual",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-engineers"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{Enabled: true, GroupCap: 10_000_000, WindowSeconds: 86_400},
|
||||
BudgetLimit: types.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 1.00, WindowSeconds: 86_400},
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policy}, nil)
|
||||
// One shared counter carries both token usage (ample headroom) and cost
|
||||
// (at the $1 budget cap); the tighter budget cap gates the policy.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 100, CostUSD: 1.00},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow,
|
||||
"the tighter of (token, budget) wins — abundant token headroom must NOT mask an exhausted budget")
|
||||
assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode)
|
||||
}
|
||||
131
management/internals/modules/agentnetwork/reconcile.go
Normal file
131
management/internals/modules/agentnetwork/reconcile.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// reconcile recomputes the synthesised reverse-proxy services for an
|
||||
// account, diffs them against the previously-synthesised set in the
|
||||
// in-memory cache, and emits Create / Update / Delete proxy mappings
|
||||
// to the affected clusters. Also triggers a peer-side network-map
|
||||
// recompute via accountManager.UpdateAccountPeers so the
|
||||
// private-service ACL injection picks up the new state immediately.
|
||||
//
|
||||
// Reconcile failures are logged and swallowed — the underlying CRUD
|
||||
// has already completed, and the next mutation (or proxy reconnect)
|
||||
// will re-converge the cluster's view.
|
||||
func (m *managerImpl) reconcile(ctx context.Context, accountID string) {
|
||||
if accountID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if m.accountManager != nil {
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{
|
||||
Resource: types.UpdateResourceService,
|
||||
Operation: types.UpdateOperationUpdate,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
if m.proxyController == nil {
|
||||
return
|
||||
}
|
||||
|
||||
services, err := SynthesizeServices(ctx, m.store, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).WithError(err).Warnf("agent-network reconcile: synthesise services for account %s", accountID)
|
||||
return
|
||||
}
|
||||
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
current := make(map[string]*proto.ProxyMapping, len(services))
|
||||
for _, svc := range services {
|
||||
if svc == nil || svc.ID == "" {
|
||||
continue
|
||||
}
|
||||
current[svc.ID] = svc.ToProtoMapping(rpservice.Update, "", oidcCfg)
|
||||
}
|
||||
|
||||
m.reconcileMu.Lock()
|
||||
previous := m.reconcileCache[accountID]
|
||||
if previous == nil {
|
||||
previous = make(map[string]*proto.ProxyMapping)
|
||||
}
|
||||
|
||||
creates, updates, deletes := diffMappings(previous, current)
|
||||
if len(current) == 0 {
|
||||
delete(m.reconcileCache, accountID)
|
||||
} else {
|
||||
m.reconcileCache[accountID] = current
|
||||
}
|
||||
m.reconcileMu.Unlock()
|
||||
|
||||
for _, mapping := range creates {
|
||||
mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping))
|
||||
}
|
||||
for _, mapping := range updates {
|
||||
mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping))
|
||||
}
|
||||
for _, mapping := range deletes {
|
||||
mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping))
|
||||
}
|
||||
}
|
||||
|
||||
// diffMappings classifies the previous→current transition for a
|
||||
// single account into Create / Update / Delete sets.
|
||||
//
|
||||
// Cluster moves (current.cluster != previous.cluster) are surfaced as
|
||||
// a Delete on the old cluster + Create on the new — handled by
|
||||
// emitting both a delete (on previous mapping) and a create (on the
|
||||
// current mapping) for that service ID.
|
||||
func diffMappings(previous, current map[string]*proto.ProxyMapping) (creates, updates, deletes []*proto.ProxyMapping) {
|
||||
for id, cur := range current {
|
||||
prev, existed := previous[id]
|
||||
switch {
|
||||
case !existed:
|
||||
creates = append(creates, cur)
|
||||
case prev.GetDomain() == "" || cur.GetAccountId() == prev.GetAccountId() && currentClusterChanged(prev, cur):
|
||||
deletes = append(deletes, prev)
|
||||
creates = append(creates, cur)
|
||||
default:
|
||||
updates = append(updates, cur)
|
||||
}
|
||||
}
|
||||
for id, prev := range previous {
|
||||
if _, stillThere := current[id]; !stillThere {
|
||||
deletes = append(deletes, prev)
|
||||
}
|
||||
}
|
||||
return creates, updates, deletes
|
||||
}
|
||||
|
||||
func currentClusterChanged(prev, cur *proto.ProxyMapping) bool {
|
||||
return clusterFromMapping(prev) != clusterFromMapping(cur)
|
||||
}
|
||||
|
||||
// clusterFromMapping returns the cluster the mapping should be sent
|
||||
// to. ProxyMapping doesn't carry the cluster directly, so we rely on
|
||||
// the synthesised service's domain (`<slug>.<cluster>`) and split on
|
||||
// the first '.'.
|
||||
func clusterFromMapping(m *proto.ProxyMapping) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
domain := m.GetDomain()
|
||||
for i := 0; i < len(domain); i++ {
|
||||
if domain[i] == '.' {
|
||||
return domain[i+1:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
232
management/internals/modules/agentnetwork/reconcile_test.go
Normal file
232
management/internals/modules/agentnetwork/reconcile_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func newReconcileMgr(t *testing.T, ctrl *gomock.Controller) (*managerImpl, *store.MockStore, *proxy.MockController) {
|
||||
t.Helper()
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockProxy := proxy.NewMockController(ctrl)
|
||||
return &managerImpl{
|
||||
store: mockStore,
|
||||
proxyController: mockProxy,
|
||||
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
|
||||
}, mockStore, mockProxy
|
||||
}
|
||||
|
||||
func newReconcileTestProvider() *types.Provider {
|
||||
return &types.Provider{
|
||||
ID: "prov-1",
|
||||
AccountID: "acct-1",
|
||||
ProviderID: "openai_api",
|
||||
Name: "OpenAI",
|
||||
UpstreamURL: "https://api.openai.com",
|
||||
APIKey: "sk-test-key",
|
||||
Enabled: true,
|
||||
SessionPrivateKey: "test-priv-key",
|
||||
SessionPublicKey: "test-pub-key",
|
||||
}
|
||||
}
|
||||
|
||||
func newReconcileTestPolicy(providerID, sourceGroupID string) *types.Policy {
|
||||
return &types.Policy{
|
||||
ID: "pol-1",
|
||||
AccountID: "acct-1",
|
||||
Name: "engineers",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{sourceGroupID},
|
||||
DestinationProviderIDs: []string{providerID},
|
||||
}
|
||||
}
|
||||
|
||||
func newReconcileTestSettings() *types.Settings {
|
||||
return &types.Settings{
|
||||
AccountID: "acct-1",
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
Subdomain: "violet",
|
||||
}
|
||||
}
|
||||
|
||||
func expectReconcileSynthInputs(mockStore *store.MockStore, ctx context.Context, providers []*types.Provider, policies []*types.Policy, guardrails []*types.Guardrail) {
|
||||
mockStore.EXPECT().
|
||||
GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(newReconcileTestSettings(), nil)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(providers, nil)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(policies, nil)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(guardrails, nil)
|
||||
}
|
||||
|
||||
func TestReconcile_FirstSynth_EmitsCreate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl)
|
||||
provider := newReconcileTestProvider()
|
||||
policy := newReconcileTestPolicy(provider.ID, "grp-eng")
|
||||
|
||||
expectReconcileSynthInputs(mockStore, ctx, []*types.Provider{provider}, []*types.Policy{policy}, []*types.Guardrail{})
|
||||
mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{})
|
||||
|
||||
var sentMappings []*proto.ProxyMapping
|
||||
mockProxy.EXPECT().
|
||||
SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), "eu.proxy.netbird.io").
|
||||
Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) {
|
||||
sentMappings = append(sentMappings, m)
|
||||
})
|
||||
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
|
||||
require.Len(t, sentMappings, 1, "first synth must emit one mapping")
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, sentMappings[0].Type, "first synth is a Create")
|
||||
assert.Equal(t, "agent-net-svc-acct-1", sentMappings[0].Id, "stable account-scoped virtual service id")
|
||||
assert.Equal(t, "violet.eu.proxy.netbird.io", sentMappings[0].Domain, "domain comes from settings (subdomain.cluster)")
|
||||
|
||||
mgr.reconcileMu.Lock()
|
||||
cached := mgr.reconcileCache["acct-1"]
|
||||
mgr.reconcileMu.Unlock()
|
||||
require.Len(t, cached, 1, "cache must hold the synth result for next diff")
|
||||
}
|
||||
|
||||
func TestReconcile_NoChange_EmitsNothingExtra(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl)
|
||||
provider := newReconcileTestProvider()
|
||||
policy := newReconcileTestPolicy(provider.ID, "grp-eng")
|
||||
|
||||
// Two identical synth runs.
|
||||
mockStore.EXPECT().
|
||||
GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(newReconcileTestSettings(), nil).Times(2)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return([]*types.Provider{provider}, nil).Times(2)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return([]*types.Policy{policy}, nil).Times(2)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return([]*types.Guardrail{}, nil).Times(2)
|
||||
mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).Times(2)
|
||||
|
||||
createCalls := 0
|
||||
updateCalls := 0
|
||||
mockProxy.EXPECT().
|
||||
SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), gomock.Any()).
|
||||
Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) {
|
||||
switch m.Type {
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
|
||||
createCalls++
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
|
||||
updateCalls++
|
||||
}
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
|
||||
assert.Equal(t, 1, createCalls, "first reconcile creates")
|
||||
assert.Equal(t, 1, updateCalls, "second reconcile re-pushes as Modified (no semantic change but mapping fields refresh)")
|
||||
}
|
||||
|
||||
func TestReconcile_PolicyRemoved_EmitsDelete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl)
|
||||
provider := newReconcileTestProvider()
|
||||
policy := newReconcileTestPolicy(provider.ID, "grp-eng")
|
||||
|
||||
gomock.InOrder(
|
||||
// First reconcile: provider + policy, synthesised.
|
||||
mockStore.EXPECT().GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").Return(newReconcileTestSettings(), nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Provider{provider}, nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Policy{policy}, nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Guardrail{}, nil),
|
||||
// Second reconcile: policy gone, provider stays but no longer referenced.
|
||||
mockStore.EXPECT().GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").Return(newReconcileTestSettings(), nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Provider{provider}, nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Policy{}, nil),
|
||||
)
|
||||
mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()
|
||||
|
||||
var seenTypes []proto.ProxyMappingUpdateType
|
||||
mockProxy.EXPECT().
|
||||
SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), "eu.proxy.netbird.io").
|
||||
Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) {
|
||||
seenTypes = append(seenTypes, m.Type)
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
|
||||
require.Len(t, seenTypes, 2, "create then delete")
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, seenTypes[0])
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, seenTypes[1])
|
||||
|
||||
mgr.reconcileMu.Lock()
|
||||
_, present := mgr.reconcileCache["acct-1"]
|
||||
mgr.reconcileMu.Unlock()
|
||||
assert.False(t, present, "cache for the account must be cleared once nothing is synthesised")
|
||||
}
|
||||
|
||||
func TestReconcile_NilProxyController_NoOp(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mgr := &managerImpl{
|
||||
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
|
||||
}
|
||||
// Must not panic; must not query the store.
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
}
|
||||
|
||||
func TestReconcile_EmptyAccountID_NoOp(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mgr, _, _ := newReconcileMgr(t, ctrl)
|
||||
// Empty accountID short-circuits before any store call.
|
||||
mgr.reconcile(ctx, "")
|
||||
}
|
||||
|
||||
func TestClusterFromMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
want string
|
||||
}{
|
||||
{"simple", "openai.eu.proxy.netbird.io", "eu.proxy.netbird.io"},
|
||||
{"deeply nested", "a.b.c.d", "b.c.d"},
|
||||
{"no dot", "openai", ""},
|
||||
{"empty", "", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := clusterFromMapping(&proto.ProxyMapping{Domain: tt.domain})
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
1059
management/internals/modules/agentnetwork/synthesizer.go
Normal file
1059
management/internals/modules/agentnetwork/synthesizer.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,178 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// decodeServiceGuardrailConfig pulls the llm_guardrail middleware config off the
|
||||
// synthesised service's single target.
|
||||
func decodeServiceGuardrailConfig(t *testing.T, svc *rpservice.Service) guardrailConfig {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, svc.Targets, "synth service must carry a target")
|
||||
for _, mw := range svc.Targets[0].Options.Middlewares {
|
||||
if mw.ID == middlewareIDLLMGuardrail {
|
||||
var cfg guardrailConfig
|
||||
require.NoError(t, json.Unmarshal(mw.ConfigJSON, &cfg), "guardrail config must decode")
|
||||
return cfg
|
||||
}
|
||||
}
|
||||
t.Fatal("llm_guardrail middleware not present on synthesised service")
|
||||
return guardrailConfig{}
|
||||
}
|
||||
|
||||
// decodeMiddlewareRawConfig returns the raw ConfigJSON bytes for the named
|
||||
// middleware on the synth service's target, or fails the test.
|
||||
func decodeMiddlewareRawConfig(t *testing.T, svc *rpservice.Service, id string) []byte {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, svc.Targets, "synth service must carry a target")
|
||||
for _, mw := range svc.Targets[0].Options.Middlewares {
|
||||
if mw.ID == id {
|
||||
return mw.ConfigJSON
|
||||
}
|
||||
}
|
||||
t.Fatalf("middleware %q not present on synthesised service", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveGuardrailAndPolicy persists a guardrail with prompt capture + redact + a
|
||||
// model allowlist, referenced by one enabled policy. Shared by the GC-3 tests.
|
||||
func saveGuardrailAndPolicy(t *testing.T, ctx context.Context, s store.Store, provider *types.Provider) {
|
||||
t.Helper()
|
||||
guardrail := &types.Guardrail{
|
||||
ID: "ainguard-1",
|
||||
AccountID: testAccountID,
|
||||
Name: "strict",
|
||||
Checks: types.GuardrailChecks{
|
||||
ModelAllowlist: types.GuardrailModelAllowlist{Enabled: true, Models: []string{"gpt-5.4"}},
|
||||
PromptCapture: types.GuardrailPromptCapture{Enabled: true, RedactPii: true},
|
||||
},
|
||||
}
|
||||
require.NoError(t, s.SaveAgentNetworkGuardrail(ctx, guardrail))
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", guardrail.ID)))
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_PromptCaptureAccountIsSoleControl is the
|
||||
// GC-3 contract: the account master switch (EnablePromptCollection) is the
|
||||
// SOLE control for capture enablement. Policy-level guardrail prompt_capture is
|
||||
// ignored for enablement — operators don't need to attach a capture guardrail
|
||||
// to a policy just to turn capture on for the account. Off by default.
|
||||
func TestSynthesizeServices_RealStore_PromptCaptureAccountIsSoleControl(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
// Account collection master switch OFF (default).
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
saveGuardrailAndPolicy(t, ctx, s, newSynthTestProvider())
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
cfg := decodeServiceGuardrailConfig(t, services[0])
|
||||
assert.Equal(t, []string{"gpt-5.4"}, cfg.ModelAllowlist,
|
||||
"model allowlist is a pure policy guardrail and must always reach the config")
|
||||
assert.False(t, cfg.PromptCapture.Enabled,
|
||||
"prompt capture must be off when the account toggle is off, even with a capture-enabled guardrail")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_PromptCaptureFlowsWhenAccountOptsIn proves
|
||||
// the account toggle is sufficient on its own — even with NO guardrail
|
||||
// attached to the policy, capture fires when the account opts in. Redact is
|
||||
// the OR of account + guardrail.
|
||||
func TestSynthesizeServices_RealStore_PromptCaptureFlowsWhenAccountOptsIn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.EnablePromptCollection = true
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
|
||||
// Save a provider and a policy with NO guardrails attached — proves the
|
||||
// account toggle is sufficient on its own.
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
cfg := decodeServiceGuardrailConfig(t, services[0])
|
||||
assert.True(t, cfg.PromptCapture.Enabled,
|
||||
"account toggle alone must enable capture; no guardrail attachment required")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_AccountRedactWithoutGuardrailRedact proves
|
||||
// the redact OR-merge from the account side: account RedactPii on, guardrail
|
||||
// redact off, capture on at both levels.
|
||||
func TestSynthesizeServices_RealStore_AccountRedactWithoutGuardrailRedact(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.EnablePromptCollection = true
|
||||
settings.RedactPii = true
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
guardrail := &types.Guardrail{
|
||||
ID: "ainguard-noredact",
|
||||
AccountID: testAccountID,
|
||||
Name: "capture-only",
|
||||
Checks: types.GuardrailChecks{
|
||||
PromptCapture: types.GuardrailPromptCapture{Enabled: true, RedactPii: false},
|
||||
},
|
||||
}
|
||||
require.NoError(t, s.SaveAgentNetworkGuardrail(ctx, guardrail))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", guardrail.ID)))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
cfg := decodeServiceGuardrailConfig(t, services[0])
|
||||
assert.True(t, cfg.PromptCapture.Enabled, "capture on (account + guardrail)")
|
||||
assert.True(t, cfg.PromptCapture.RedactPii, "account RedactPii must apply even when the guardrail leaves it off (OR)")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_NoGuardrail_CaptureOff pins the default:
|
||||
// with no guardrail referenced, the synth service's guardrail config has prompt
|
||||
// capture disabled and an empty allowlist. This is the "off by default" baseline
|
||||
// the account switch must preserve.
|
||||
func TestSynthesizeServices_RealStore_NoGuardrail_CaptureOff(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1, "exactly one synth service expected")
|
||||
|
||||
cfg := decodeServiceGuardrailConfig(t, services[0])
|
||||
assert.Empty(t, cfg.ModelAllowlist, "no guardrail → no allowlist")
|
||||
assert.False(t, cfg.PromptCapture.Enabled, "no guardrail → prompt capture off by default")
|
||||
assert.False(t, cfg.PromptCapture.RedactPii, "no guardrail → redact off by default")
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// TestSynthesizeServices_RealStore_LogCollectionOff_SuppressesAccessLog drives the
|
||||
// happy default: account settings ship with EnableLogCollection=false, so the
|
||||
// synthesised target opts out of access-log emission (DisableAccessLog=true) and
|
||||
// the proto mapping the proxy receives reflects that.
|
||||
func TestSynthesizeServices_RealStore_LogCollectionOff_SuppressesAccessLog(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1, "exactly one synth service expected")
|
||||
require.NotEmpty(t, services[0].Targets, "synth service must carry a target")
|
||||
assert.True(t, services[0].Targets[0].Options.DisableAccessLog,
|
||||
"EnableLogCollection=false (default) must produce DisableAccessLog=true on the synth target")
|
||||
|
||||
mapping := services[0].ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{})
|
||||
require.NotEmpty(t, mapping.GetPath(), "proto mapping must carry a path")
|
||||
assert.True(t, mapping.GetPath()[0].GetOptions().GetDisableAccessLog(),
|
||||
"proto mapping must propagate DisableAccessLog=true so the proxy suppresses access-log emission")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_LogCollectionOn_PermitsAccessLog asserts the
|
||||
// inverse: once the account opts in, the synth target leaves DisableAccessLog
|
||||
// at its default false and the proto wire stays unset.
|
||||
func TestSynthesizeServices_RealStore_LogCollectionOn_PermitsAccessLog(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.EnableLogCollection = true
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1, "exactly one synth service expected")
|
||||
require.NotEmpty(t, services[0].Targets, "synth service must carry a target")
|
||||
assert.False(t, services[0].Targets[0].Options.DisableAccessLog,
|
||||
"EnableLogCollection=true must leave DisableAccessLog=false on the synth target")
|
||||
|
||||
mapping := services[0].ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{})
|
||||
require.NotEmpty(t, mapping.GetPath(), "proto mapping must carry a path")
|
||||
assert.False(t, mapping.GetPath()[0].GetOptions().GetDisableAccessLog(),
|
||||
"proto mapping must propagate DisableAccessLog=false so access-log emission stays on")
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// parserRedactConfig mirrors the on-wire shape of the redact + capture knobs
|
||||
// that both llm_request_parser and llm_response_parser unmarshal. We don't
|
||||
// import the proxy-side packages from a management test (cross-module), so we
|
||||
// decode the JSON directly and assert on the fields that are part of the
|
||||
// synth contract.
|
||||
type parserRedactConfig struct {
|
||||
RedactPii bool `json:"redact_pii,omitempty"`
|
||||
CapturePrompt *bool `json:"capture_prompt,omitempty"` // present only on the request parser
|
||||
CaptureCompletion *bool `json:"capture_completion,omitempty"` // present only on the response parser
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_ParserConfigsCarryRedactPii is the
|
||||
// management-side contract test for the request/response parser redaction
|
||||
// wiring. When settings.RedactPii is true, the synthesised middleware chain
|
||||
// MUST stamp redact_pii=true on both llm_request_parser and llm_response_parser
|
||||
// configs — otherwise the parsers ship raw prompts / completions to the
|
||||
// access log even though the account has opted in. This is exactly the live
|
||||
// leak path that motivated the parser-side redaction in the first place.
|
||||
func TestSynthesizeServices_RealStore_ParserConfigsCarryRedactPii(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.RedactPii = true
|
||||
settings.EnablePromptCollection = true
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1, "exactly one synth service expected")
|
||||
|
||||
for _, parserID := range []string{middlewareIDLLMRequestParser, middlewareIDLLMResponseParser} {
|
||||
raw := decodeMiddlewareRawConfig(t, services[0], parserID)
|
||||
var cfg parserRedactConfig
|
||||
require.NoError(t, json.Unmarshal(raw, &cfg), "%s config must be valid JSON", parserID)
|
||||
assert.True(t, cfg.RedactPii, "%s config must carry redact_pii=true when settings.RedactPii is on (otherwise the parser ships raw prompts/completions to the access log)", parserID)
|
||||
}
|
||||
// The capture flag is set explicitly to enable_prompt_collection on each
|
||||
// parser. With it on here, both must allow emission.
|
||||
reqCfg := decodeParserConfig(t, services[0], middlewareIDLLMRequestParser)
|
||||
require.NotNil(t, reqCfg.CapturePrompt, "request parser must carry an explicit capture_prompt")
|
||||
assert.True(t, *reqCfg.CapturePrompt, "capture_prompt=true when EnablePromptCollection=true")
|
||||
respCfg := decodeParserConfig(t, services[0], middlewareIDLLMResponseParser)
|
||||
require.NotNil(t, respCfg.CaptureCompletion, "response parser must carry an explicit capture_completion")
|
||||
assert.True(t, *respCfg.CaptureCompletion, "capture_completion=true when EnablePromptCollection=true")
|
||||
}
|
||||
|
||||
// decodeParserConfig is a small helper around decodeMiddlewareRawConfig that
|
||||
// also unmarshals into parserRedactConfig.
|
||||
func decodeParserConfig(t *testing.T, svc *rpservice.Service, parserID string) parserRedactConfig {
|
||||
t.Helper()
|
||||
raw := decodeMiddlewareRawConfig(t, svc, parserID)
|
||||
var cfg parserRedactConfig
|
||||
require.NoError(t, json.Unmarshal(raw, &cfg), "%s config must be valid JSON", parserID)
|
||||
return cfg
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_ParserConfigsSuppressCaptureWhenLogCollectionOnly
|
||||
// is the contract test for the bug: enable_log_collection=true with
|
||||
// enable_prompt_collection=false MUST result in capture_prompt=false on the
|
||||
// request parser AND capture_completion=false on the response parser, so the
|
||||
// access-log row stays metadata-only (provider, model, tokens, cost) and
|
||||
// carries NO prompt input nor response output. Without this, operators who
|
||||
// want billing-style logs end up with raw user prompts and model outputs in
|
||||
// every access-log entry.
|
||||
func TestSynthesizeServices_RealStore_ParserConfigsSuppressCaptureWhenLogCollectionOnly(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.EnableLogCollection = true // operator wants logs ON
|
||||
settings.EnablePromptCollection = false // but NOT content capture
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
reqCfg := decodeParserConfig(t, services[0], middlewareIDLLMRequestParser)
|
||||
require.NotNil(t, reqCfg.CapturePrompt, "request parser must carry an explicit capture_prompt gate")
|
||||
assert.False(t, *reqCfg.CapturePrompt, "capture_prompt MUST be false when EnablePromptCollection is off — otherwise llm.request_prompt_raw leaks user input into the access log")
|
||||
|
||||
respCfg := decodeParserConfig(t, services[0], middlewareIDLLMResponseParser)
|
||||
require.NotNil(t, respCfg.CaptureCompletion, "response parser must carry an explicit capture_completion gate")
|
||||
assert.False(t, *respCfg.CaptureCompletion, "capture_completion MUST be false when EnablePromptCollection is off — otherwise llm.response_completion leaks model output into the access log")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_ParserConfigsOmitRedactPiiWhenOff proves
|
||||
// the inverse: with the account toggle off, the parser configs stay clean (no
|
||||
// redact_pii field, which the parsers treat as zero / no redaction). This is
|
||||
// the operator-opt-out path — the access log keeps raw prompts/completions
|
||||
// for debugging until the operator opts in.
|
||||
func TestSynthesizeServices_RealStore_ParserConfigsOmitRedactPiiWhenOff(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
|
||||
// Default settings: RedactPii = false.
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
for _, parserID := range []string{middlewareIDLLMRequestParser, middlewareIDLLMResponseParser} {
|
||||
raw := decodeMiddlewareRawConfig(t, services[0], parserID)
|
||||
// Inspect the decoded JSON directly: a struct decode would also pass
|
||||
// if redact_pii were present-but-false. The contract is that the key
|
||||
// is omitted entirely while the account toggle is off.
|
||||
var rawCfg map[string]json.RawMessage
|
||||
require.NoError(t, json.Unmarshal(raw, &rawCfg), "%s config must be valid JSON", parserID)
|
||||
assert.NotContains(t, rawCfg, "redact_pii",
|
||||
"%s config must omit redact_pii entirely while the account toggle is off", parserID)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// decodeServiceRouterConfig finds the llm_router middleware on the synthesised
|
||||
// service's single target and decodes its config — the model→provider routing
|
||||
// table the proxy authorises against.
|
||||
func decodeServiceRouterConfig(t *testing.T, svc *rpservice.Service) routerConfig {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, svc.Targets, "synth service must carry a target")
|
||||
for _, mw := range svc.Targets[0].Options.Middlewares {
|
||||
if mw.ID == middlewareIDLLMRouter {
|
||||
var cfg routerConfig
|
||||
require.NoError(t, json.Unmarshal(mw.ConfigJSON, &cfg), "router config must decode")
|
||||
return cfg
|
||||
}
|
||||
}
|
||||
t.Fatal("llm_router middleware not present on synthesised service")
|
||||
return routerConfig{}
|
||||
}
|
||||
|
||||
// decodeMappingRouterConfig is the proto-wire equivalent: it pulls the
|
||||
// llm_router config off the ProxyMapping the proxy actually receives.
|
||||
func decodeMappingRouterConfig(t *testing.T, m *proto.ProxyMapping) routerConfig {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, m.GetPath(), "mapping must carry a path")
|
||||
for _, mw := range m.GetPath()[0].GetOptions().GetMiddlewares() {
|
||||
if mw.GetId() == middlewareIDLLMRouter {
|
||||
var cfg routerConfig
|
||||
require.NoError(t, json.Unmarshal(mw.GetConfigJson(), &cfg), "wire router config must decode")
|
||||
return cfg
|
||||
}
|
||||
}
|
||||
t.Fatal("llm_router middleware not present on proxy mapping")
|
||||
return routerConfig{}
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_SurvivesStatusToggle drives synthesis through
|
||||
// a REAL sqlite store (Save → gorm/JSON serialize → reload → decrypt) instead of
|
||||
// a MockStore, so it exercises the field round-trip that a provider/policy edit
|
||||
// actually hits. Mock-based tests can't catch a field that dies in persistence;
|
||||
// this one can. It then performs the exact operation that reproduced the live
|
||||
// 403 — disable then re-enable the provider — and asserts the re-enabled state
|
||||
// is fully routable again.
|
||||
func TestSynthesizeServices_RealStore_SurvivesStatusToggle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
assertRoutable := func(t *testing.T, stage string) {
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err, stage)
|
||||
require.Len(t, services, 1, "%s: exactly one synth service expected", stage)
|
||||
svc := services[0]
|
||||
|
||||
assert.True(t, svc.Private, "%s: synth service must be Private after store round-trip", stage)
|
||||
assert.Equal(t, []string{"grp-eng"}, svc.AccessGroups, "%s: AccessGroups must survive the round-trip", stage)
|
||||
|
||||
m := svc.ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{})
|
||||
assert.True(t, m.GetPrivate(), "%s: proto mapping Private must be true (proxy gates tunnel-peer auth on it)", stage)
|
||||
|
||||
cfg := decodeServiceRouterConfig(t, svc)
|
||||
require.Len(t, cfg.Providers, 1, "%s: the enabled+linked provider must appear in the router config", stage)
|
||||
assert.Equal(t, []string{"gpt-5.4"}, cfg.Providers[0].Models, "%s: provider models must reach the route", stage)
|
||||
assert.Equal(t, []string{"grp-eng"}, cfg.Providers[0].AllowedGroupIDs, "%s: policy source groups must reach the route", stage)
|
||||
}
|
||||
|
||||
assertRoutable(t, "initial")
|
||||
|
||||
provider.Enabled = false
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
disabled, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err, "synthesis must not error with a disabled provider")
|
||||
for _, svc := range disabled {
|
||||
assert.Empty(t, decodeServiceRouterConfig(t, svc).Providers,
|
||||
"a disabled provider must not appear in the router config (otherwise it would route while off)")
|
||||
}
|
||||
|
||||
provider.Enabled = true
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
assertRoutable(t, "after disable->enable")
|
||||
}
|
||||
|
||||
// captureController is a proxy.Controller that records the mappings reconcile
|
||||
// pushes, so the test can inspect the exact wire payload — Private flag and
|
||||
// router config included.
|
||||
type captureController struct {
|
||||
rpproxy.Controller
|
||||
pushed []*proto.ProxyMapping
|
||||
}
|
||||
|
||||
func (c *captureController) GetOIDCValidationConfig() rpproxy.OIDCValidationConfig {
|
||||
return rpproxy.OIDCValidationConfig{}
|
||||
}
|
||||
|
||||
func (c *captureController) SendServiceUpdateToCluster(_ context.Context, _ string, update *proto.ProxyMapping, _ string) {
|
||||
c.pushed = append(c.pushed, update)
|
||||
}
|
||||
|
||||
// noopAccountManager satisfies the reconcile path's accountManager dependency.
|
||||
type noopAccountManager struct {
|
||||
account.Manager
|
||||
}
|
||||
|
||||
func (noopAccountManager) UpdateAccountPeers(context.Context, string, nbtypes.UpdateReason) {}
|
||||
|
||||
// TestReconcile_RealStore_PushesPrivateAfterStatusToggle reproduces the live
|
||||
// path end-to-end below the gRPC boundary: a real store + the real
|
||||
// managerImpl.reconcile + a capturing proxy controller. It runs the operation
|
||||
// that broke in production — provider disable then re-enable — and asserts the
|
||||
// mapping reconcile pushes to the cluster after re-enable is Private=true and
|
||||
// carries the routable provider. If reconcile ever pushes private=false (the
|
||||
// symptom that left UserGroups empty → no_authorised_provider), this fails.
|
||||
func TestReconcile_RealStore_PushesPrivateAfterStatusToggle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
ctrl := &captureController{}
|
||||
m := &managerImpl{
|
||||
store: s,
|
||||
accountManager: noopAccountManager{},
|
||||
proxyController: ctrl,
|
||||
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
m.reconcile(ctx, testAccountID) // initial, provider enabled
|
||||
|
||||
provider.Enabled = false
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
m.reconcile(ctx, testAccountID) // disabled
|
||||
|
||||
provider.Enabled = true
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
m.reconcile(ctx, testAccountID) // re-enabled — the reproduction step
|
||||
|
||||
require.NotEmpty(t, ctrl.pushed, "reconcile must push at least one mapping")
|
||||
last := ctrl.pushed[len(ctrl.pushed)-1]
|
||||
|
||||
assert.Equal(t, newSynthTestSettings().Endpoint(), last.GetDomain(), "synth domain on the wire")
|
||||
assert.True(t, last.GetPrivate(),
|
||||
"reconcile-pushed mapping after re-enable MUST be Private=true; a false here is the exact bug — the proxy skips ValidateTunnelPeer, UserGroups stays empty, and llm_router denies no_authorised_provider")
|
||||
|
||||
cfg := decodeMappingRouterConfig(t, last)
|
||||
require.Len(t, cfg.Providers, 1, "re-enabled provider must be back in the pushed router config")
|
||||
assert.Equal(t, []string{"gpt-5.4"}, cfg.Providers[0].Models, "model must be routable again after re-enable")
|
||||
assert.Equal(t, []string{"grp-eng"}, cfg.Providers[0].AllowedGroupIDs, "authorised groups must be present after re-enable")
|
||||
}
|
||||
1098
management/internals/modules/agentnetwork/synthesizer_test.go
Normal file
1098
management/internals/modules/agentnetwork/synthesizer_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user