mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-27 10:19:55 +00:00
Compare commits
35 Commits
v0.73.1
...
agent-netw
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
668af0dc4f | ||
|
|
5f130959ea | ||
|
|
5644279888 | ||
|
|
f22ac6d271 | ||
|
|
9f485be2f9 | ||
|
|
c83e46fbe1 | ||
|
|
405607c584 | ||
|
|
29f55d4255 | ||
|
|
3993fa32e4 | ||
|
|
6ade3839aa | ||
|
|
d4d158a8f3 | ||
|
|
6613d194ef | ||
|
|
769e12840d | ||
|
|
350a96c640 | ||
|
|
615631567a | ||
|
|
f4daf59bcd | ||
|
|
ff2787e184 | ||
|
|
e20b62ad65 | ||
|
|
18b38943aa | ||
|
|
a400828b89 | ||
|
|
e2bb328a34 | ||
|
|
221b9c012c | ||
|
|
17b2044596 | ||
|
|
07101c59ac | ||
|
|
51b6f6291b | ||
|
|
2ebf26006a | ||
|
|
211a26019a | ||
|
|
6c26178ad5 | ||
|
|
af3b7e4497 | ||
|
|
e84f6527f7 | ||
|
|
ac9529ea8c | ||
|
|
f736ef9647 | ||
|
|
cf58bf1ba9 | ||
|
|
522b8ed969 | ||
|
|
c9e99659ea |
@@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -59,12 +59,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: true
|
||||
|
||||
2
.github/workflows/git-town.yml
vendored
2
.github/workflows/git-town.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
|
||||
|
||||
6
.github/workflows/golang-test-darwin.yml
vendored
6
.github/workflows/golang-test-darwin.yml
vendored
@@ -16,12 +16,12 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
|
||||
4
.github/workflows/golang-test-freebsd.yml
vendored
4
.github/workflows/golang-test-freebsd.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
id: test
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
|
||||
58
.github/workflows/golang-test-linux.yml
vendored
58
.github/workflows/golang-test-linux.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -119,12 +119,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -162,7 +162,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -175,12 +175,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -246,12 +246,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -290,7 +290,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -306,12 +306,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -347,7 +347,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -363,12 +363,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -407,7 +407,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -424,12 +424,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -484,7 +484,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -529,12 +529,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -579,10 +579,11 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags devcert -run=^$ -bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
|
||||
env:
|
||||
GIT_BRANCH: ${{ github.ref_name }}
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
@@ -623,12 +624,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -673,12 +674,13 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags=benchmark \
|
||||
-run=^$ \
|
||||
-bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/server/http/...
|
||||
env:
|
||||
GIT_BRANCH: ${{ github.ref_name }}
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
@@ -692,12 +694,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -734,7 +736,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
|
||||
4
.github/workflows/golang-test-windows.yml
vendored
4
.github/workflows/golang-test-windows.yml
vendored
@@ -18,12 +18,12 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
id: go
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
8
.github/workflows/golangci-lint.yml
vendored
8
.github/workflows/golangci-lint.yml
vendored
@@ -15,13 +15,13 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals,flate,recordin,unparseable
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Check for duplicate constants
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
2
.github/workflows/install-script-test.yml
vendored
2
.github/workflows/install-script-test.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
10
.github/workflows/mobile-build-validation.yml
vendored
10
.github/workflows/mobile-build-validation.yml
vendored
@@ -16,11 +16,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Setup Android SDK
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
|
||||
uses: actions/setup-java@ad2b38190b15e4d6bdf0c97fb4fca8412226d287
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
@@ -54,11 +54,11 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: install gomobile
|
||||
|
||||
32
.github/workflows/release.yml
vendored
32
.github/workflows/release.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
@@ -166,7 +166,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -186,9 +186,9 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
|
||||
uses: docker/setup-qemu-action@06116385d9baf250c9f4dcb4858b16962ea869c3 #v4.1.0
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
|
||||
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 #v4.1.0
|
||||
- name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
@@ -221,7 +221,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --clean ${{ env.flags }}
|
||||
@@ -347,7 +347,7 @@ jobs:
|
||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
@@ -374,7 +374,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -420,7 +420,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||
@@ -464,12 +464,12 @@ jobs:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -488,7 +488,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||
@@ -522,7 +522,7 @@ jobs:
|
||||
downloadPath: '${{ github.workspace }}\temp'
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -534,13 +534,13 @@ jobs:
|
||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
- name: Download release artifacts
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
with:
|
||||
name: release
|
||||
path: release
|
||||
|
||||
- name: Download UI release artifacts
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: release-ui
|
||||
|
||||
12
.github/workflows/test-infrastructure-files.yml
vendored
12
.github/workflows/test-infrastructure-files.yml
vendored
@@ -68,12 +68,12 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
@@ -207,7 +207,7 @@ jobs:
|
||||
- name: Build management docker image
|
||||
working-directory: management
|
||||
run: |
|
||||
docker build -t netbirdio/management:latest .
|
||||
docker build -t netbirdio/management:latest --build-arg TARGETPLATFORM=. .
|
||||
|
||||
- name: Build signal binary
|
||||
working-directory: signal
|
||||
@@ -216,7 +216,7 @@ jobs:
|
||||
- name: Build signal docker image
|
||||
working-directory: signal
|
||||
run: |
|
||||
docker build -t netbirdio/signal:latest .
|
||||
docker build -t netbirdio/signal:latest --build-arg TARGETPLATFORM=. .
|
||||
|
||||
- name: Build relay binary
|
||||
working-directory: relay
|
||||
@@ -225,7 +225,7 @@ jobs:
|
||||
- name: Build relay docker image
|
||||
working-directory: relay
|
||||
run: |
|
||||
docker build -t netbirdio/relay:latest .
|
||||
docker build -t netbirdio/relay:latest --build-arg TARGETPLATFORM=. .
|
||||
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files/artifacts
|
||||
@@ -256,7 +256,7 @@ jobs:
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
8
.github/workflows/wasm-build-validation.yml
vendored
8
.github/workflows/wasm-build-validation.yml
vendored
@@ -19,11 +19,11 @@ jobs:
|
||||
GOARCH: wasm
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
@@ -44,11 +44,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Build Wasm client
|
||||
|
||||
@@ -462,9 +462,13 @@ checksum:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
- glob: ./infrastructure_files/getting-started-enterprise.sh
|
||||
- glob: ./infrastructure_files/migrate-to-enterprise.sh
|
||||
|
||||
release:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
- glob: ./infrastructure_files/getting-started-enterprise.sh
|
||||
- glob: ./infrastructure_files/migrate-to-enterprise.sh
|
||||
|
||||
@@ -130,7 +130,7 @@ func debugConfigDump(cmd *cobra.Command, _ []string) error {
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
ProfileName: string(activeProf.ID),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -279,9 +279,11 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
// Cancel the client context before stopping: Engine.Start blocks on the
|
||||
// signal stream while holding the engine mutex and only unblocks on
|
||||
// cancellation. Stopping first would deadlock on that mutex.
|
||||
// ConnectClient.Stop now cancels its own run context and waits for the
|
||||
// run loop to tear the engine down, so this cancel() is no longer
|
||||
// required to break the deadlock and could be removed. It is kept as a
|
||||
// defensive belt-and-suspenders: cancelling the parent context first
|
||||
// guarantees the run loop is unblocked even if Stop's contract regresses.
|
||||
cancel()
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
@@ -54,6 +55,10 @@ var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath
|
||||
|
||||
type ConnectClient struct {
|
||||
ctx context.Context
|
||||
runCancel context.CancelFunc
|
||||
runExited chan struct{}
|
||||
runOnce sync.Once
|
||||
runStarted atomic.Bool
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -70,8 +75,14 @@ func NewConnectClient(
|
||||
config *profilemanager.Config,
|
||||
statusRecorder *peer.Status,
|
||||
) *ConnectClient {
|
||||
// Derive the run context here so Stop owns the cancel that unblocks the run
|
||||
// loop. runCancel is set once at construction, so Stop can call it without
|
||||
// racing the run loop's startup. Callers therefore need not cancel before Stop.
|
||||
runCtx, runCancel := context.WithCancel(ctx)
|
||||
return &ConnectClient{
|
||||
ctx: ctx,
|
||||
ctx: runCtx,
|
||||
runCancel: runCancel,
|
||||
runExited: make(chan struct{}),
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
engineMutex: sync.Mutex{},
|
||||
@@ -135,6 +146,11 @@ func (c *ConnectClient) RunOniOS(
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
||||
// Mark the loop as started and signal exit on return so Stop can wait for
|
||||
// the loop to finish (and skip the wait if the loop never ran).
|
||||
c.runStarted.Store(true)
|
||||
defer c.runOnce.Do(func() { close(c.runExited) })
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
rec := c.statusRecorder
|
||||
@@ -290,7 +306,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
log.Debug(err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
state.Set(StatusNeedsLogin)
|
||||
_ = c.Stop()
|
||||
c.runCancel()
|
||||
return backoff.Permanent(wrapErr(err)) // unrecoverable error
|
||||
}
|
||||
return wrapErr(err)
|
||||
@@ -410,14 +426,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
c.engine = nil
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
// todo: consider to remove this condition. Is not thread safe.
|
||||
// We should always call Stop(), but we need to verify that it is idempotent
|
||||
if engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||
log.Infof("ensuring wg interface is removed, Netbird engine context cancelled")
|
||||
|
||||
if err := engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
if err := engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
c.statusRecorder.ClientTeardown()
|
||||
|
||||
@@ -433,12 +445,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
|
||||
c.statusRecorder.ClientStart()
|
||||
err = backoff.Retry(operation, backOff)
|
||||
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
|
||||
if err != nil {
|
||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
state.Set(StatusNeedsLogin)
|
||||
_ = c.Stop()
|
||||
c.runCancel()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -516,11 +528,9 @@ func (c *ConnectClient) Status() StatusType {
|
||||
}
|
||||
|
||||
func (c *ConnectClient) Stop() error {
|
||||
engine := c.Engine()
|
||||
if engine != nil {
|
||||
if err := engine.Stop(); err != nil {
|
||||
return fmt.Errorf("stop engine: %w", err)
|
||||
}
|
||||
c.runCancel()
|
||||
if c.runStarted.Load() {
|
||||
<-c.runExited
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -51,13 +51,20 @@ type cachedRecord struct {
|
||||
}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
|
||||
// guarded by mutex.
|
||||
type Resolver struct {
|
||||
records map[dns.Question]*cachedRecord
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
|
||||
// failedResolves records the last failed initial resolve per domain so a
|
||||
// domain that never resolves isn't retried on every server-domains update
|
||||
// until refreshBackoff elapses. Entries are cleared on success and pruned
|
||||
// to the current server-domains set.
|
||||
failedResolves map[domain.Domain]time.Time
|
||||
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
@@ -76,9 +83,10 @@ type Resolver struct {
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
failedResolves: make(map[domain.Domain]time.Time),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,7 +181,9 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||
// entry for that qtype.
|
||||
// entry for that qtype. When one family hard-errors while the other succeeds,
|
||||
// the resolved family is still cached but AddDomain returns an error so the
|
||||
// caller retries the incomplete resolve rather than treating it as complete.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
@@ -203,6 +213,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
|
||||
if errA != nil || errAAAA != nil {
|
||||
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -462,6 +476,7 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
delete(m.records, qAAAA)
|
||||
delete(m.refreshing, qA)
|
||||
delete(m.refreshing, qAAAA)
|
||||
delete(m.failedResolves, d)
|
||||
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
@@ -505,6 +520,7 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
|
||||
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
|
||||
currentDomains := m.GetCachedDomains()
|
||||
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
||||
m.pruneFailedResolves(allDomains)
|
||||
}
|
||||
|
||||
m.addNewDomains(ctx, newDomains)
|
||||
@@ -577,13 +593,85 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
|
||||
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
||||
}
|
||||
|
||||
// addNewDomains resolves and caches all domains from the update
|
||||
// addNewDomains resolves and caches domains that are not yet in the cache,
|
||||
// running the lookups concurrently. Domains already cached are skipped and left
|
||||
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
|
||||
// synchronously: once NetBird owns the OS resolver the resolve runs through the
|
||||
// handler chain and would otherwise dial the managed upstreams under the engine
|
||||
// sync lock on every update.
|
||||
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
||||
var wg sync.WaitGroup
|
||||
seen := make(map[domain.Domain]struct{}, len(newDomains))
|
||||
for _, newDomain := range newDomains {
|
||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
||||
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||
} else {
|
||||
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
||||
if _, dup := seen[newDomain]; dup {
|
||||
continue
|
||||
}
|
||||
seen[newDomain] = struct{}{}
|
||||
|
||||
if !m.needsResolve(newDomain) {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(d domain.Domain) {
|
||||
defer wg.Done()
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
m.markResolveFailed(d)
|
||||
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
|
||||
return
|
||||
}
|
||||
m.clearResolveFailed(d)
|
||||
log.Debugf("added/updated management cache domain=%s", d.SafeString())
|
||||
}(newDomain)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// needsResolve reports whether d should be resolved now. A recent failed or
|
||||
// incomplete resolve gates retries on the backoff even when one family is
|
||||
// already cached, so a transiently-failed family is retried instead of being
|
||||
// treated as fully resolved. Otherwise a domain with any cached record is left
|
||||
// to the stale-while-revalidate refresh path.
|
||||
func (m *Resolver) needsResolve(d domain.Domain) bool {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
if failedAt, ok := m.failedResolves[d]; ok {
|
||||
return time.Since(failedAt) >= refreshBackoff
|
||||
}
|
||||
|
||||
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
if _, ok := m.records[q]; ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Resolver) markResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
m.failedResolves[d] = time.Now()
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Resolver) clearResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
delete(m.failedResolves, d)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// pruneFailedResolves drops failure markers for domains no longer present in
|
||||
// the server-domains set, keeping the map bounded to the current set (a
|
||||
// failed-only domain has no cached record, so RemoveDomain never sees it).
|
||||
func (m *Resolver) pruneFailedResolves(domains domain.List) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
for d := range m.failedResolves {
|
||||
if !slices.Contains(domains, d) {
|
||||
delete(m.failedResolves, d)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ type fakeChain struct {
|
||||
mu sync.Mutex
|
||||
calls map[string]int
|
||||
answers map[string][]dns.RR
|
||||
qErr map[string]error
|
||||
err error
|
||||
hasRoot bool
|
||||
onLookup func()
|
||||
@@ -30,6 +31,7 @@ func newFakeChain() *fakeChain {
|
||||
return &fakeChain{
|
||||
calls: map[string]int{},
|
||||
answers: map[string][]dns.RR{},
|
||||
qErr: map[string]error{},
|
||||
hasRoot: true,
|
||||
}
|
||||
}
|
||||
@@ -47,6 +49,9 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
|
||||
f.calls[key]++
|
||||
answers := f.answers[key]
|
||||
err := f.err
|
||||
if err == nil {
|
||||
err = f.qErr[key]
|
||||
}
|
||||
onLookup := f.onLookup
|
||||
f.mu.Unlock()
|
||||
|
||||
@@ -75,6 +80,12 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
|
||||
}
|
||||
|
||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
183
client/internal/dns/mgmt/mgmt_resolve_test.go
Normal file
183
client/internal/dns/mgmt/mgmt_resolve_test.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// A domain already in the cache must not be re-resolved on a subsequent server
|
||||
// domains update; it is left to the stale-while-revalidate refresh path.
|
||||
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must resolve the domain")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"cached domain must not be re-resolved on a subsequent update")
|
||||
}
|
||||
|
||||
// New domains in a single update must resolve concurrently rather than serially.
|
||||
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
|
||||
var inflight, maxInflight atomic.Int32
|
||||
chain.onLookup = func() {
|
||||
n := inflight.Add(1)
|
||||
for {
|
||||
old := maxInflight.Load()
|
||||
if n <= old || maxInflight.CompareAndSwap(old, n) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
inflight.Add(-1)
|
||||
}
|
||||
|
||||
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
|
||||
for _, d := range relays {
|
||||
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
|
||||
}
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
start := time.Now()
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
|
||||
require.NoError(t, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
|
||||
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
|
||||
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
|
||||
}
|
||||
|
||||
// A domain that fails to resolve must not be retried on every update; the
|
||||
// failure backoff suppresses re-resolution until it expires.
|
||||
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must attempt the resolve")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"failed resolve must back off and not retry on the next update")
|
||||
}
|
||||
|
||||
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
|
||||
// the same host) must be resolved once per update, not once per occurrence.
|
||||
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{
|
||||
Stuns: []domain.Domain{"dup.example.com"},
|
||||
Turns: []domain.Domain{"dup.example.com"},
|
||||
}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
|
||||
"a domain appearing under multiple server-domain types must resolve once")
|
||||
}
|
||||
|
||||
// A failure marker must be dropped once its domain leaves the server-domains set
|
||||
// so the map stays bounded to the current set.
|
||||
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, marked, "failed resolve must be recorded")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
|
||||
}
|
||||
|
||||
// When one family hard-errors while the other resolves, the domain is cached
|
||||
// for the working family but recorded as incomplete so the failed family is
|
||||
// retried under backoff instead of being treated as fully resolved forever.
|
||||
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
|
||||
d := domain.Domain("relay.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, aCached, "the working family must still be cached")
|
||||
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
|
||||
|
||||
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
|
||||
|
||||
r.mutex.Lock()
|
||||
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
|
||||
r.mutex.Unlock()
|
||||
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
|
||||
}
|
||||
|
||||
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
|
||||
// not a failure: the domain must not be marked for retry, otherwise it would be
|
||||
// re-resolved on every sync.
|
||||
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
|
||||
d := domain.Domain("v4only.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
|
||||
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
|
||||
}
|
||||
@@ -207,3 +207,35 @@ func FormatAnswers(answers []dns.RR) string {
|
||||
}
|
||||
return "[" + strings.Join(parts, ", ") + "]"
|
||||
}
|
||||
|
||||
// StripOPT removes any OPT pseudo-RRs from the message's Extra section. Per
|
||||
// RFC 6891 a responder must not include an OPT RR toward a client that did not
|
||||
// advertise EDNS0.
|
||||
func StripOPT(msg *dns.Msg) {
|
||||
if len(msg.Extra) == 0 {
|
||||
return
|
||||
}
|
||||
out := msg.Extra[:0]
|
||||
for _, rr := range msg.Extra {
|
||||
if _, ok := rr.(*dns.OPT); ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
msg.Extra = out
|
||||
}
|
||||
|
||||
// ExtractEDE returns the first Extended DNS Error (RFC 8914) option carried in
|
||||
// the message, if present.
|
||||
func ExtractEDE(msg *dns.Msg) (*dns.EDNS0_EDE, bool) {
|
||||
opt := msg.IsEdns0()
|
||||
if opt == nil {
|
||||
return nil, false
|
||||
}
|
||||
for _, o := range opt.Option {
|
||||
if ede, ok := o.(*dns.EDNS0_EDE); ok {
|
||||
return ede, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -120,3 +120,42 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
||||
}
|
||||
|
||||
func TestStripOPT(t *testing.T) {
|
||||
rm := &dns.Msg{
|
||||
Extra: []dns.RR{
|
||||
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
||||
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
}
|
||||
StripOPT(rm)
|
||||
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
||||
_, isOPT := rm.Extra[0].(*dns.OPT)
|
||||
assert.False(t, isOPT, "remaining record must not be OPT")
|
||||
}
|
||||
|
||||
func TestExtractEDE(t *testing.T) {
|
||||
t.Run("no edns", func(t *testing.T) {
|
||||
_, ok := ExtractEDE(&dns.Msg{})
|
||||
assert.False(t, ok, "message without OPT has no EDE")
|
||||
})
|
||||
|
||||
t.Run("edns without ede", func(t *testing.T) {
|
||||
rm := &dns.Msg{}
|
||||
rm.SetEdns0(4096, false)
|
||||
_, ok := ExtractEDE(rm)
|
||||
assert.False(t, ok, "OPT without EDE option returns false")
|
||||
})
|
||||
|
||||
t.Run("with ede", func(t *testing.T) {
|
||||
rm := &dns.Msg{}
|
||||
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: 49152, ExtraText: "upstream timeout"})
|
||||
rm.Extra = append(rm.Extra, opt)
|
||||
|
||||
ede, ok := ExtractEDE(rm)
|
||||
assert.True(t, ok, "EDE option should be found")
|
||||
assert.Equal(t, uint16(49152), ede.InfoCode)
|
||||
assert.Equal(t, "upstream timeout", ede.ExtraText)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -457,7 +457,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
|
||||
// problems: fail over for a better answer but keep the upstream healthy.
|
||||
if code, ok := nonRetryableEDE(rm); ok {
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
resutil.StripOPT(rm)
|
||||
}
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
|
||||
}
|
||||
@@ -466,7 +466,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
|
||||
}
|
||||
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
resutil.StripOPT(rm)
|
||||
}
|
||||
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
|
||||
@@ -523,22 +523,6 @@ func upstreamUDPSize() uint16 {
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
|
||||
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
|
||||
// the response complies with RFC 6891 when the client did not advertise EDNS0.
|
||||
func stripOPT(rm *dns.Msg) {
|
||||
if len(rm.Extra) == 0 {
|
||||
return
|
||||
}
|
||||
out := rm.Extra[:0]
|
||||
for _, rr := range rm.Extra {
|
||||
if _, ok := rr.(*dns.OPT); ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
rm.Extra = out
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
||||
|
||||
@@ -985,19 +985,6 @@ func TestEDEName(t *testing.T) {
|
||||
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
|
||||
}
|
||||
|
||||
func TestStripOPT(t *testing.T) {
|
||||
rm := &dns.Msg{
|
||||
Extra: []dns.RR{
|
||||
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
||||
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
}
|
||||
stripOPT(rm)
|
||||
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
||||
_, isOPT := rm.Extra[0].(*dns.OPT)
|
||||
assert.False(t, isOPT, "remaining record must not be OPT")
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
||||
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
@@ -26,6 +26,15 @@ import (
|
||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||
const upstreamTimeout = 15 * time.Second
|
||||
|
||||
// EDE info codes the forwarder emits on upstream failures so the querying
|
||||
// client can see the reason without inspecting this peer's logs. They live in
|
||||
// the RFC 8914 Private Use range (49152-65535); the Go resolver never exposes a
|
||||
// real upstream EDE here, so these cannot collide with a genuine code.
|
||||
const (
|
||||
edeNetbirdUpstreamTimeout uint16 = 49152
|
||||
edeNetbirdUpstreamFailure uint16 = 49153
|
||||
)
|
||||
|
||||
type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
@@ -220,7 +229,7 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
|
||||
|
||||
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||
if result.Err != nil {
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -333,6 +342,7 @@ func (f *DNSForwarder) handleDNSError(
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
result resutil.LookupResult,
|
||||
reqHasEdns bool,
|
||||
startTime time.Time,
|
||||
) {
|
||||
qType := question.Qtype
|
||||
@@ -374,6 +384,10 @@ func (f *DNSForwarder) handleDNSError(
|
||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||
}
|
||||
|
||||
if reqHasEdns {
|
||||
attachEDE(resp, edeCodeFor(dnsErr), edeText(dnsErr))
|
||||
}
|
||||
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
}
|
||||
|
||||
@@ -414,3 +428,33 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
|
||||
|
||||
return selectedResId, matches
|
||||
}
|
||||
|
||||
// edeCodeFor maps an upstream lookup error to the NetBird EDE info code.
|
||||
func edeCodeFor(dnsErr *net.DNSError) uint16 {
|
||||
if dnsErr != nil && dnsErr.IsTimeout {
|
||||
return edeNetbirdUpstreamTimeout
|
||||
}
|
||||
return edeNetbirdUpstreamFailure
|
||||
}
|
||||
|
||||
// edeText builds the EDE extra-text describing the class of upstream failure.
|
||||
// It deliberately omits the upstream server address, which may be an internal
|
||||
// resolver and is exposed to any client permitted to use the route; the full
|
||||
// detail stays in the forwarder's local log.
|
||||
func edeText(dnsErr *net.DNSError) string {
|
||||
if dnsErr != nil && dnsErr.IsTimeout {
|
||||
return "netbird forwarder: upstream timeout"
|
||||
}
|
||||
return "netbird forwarder: upstream failure"
|
||||
}
|
||||
|
||||
// attachEDE adds an Extended DNS Error (RFC 8914) option to the response,
|
||||
// creating the OPT pseudo-record if the response does not already carry one.
|
||||
func attachEDE(resp *dns.Msg, code uint16, text string) {
|
||||
opt := resp.IsEdns0()
|
||||
if opt == nil {
|
||||
resp.SetEdns0(dns.DefaultMsgSize, false)
|
||||
opt = resp.IsEdns0()
|
||||
}
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: code, ExtraText: text})
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -617,6 +618,85 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lookupErr error
|
||||
reqEdns bool
|
||||
wantEDE bool
|
||||
wantCode uint16
|
||||
wantTextHas string
|
||||
}{
|
||||
{
|
||||
name: "timeout with edns0",
|
||||
lookupErr: &net.DNSError{Err: "i/o timeout", Server: "10.0.0.53:53", IsTimeout: true},
|
||||
reqEdns: true,
|
||||
wantEDE: true,
|
||||
wantCode: edeNetbirdUpstreamTimeout,
|
||||
wantTextHas: "netbird forwarder: upstream timeout",
|
||||
},
|
||||
{
|
||||
name: "server failure with edns0",
|
||||
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
|
||||
reqEdns: true,
|
||||
wantEDE: true,
|
||||
wantCode: edeNetbirdUpstreamFailure,
|
||||
wantTextHas: "netbird forwarder: upstream failure",
|
||||
},
|
||||
{
|
||||
name: "no edns0 in request omits ede",
|
||||
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
|
||||
reqEdns: false,
|
||||
wantEDE: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
require.NoError(t, err)
|
||||
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
|
||||
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||
Return([]netip.Addr(nil), tt.lookupErr).Once()
|
||||
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion("example.com.", dns.TypeA)
|
||||
if tt.reqEdns {
|
||||
query.SetEdns0(dns.DefaultMsgSize, false)
|
||||
}
|
||||
|
||||
var writtenResp *dns.Msg
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writtenResp = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
mockResolver.AssertExpectations(t)
|
||||
|
||||
require.NotNil(t, writtenResp, "expected a response")
|
||||
assert.Equal(t, dns.RcodeServerFailure, writtenResp.Rcode, "upstream failure must be SERVFAIL")
|
||||
|
||||
ede, ok := resutil.ExtractEDE(writtenResp)
|
||||
if !tt.wantEDE {
|
||||
assert.False(t, ok, "response must not carry EDE")
|
||||
return
|
||||
}
|
||||
require.True(t, ok, "response must carry EDE")
|
||||
assert.Equal(t, tt.wantCode, ede.InfoCode, "EDE info code")
|
||||
assert.Contains(t, ede.ExtraText, tt.wantTextHas, "EDE extra-text")
|
||||
assert.NotContains(t, ede.ExtraText, "10.0.0.53", "must not leak upstream server address")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
// Test that large UDP responses are truncated with TC bit set
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
@@ -86,6 +86,8 @@ const (
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
|
||||
var ErrEngineAlreadyStarted = errors.New("engine already started")
|
||||
|
||||
type EngineConfig struct {
|
||||
WgPort int
|
||||
WgIfaceName string
|
||||
@@ -199,6 +201,8 @@ type Engine struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
started bool
|
||||
|
||||
wgInterface WGIface
|
||||
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
@@ -279,9 +283,15 @@ func NewEngine(
|
||||
services EngineServices,
|
||||
mobileDep MobileDependency,
|
||||
) *Engine {
|
||||
// The engine is single-use: a fresh instance is built per connection
|
||||
// cycle (see Client.run), so the run context is created once here rather
|
||||
// than in Start.
|
||||
ctx, cancel := context.WithCancel(clientCtx)
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
@@ -314,8 +324,34 @@ func (e *Engine) Stop() error {
|
||||
log.Debugf("tried stopping engine that is nil")
|
||||
return nil
|
||||
}
|
||||
e.cancel()
|
||||
e.syncMsgMux.Lock()
|
||||
|
||||
e.stopLocked()
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopLocked tears down everything Start may have brought up, in the order
|
||||
// teardown requires (DNS before the interface goes down, flow manager after).
|
||||
// The caller must hold syncMsgMux. It is shared by Stop and by Start's failure
|
||||
// path, so a partially-initialized engine is cleaned up the same way; every
|
||||
// step is nil-guarded. It does not wait on shutdownWg — the caller does that
|
||||
// after releasing the lock, since the goroutines also take syncMsgMux.
|
||||
func (e *Engine) stopLocked() {
|
||||
if e.connMgr != nil {
|
||||
e.connMgr.Close()
|
||||
}
|
||||
@@ -366,10 +402,6 @@ func (e *Engine) Stop() error {
|
||||
// so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
e.jobExecutorWG.Wait() // block until job goroutines finish
|
||||
|
||||
e.close()
|
||||
@@ -388,21 +420,6 @@ func (e *Engine) Stop() error {
|
||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
||||
@@ -440,18 +457,38 @@ func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
|
||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||
// Connections to remote peers are not established here.
|
||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
|
||||
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) (err error) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if err := iface.ValidateMTU(e.config.MTU); err != nil {
|
||||
// The engine is single-use. Reject a duplicate start and a start on an
|
||||
// already-stopped engine (run context cancelled).
|
||||
if e.started {
|
||||
return ErrEngineAlreadyStarted
|
||||
}
|
||||
|
||||
if ctxErr := e.ctx.Err(); ctxErr != nil {
|
||||
return fmt.Errorf("engine already stopped: %w", ctxErr)
|
||||
}
|
||||
|
||||
e.started = true
|
||||
|
||||
// Tear down any partially-initialized state on a failed start. Cancel the
|
||||
// run context first so goroutines started before the failure (connMgr,
|
||||
// srWatcher, monitors) unwind, then stopLocked mirrors Stop's teardown (we
|
||||
// already hold syncMsgMux), cleaning up route/DNS/flow/state managers too,
|
||||
// not just what close() covers.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
e.cancel()
|
||||
e.stopLocked()
|
||||
}
|
||||
}()
|
||||
|
||||
if err = iface.ValidateMTU(e.config.MTU); err != nil {
|
||||
return fmt.Errorf("invalid MTU configuration: %w", err)
|
||||
}
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
@@ -485,13 +522,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("read initial settings: %w", err)
|
||||
}
|
||||
|
||||
dnsServer, err := e.newDnsServer(dnsConfig)
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("create dns server: %w", err)
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
@@ -526,7 +561,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
if err = e.wgInterfaceCreate(); err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
e.close()
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
}
|
||||
|
||||
@@ -535,7 +569,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
}
|
||||
|
||||
if err := e.createFirewall(); err != nil {
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -547,7 +580,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.udpMux, err = e.wgInterface.Up()
|
||||
if err != nil {
|
||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||
e.close()
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
@@ -572,9 +604,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.acl = acl.NewDefaultManager(e.firewall)
|
||||
}
|
||||
|
||||
err = e.dnsServer.Initialize()
|
||||
if err != nil {
|
||||
e.close()
|
||||
if err := e.dnsServer.Initialize(); err != nil {
|
||||
return fmt.Errorf("initialize dns server: %w", err)
|
||||
}
|
||||
|
||||
@@ -586,7 +616,9 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start(peer.IsForceRelayed())
|
||||
|
||||
e.receiveSignalEvents()
|
||||
if err = e.receiveSignalEvents(); err != nil {
|
||||
return err
|
||||
}
|
||||
e.receiveManagementEvents()
|
||||
e.receiveJobEvents()
|
||||
|
||||
@@ -638,7 +670,6 @@ func (e *Engine) createFirewall() error {
|
||||
|
||||
func (e *Engine) initFirewall() error {
|
||||
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("set firewall: %w", err)
|
||||
}
|
||||
|
||||
@@ -1698,7 +1729,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
}
|
||||
|
||||
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
||||
func (e *Engine) receiveSignalEvents() {
|
||||
func (e *Engine) receiveSignalEvents() error {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
@@ -1769,7 +1800,12 @@ func (e *Engine) receiveSignalEvents() {
|
||||
}
|
||||
}()
|
||||
|
||||
e.signal.WaitStreamConnected()
|
||||
// todo: consider to remove this blocker. I do not see benefit to block the Start operations
|
||||
e.signal.WaitStreamConnected(e.ctx)
|
||||
if err := e.ctx.Err(); err != nil {
|
||||
return fmt.Errorf("wait for signal stream: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
|
||||
@@ -247,7 +247,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
@@ -426,7 +426,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
@@ -638,7 +638,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
// feed updates to Engine via mocked Management client
|
||||
@@ -817,7 +817,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
@@ -1024,7 +1024,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
|
||||
@@ -251,6 +251,14 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
// Advertise EDNS0 to the forwarder so it may return an Extended DNS Error
|
||||
// describing why a lookup failed. The OPT is stripped from the reply when
|
||||
// the original client did not request EDNS0.
|
||||
hadEdns := r.IsEdns0() != nil
|
||||
if !hadEdns {
|
||||
r.SetEdns0(dns.DefaultMsgSize, false)
|
||||
}
|
||||
|
||||
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
@@ -260,6 +268,13 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
if ede, ok := resutil.ExtractEDE(reply); ok {
|
||||
resutil.SetMeta(w, "ede", fmt.Sprintf("%d %s", ede.InfoCode, ede.ExtraText))
|
||||
}
|
||||
if !hadEdns {
|
||||
resutil.StripOPT(reply)
|
||||
}
|
||||
|
||||
resutil.SetMeta(w, "peer", peerKey)
|
||||
|
||||
reply.Id = r.Id
|
||||
|
||||
@@ -36,6 +36,7 @@ type URLOpener interface {
|
||||
// Auth can register or login new client
|
||||
type Auth struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
config *profilemanager.Config
|
||||
cfgPath string
|
||||
}
|
||||
@@ -51,8 +52,19 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use a cancellable context so Stop() can abort an in-progress interactive
|
||||
// login. The PKCE flow's WaitToken blocks (and keeps its loopback HTTP server
|
||||
// bound to a port) until the OAuth callback arrives or the flow expires;
|
||||
// cancelling the context unblocks WaitToken, which then shuts that server down
|
||||
// and frees the port for the next login attempt. iOS runs login in the main-app
|
||||
// process (decoupled from the network extension), so without this the server
|
||||
// lingers after the user dismisses the browser and the next connect stalls
|
||||
// trying to bind the same port.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &Auth{
|
||||
ctx: context.Background(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
config: cfg,
|
||||
cfgPath: cfgPath,
|
||||
}, nil
|
||||
@@ -60,12 +72,24 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||
|
||||
// NewAuthWithConfig instantiate Auth based on existing config
|
||||
func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Auth{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// Stop aborts an in-progress interactive login started via Login/LoginWithDeviceName.
|
||||
// It cancels the auth context, which unblocks the PKCE WaitToken and shuts down its
|
||||
// loopback HTTP server, freeing the redirect port. Safe to call multiple times and
|
||||
// safe to call when no login is running.
|
||||
func (a *Auth) Stop() {
|
||||
if a.cancel != nil {
|
||||
a.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
||||
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
||||
// is not supported and returns false without saving the configuration. For other errors return false.
|
||||
|
||||
@@ -993,6 +993,10 @@ func (s *Server) cleanupConnection() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: consider calling s.connectClient.Stop() instead of engine.Stop().
|
||||
// actCancel() lets the run loop stop the engine too, so both stop it
|
||||
// concurrently; ConnectClient.Stop cancels and waits for the run loop,
|
||||
// making the run loop the sole owner of engine shutdown.
|
||||
if engine != nil {
|
||||
if err := engine.Stop(); err != nil {
|
||||
return err
|
||||
|
||||
@@ -418,7 +418,14 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
||||
case args.showProfiles:
|
||||
s.showProfilesUI()
|
||||
case args.showQuickActions:
|
||||
s.showQuickActionsUI()
|
||||
// Suppress the on-boot Quick Actions popup when the daemon
|
||||
// reports DisableAutoConnect=true — that flag carries both the
|
||||
// user's "Connect on Startup = off" preference AND any MDM-
|
||||
// enforced override (applyMDMPolicy writes the policy value
|
||||
// into the same Config field). See netbirdio/netbird#5744.
|
||||
if !s.disableAutoConnectFromDaemon() {
|
||||
s.showQuickActionsUI()
|
||||
}
|
||||
case args.showUpdate:
|
||||
s.showUpdateProgress(ctx, args.showUpdateVersion)
|
||||
}
|
||||
@@ -1338,6 +1345,40 @@ func (s *serviceClient) getFeatures() (*proto.GetFeaturesResponse, error) {
|
||||
return features, nil
|
||||
}
|
||||
|
||||
// disableAutoConnectFromDaemon returns true when the daemon reports
|
||||
// the active profile has DisableAutoConnect=true. Used by the
|
||||
// --quick-actions startup path to suppress the on-boot popup when the
|
||||
// user (or an MDM admin) opted out of auto-connecting; both cases
|
||||
// converge on the same Config field because applyMDMPolicy writes the
|
||||
// policy value into it. Returns false on any RPC / lookup failure so a
|
||||
// daemon hiccup does not silently swallow the popup.
|
||||
func (s *serviceClient) disableAutoConnectFromDaemon() bool {
|
||||
activeProf, err := s.profileManager.GetActiveProfile()
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get active profile: %v", err)
|
||||
return false
|
||||
}
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get current user: %v", err)
|
||||
return false
|
||||
}
|
||||
conn, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get daemon client: %v", err)
|
||||
return false
|
||||
}
|
||||
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: GetConfig RPC: %v", err)
|
||||
return false
|
||||
}
|
||||
return srvCfg.GetDisableAutoConnect()
|
||||
}
|
||||
|
||||
// getSrvConfig from the service to show it in the settings window.
|
||||
func (s *serviceClient) getSrvConfig() {
|
||||
s.managementURL = profilemanager.DefaultManagementURL
|
||||
|
||||
109
docs/agent-networks/00-overview.md
Normal file
109
docs/agent-networks/00-overview.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# Agent Networks — overview
|
||||
|
||||
Single-entry point. Feature scope, the module map, and the cross-cutting
|
||||
topics worth keeping in mind, with links into every per-module guide.
|
||||
|
||||
## TL;DR
|
||||
|
||||
Agent Networks introduces an **LLM-aware reverse-proxy middleware system**
|
||||
plus **account-level controls** (budget rules, log collection toggles,
|
||||
PII redaction). The management server synthesises a per-peer middleware
|
||||
chain that the proxy executes on every LLM request; the chain enforces
|
||||
quotas, injects identity, redacts PII, parses tokens/cost, and emits
|
||||
access-log entries. The dashboard exposes the surface as a single **AI
|
||||
Observability** page with four tabs.
|
||||
|
||||
- **Backend** lives in this repo, primarily under
|
||||
`management/server/agentnetwork`, `proxy/internal/middleware`, and
|
||||
`proxy/internal/llm`, with wire contracts in `shared/management`.
|
||||
- **Dashboard** lives in the dashboard repo under
|
||||
`src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`.
|
||||
|
||||
## Reading order
|
||||
|
||||
| # | Doc | Why |
|
||||
|---|-----|-----|
|
||||
| 1 | [01-end-to-end-flows.md](01-end-to-end-flows.md) | Get the three big diagrams in your head first. |
|
||||
| 2 | [modules/10-shared-api.md](modules/10-shared-api.md) | Wire contracts — every other module either produces or consumes these. |
|
||||
| 3 | [modules/21-management-agentnetwork.md](modules/21-management-agentnetwork.md) | The largest module; everything the proxy executes originates here. |
|
||||
| 4 | [modules/30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md) | The generic plugin system on the proxy side. |
|
||||
| 5 | [modules/31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md) | The 8 LLM middlewares that ride on the framework. |
|
||||
| 6 | Everything else in any order. | |
|
||||
|
||||
## Module map
|
||||
|
||||
11 modules. Each is described in detail in its own file under
|
||||
[`modules/`](modules/).
|
||||
|
||||
| # | Module | Risk | BC impact |
|
||||
|---|--------|------|-----------|
|
||||
| 10 | [shared/api](modules/10-shared-api.md) — proto + OpenAPI | Low | Additive only |
|
||||
| 20 | [management/store](modules/20-management-store.md) — SQL persistence | Medium | Auto-migrate (additive) |
|
||||
| 21 | [management/agentnetwork](modules/21-management-agentnetwork.md) — domain layer + synthesizer | **High** | Additive |
|
||||
| 22 | [management/handlers + wiring](modules/22-management-handlers-wiring.md) — HTTP API + gRPC delivery | Medium | Additive |
|
||||
| 30 | [proxy/middleware-framework](modules/30-proxy-middleware-framework.md) — generic plugin system | High | Additive |
|
||||
| 31 | [proxy/middleware-builtin](modules/31-proxy-middleware-builtin.md) — 8 LLM middlewares | High | Additive |
|
||||
| 32 | [proxy/llm-parsers](modules/32-proxy-llm-parsers.md) — SDK adapters + pricing | Medium | Additive |
|
||||
| 33 | [proxy/runtime](modules/33-proxy-runtime.md) — translate + serve + access-log | High | Additive (touches hot path) |
|
||||
| 40 | [dashboard](modules/40-dashboard.md) — UI for everything above | Medium | Sidebar reshape |
|
||||
| 50 | [path-routed-providers](modules/50-path-routed-providers.md) — Vertex AI + Bedrock | Medium | Additive (new catalog entries) |
|
||||
|
||||
The largest and highest-risk module is `management/agentnetwork`: it is
|
||||
the single writer of the middleware chain the proxy executes.
|
||||
|
||||
## Cross-cutting topics
|
||||
|
||||
These are the items most likely to bite production. Each is fully
|
||||
documented in the linked module guide.
|
||||
|
||||
1. **Capture-pointer semantics** (`*bool` for `capture_prompt` and
|
||||
`capture_completion`): nil = legacy emit, false = suppress, true =
|
||||
emit. nil-vs-false must be handled at every JSON hop. See
|
||||
[21-management-agentnetwork.md](modules/21-management-agentnetwork.md)
|
||||
and [31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md).
|
||||
2. **`ProxyMapping.Private` preservation** on per-proxy live updates.
|
||||
Failure mode: `auth` skips `ValidateTunnelPeer` →
|
||||
`CapturedData.UserGroups` empty → `llm_router` denies. See
|
||||
[33-proxy-runtime.md](modules/33-proxy-runtime.md).
|
||||
3. **respInput carrying `UserEmail`/`UserGroups`/`UserGroupNames` onto
|
||||
the response leg** in `reverseproxy.go`. Load-bearing wire that lets
|
||||
`llm_limit_record` ship non-empty `group_ids` on `RecordLLMUsage`. See
|
||||
[33-proxy-runtime.md](modules/33-proxy-runtime.md).
|
||||
4. **Min-wins all-must-pass budget rule semantics**. Every matching
|
||||
rule's remaining quota must be > 0 for the request to proceed; one
|
||||
exhausted rule blocks the whole call. Documented in
|
||||
[21-management-agentnetwork.md](modules/21-management-agentnetwork.md)
|
||||
and the `llm_limit_check` middleware in
|
||||
[31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md).
|
||||
5. **body-tap memory bounds**: per-direction 1 MiB cap, shared 256 MiB
|
||||
budget, `LimitReader(r.Body, limit+1)` for truncation detection with
|
||||
`replayReadCloser` fallback so upstream still sees the full body.
|
||||
`cloneInputFor` deep-copies the body up to 16 times per chain — a
|
||||
perf hot-spot. See
|
||||
[30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md).
|
||||
6. **UpstreamRewrite.AuthHeader bypasses the header denylist**
|
||||
deliberately. The runtime consumer only unpacks it via the
|
||||
trusted upstream-build path. See
|
||||
[30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md).
|
||||
7. **`disable_access_log` default-false semantics**: the synth target
|
||||
sets it true, all other targets leave it false. See
|
||||
[10-shared-api.md](modules/10-shared-api.md).
|
||||
8. **String-typed `decision` / `deny_code`** on
|
||||
`CheckLLMPolicyLimitsResponse` — would benefit from enum pinning
|
||||
before external consumers integrate. See
|
||||
[10-shared-api.md](modules/10-shared-api.md).
|
||||
|
||||
## Explicit non-goals
|
||||
|
||||
- **Reaper / GC pass over stale synth services** — designed but cut from
|
||||
scope.
|
||||
- **URL-sync for tab state on AI Observability** — read path is wired
|
||||
(`?tab=`) but write path isn't. Future work.
|
||||
- **CI golden-file regen-and-diff for `types.gen.go` /
|
||||
`proxy_service.pb.go`** — would catch codegen drift; not yet in place.
|
||||
|
||||
## Where to read the code
|
||||
|
||||
Per-module file scopes are listed in each module guide. Behaviour is
|
||||
covered by Go tests co-located with each package (and an end-to-end
|
||||
chain integration test under `proxy/internal/proxy`).
|
||||
217
docs/agent-networks/01-end-to-end-flows.md
Normal file
217
docs/agent-networks/01-end-to-end-flows.md
Normal file
@@ -0,0 +1,217 @@
|
||||
# End-to-end flows
|
||||
|
||||
Three cross-module mermaid diagrams. Each per-module guide repeats the
|
||||
slice that's relevant to its own scope — these are the canonical
|
||||
top-down views.
|
||||
|
||||
- [Flow A — Config → runtime (synth + deliver)](#flow-a--config--runtime-synth--deliver)
|
||||
- [Flow B — Request lifecycle through the LLM chain](#flow-b--request-lifecycle-through-the-llm-chain)
|
||||
- [Flow C — Budget rule feedback loop](#flow-c--budget-rule-feedback-loop)
|
||||
|
||||
---
|
||||
|
||||
## Flow A — Config → runtime (synth + deliver)
|
||||
|
||||
How an operator's change to a Provider, Policy, Guardrail, Budget Rule,
|
||||
or Settings record ends up as live middleware on a peer's proxy.
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
actor Op as Operator
|
||||
participant UI as Dashboard
|
||||
participant HTTP as management/handlers
|
||||
participant Mgr as agentnetwork.Manager
|
||||
participant Store as management/store (SQL)
|
||||
participant Ctl as network_map.Controller
|
||||
participant Synth as agentnetwork.SynthesizeServices
|
||||
participant Grpc as management gRPC
|
||||
participant Proxy as netbird-proxy
|
||||
participant Xlate as middleware_translate
|
||||
participant Chain as middleware.Chain
|
||||
|
||||
Op->>UI: edit provider/policy/budget/settings
|
||||
UI->>HTTP: REST PUT/POST /api/agent-network/*
|
||||
HTTP->>Mgr: SaveProvider / SavePolicy / SaveBudgetRule / SaveSettings
|
||||
Mgr->>Store: persist (gorm)
|
||||
Mgr-->>Ctl: account change event (Network-Map dirty)
|
||||
loop per connected peer
|
||||
Ctl->>Synth: SynthesizeServices(ctx, store, accountID)
|
||||
Synth->>Store: load providers, policies, guardrails, budget rules, settings
|
||||
Synth-->>Synth: build per-peer Service list
|
||||
Note over Synth: each Service has a middleware<br/>chain with capture_prompt /<br/>capture_completion / redact_pii<br/>baked from account settings
|
||||
Synth-->>Ctl: []rpservice.Service
|
||||
Ctl->>Grpc: NetworkMap push (services + middleware configs)
|
||||
end
|
||||
Grpc-->>Proxy: NetworkMap stream
|
||||
Proxy->>Xlate: translate proto MiddlewareConfig → runtime Spec
|
||||
Xlate->>Chain: register / replace per-service chain
|
||||
Note over Chain: chain replacement is live<br/>(no proxy restart, in-flight<br/>requests unaffected)
|
||||
```
|
||||
|
||||
**Notes on the diagram**
|
||||
|
||||
- The `network_map.Controller` synthesises on every push, not on a
|
||||
timer. A single config change costs O(connected peers × policies ×
|
||||
providers) per push. See [`modules/22-management-handlers-wiring.md`](modules/22-management-handlers-wiring.md).
|
||||
- `SynthesizeServices` is the single source of truth for the wire
|
||||
format the proxy executes. Anything the proxy does that the
|
||||
synthesiser didn't request is a bug. See
|
||||
[`modules/21-management-agentnetwork.md`](modules/21-management-agentnetwork.md).
|
||||
- The translate step (step 13) is the only place that knows the
|
||||
middleware-ID strings on the proxy side. It must reject unknown IDs;
|
||||
silently dropping middlewares would create a security gap (e.g.
|
||||
missing `llm_limit_check` ⇒ unbounded spend). See
|
||||
[`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md).
|
||||
|
||||
---
|
||||
|
||||
## Flow B — Request lifecycle through the LLM chain
|
||||
|
||||
What happens when an agent on the client peer sends a chat-completion /
|
||||
messages request through the synthesised reverse-proxy.
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
actor Agent as Agent (local)
|
||||
participant Px as netbird-proxy
|
||||
participant Auth as auth middleware
|
||||
participant Map as service-mapping
|
||||
participant Req as llm_request_parser
|
||||
participant Rt as llm_router
|
||||
participant Chk as llm_limit_check
|
||||
participant Inj as llm_identity_inject
|
||||
participant Grd as llm_guardrail
|
||||
participant Up as upstream LLM
|
||||
participant Resp as llm_response_parser
|
||||
participant Cost as cost_meter
|
||||
participant Rec as llm_limit_record
|
||||
participant Log as access-log
|
||||
participant MgmtGrpc as management gRPC
|
||||
|
||||
Agent->>Px: POST /v1/chat/completions (OpenAI / Anthropic)
|
||||
Px->>Auth: identify peer (user, groups)
|
||||
Auth->>Map: resolve service from Host + path
|
||||
Map-->>Req: dispatch chain in slot order
|
||||
|
||||
Req->>Req: parse body → provider, model, prompt, token estimate
|
||||
Note over Req: capture_prompt gates raw_prompt<br/>capture (nil = legacy emit,<br/>false = drop, true = emit)
|
||||
Req->>Rt: pass metadata
|
||||
Rt->>Chk: route to upstream candidate
|
||||
|
||||
Chk->>MgmtGrpc: CheckLLMPolicyLimits(provider, model, est_tokens, groups, user)
|
||||
MgmtGrpc-->>Chk: decision = allow / deny + deny_code
|
||||
alt decision == deny
|
||||
Chk-->>Log: emit access-log with deny_code<br/>(if EnableLogCollection)
|
||||
Chk-->>Agent: 429 (or 403 per deny_code)
|
||||
else decision == allow
|
||||
Chk->>Inj: continue
|
||||
Inj->>Inj: inject NetBird identity headers per provider config
|
||||
Inj->>Grd: continue
|
||||
Grd->>Grd: enforce model allowlist
|
||||
Grd->>Up: forward (over WireGuard)
|
||||
Up-->>Resp: response (JSON or SSE stream)
|
||||
Resp->>Resp: parse usage tokens, completion
|
||||
Note over Resp: capture_completion gates raw<br/>completion capture
|
||||
Resp->>Cost: tokens
|
||||
Cost->>Cost: lookup pricing.yaml + compute cost
|
||||
Cost->>Rec: tokens + cost
|
||||
Rec->>MgmtGrpc: RecordLLMUsage(provider, model, prompt_t, completion_t, cost, groups, user)
|
||||
Rec-->>Log: emit access-log entry<br/>(if EnableLogCollection)
|
||||
Log-->>Agent: 200 + body (streamed if SSE)
|
||||
end
|
||||
```
|
||||
|
||||
**Notes on the diagram**
|
||||
|
||||
- The chain runs in synth-defined order. Re-ordering middlewares
|
||||
changes invariants — `llm_limit_check` must precede `llm_router` so
|
||||
a denied request never hits upstream, and `llm_limit_record` must
|
||||
pair with `llm_limit_check` so a successful check is always recorded
|
||||
(or the rate-limit semantics break). See
|
||||
[`modules/31-proxy-middleware-builtin.md`](modules/31-proxy-middleware-builtin.md).
|
||||
- `llm_guardrail` is also where PII redaction happens
|
||||
(`redact_pii = settings.RedactPii`). Phones, emails, credit cards,
|
||||
PII names — see `redact.go` for the full set. See
|
||||
[`modules/31-proxy-middleware-builtin.md`](modules/31-proxy-middleware-builtin.md).
|
||||
- SSE streaming requires special handling on the response side; the
|
||||
parser must handle partial chunks without buffering the whole
|
||||
stream. See [`modules/32-proxy-llm-parsers.md`](modules/32-proxy-llm-parsers.md).
|
||||
- Access-log emission is gated on `settings.EnableLogCollection`. With
|
||||
it OFF, neither the deny nor the allow leg writes an entry — the
|
||||
chain still runs (budget rules are still enforced) but no audit trail
|
||||
is kept. See
|
||||
[`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md).
|
||||
|
||||
---
|
||||
|
||||
## Flow C — Budget rule feedback loop
|
||||
|
||||
How an account's budget rules tighten ceilings on every request and how
|
||||
consumption flows back into the dashboard.
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph Operator
|
||||
DashBud[Dashboard Budget Settings tab]
|
||||
end
|
||||
subgraph Mgmt[Management]
|
||||
Save[POST/PUT /api/agent-network/budget-rules]
|
||||
Store[(SQL store)]
|
||||
Synth[SynthesizeServices]
|
||||
Check[CheckLLMPolicyLimits RPC]
|
||||
Rec[RecordLLMUsage RPC]
|
||||
Cons[/api/agent-network/consumption]
|
||||
end
|
||||
subgraph Proxy[Proxy]
|
||||
Chk[llm_limit_check]
|
||||
RecMw[llm_limit_record]
|
||||
end
|
||||
subgraph DashView[Dashboard Budget Dashboard tab]
|
||||
Panel[AgentConsumptionPanel]
|
||||
end
|
||||
|
||||
DashBud -->|create / update rules| Save
|
||||
Save --> Store
|
||||
Store --> Synth
|
||||
Synth -->|push synth-services to peer| Proxy
|
||||
|
||||
Chk -->|per request| Check
|
||||
Check -->|aggregate matching rules<br/>min-wins all-must-pass| Store
|
||||
Check -->|allow / deny| Chk
|
||||
|
||||
RecMw -->|post-response| Rec
|
||||
Rec -->|tokens + cost + groups + user| Store
|
||||
|
||||
Store -->|read counters| Cons
|
||||
Cons --> Panel
|
||||
```
|
||||
|
||||
**Notes on the diagram**
|
||||
|
||||
- **min-wins all-must-pass** is the core semantic. A budget rule binds
|
||||
to (group set, user set) with a (window, ceiling). At check time,
|
||||
every rule that matches the caller is evaluated; if ANY rule has
|
||||
zero remaining quota the request is denied. This is the most
|
||||
surprising semantic for operators — see the invariants section of
|
||||
[`modules/21-management-agentnetwork.md`](modules/21-management-agentnetwork.md).
|
||||
- The proxy never makes its own budget decisions. It always asks
|
||||
management via `CheckLLMPolicyLimits` and reports back via
|
||||
`RecordLLMUsage`. This keeps account-wide accounting in one place
|
||||
and avoids per-proxy drift.
|
||||
- `RecordLLMUsage` must carry `group_ids` and `user_id` so the
|
||||
decrement hits the right rule(s). The wire that carries those
|
||||
fields onto the response leg is `respInput` in `reverseproxy.go`. See
|
||||
[`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md).
|
||||
- The dashboard's Budget Dashboard tab polls
|
||||
`/api/agent-network/consumption` — not gRPC, not WebSocket. Poll
|
||||
interval lives in `AgentConsumptionPanel.tsx`. See
|
||||
[`modules/40-dashboard.md`](modules/40-dashboard.md).
|
||||
|
||||
---
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Per-module guides: [`modules/`](modules/)
|
||||
- Overview + module map: [`00-overview.md`](00-overview.md)
|
||||
66
docs/agent-networks/README.md
Normal file
66
docs/agent-networks/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# Agent Networks — architecture documentation
|
||||
|
||||
A self-contained set of documents describing the agent-networks feature:
|
||||
an LLM-aware reverse-proxy middleware system plus account-level controls
|
||||
(budget rules, log collection toggles, PII redaction). The management
|
||||
server synthesises a per-peer middleware chain that the proxy executes on
|
||||
every LLM request.
|
||||
|
||||
## What to read first
|
||||
|
||||
1. **[00-overview.md](00-overview.md)** — the single entry point. Feature
|
||||
scope, the module map, and the cross-cutting topics worth keeping in
|
||||
mind, with links to every per-module guide.
|
||||
2. **[01-end-to-end-flows.md](01-end-to-end-flows.md)** — three
|
||||
high-level mermaid diagrams: config-to-runtime synth/delivery,
|
||||
per-request lifecycle through the LLM chain, and the budget-rule
|
||||
feedback loop.
|
||||
3. **Per-module guides** under `modules/` — one file per package. Each
|
||||
describes the module boundary, the file-level layout, its own flow
|
||||
diagrams, the public contracts, the invariants it relies on, and the
|
||||
areas worth the closest attention.
|
||||
|
||||
## Directory layout
|
||||
|
||||
```
|
||||
docs/agent-networks/
|
||||
├── README.md # you are here
|
||||
├── 00-overview.md # feature summary + module map
|
||||
├── 01-end-to-end-flows.md # cross-module mermaid diagrams
|
||||
└── modules/
|
||||
├── 10-shared-api.md # proto + OpenAPI wire contracts
|
||||
├── 20-management-store.md # SQL persistence layer
|
||||
├── 21-management-agentnetwork.md # domain layer + synthesizer (largest)
|
||||
├── 22-management-handlers-wiring.md # HTTP API + gRPC delivery
|
||||
├── 30-proxy-middleware-framework.md # generic plugin system
|
||||
├── 31-proxy-middleware-builtin.md # 8 LLM-aware middlewares
|
||||
├── 32-proxy-llm-parsers.md # OpenAI/Anthropic/Bedrock SDKs + pricing
|
||||
├── 33-proxy-runtime.md # translate + serve + access-log
|
||||
├── 40-dashboard.md # UI for everything above (lives in the dashboard repo)
|
||||
└── 50-path-routed-providers.md # Vertex AI + Bedrock (path-routed, keyfile:: creds, /bedrock prefix)
|
||||
```
|
||||
|
||||
The `40-dashboard.md` module documents code that lives in the **dashboard
|
||||
repo**, not in this repo. The guide is co-located here so backend readers
|
||||
see the full picture in one place.
|
||||
|
||||
## How the per-module guides are structured
|
||||
|
||||
Every `modules/*.md` follows the same template so the docs are easy to
|
||||
scan:
|
||||
|
||||
- **Module boundary** — what this package owns; where it sits in the stack.
|
||||
- **Files** — path / role.
|
||||
- **Architecture & flow** — one or more mermaid diagrams.
|
||||
- **Public contracts** — function signatures, gRPC messages, JSON shapes.
|
||||
- **Invariants** — semantic guarantees the module relies on or enforces.
|
||||
- **Things to scrutinize** — split by correctness / security /
|
||||
concurrency / backward-compat / performance / observability.
|
||||
- **Test coverage** — the test files that lock down behaviour in this
|
||||
module.
|
||||
- **Known limitations / non-goals** — what is intentionally out of scope.
|
||||
- **Cross-references** — upstream/downstream module links + the
|
||||
end-to-end flow + the overview.
|
||||
|
||||
See [00-overview.md](00-overview.md) for the module map and the
|
||||
cross-cutting topics.
|
||||
105
docs/agent-networks/modules/10-shared-api.md
Normal file
105
docs/agent-networks/modules/10-shared-api.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# shared/api — wire contracts (proto + OpenAPI)
|
||||
|
||||
> **Risk level:** Medium — wire-format surface that every other module pins against; backward-compat hinges on field-number discipline more than on logic correctness.
|
||||
> **Backward-compat impact:** Additive only (new proto fields use unallocated numbers, new RPCs default to `Unimplemented`, new OpenAPI schemas/paths are append-only; no existing field/RPC/schema removed or renumbered).
|
||||
|
||||
## Module boundary
|
||||
This module owns the cross-process contract surface between management, proxy, and dashboard. Two artefacts: `shared/management/proto/proxy_service.proto` (management↔proxy gRPC) and `shared/management/http/api/openapi.yml` (dashboard/CLI↔management REST). Both have generated companions checked in (`proxy_service.pb.go`, `proxy_service_grpc.pb.go`, `types.gen.go`) which must travel in lockstep with their sources. `shared/management/status/error.go` is in scope only for the four new typed `NotFound` constructors that the new HTTP handlers return.
|
||||
|
||||
Everything downstream — `management/agentnetwork`, `management/server/http/handlers/*`, `proxy/internal/*`, the dashboard SDK — consumes these types verbatim. The concern here is wire stability and codegen reproducibility, not behaviour: behaviour is covered in the management and proxy module guides.
|
||||
|
||||
`management.proto` and `signalexchange.proto` are unchanged. `status/error.go` only receives four additive constructors (lines 208-227); no existing error types are reshaped.
|
||||
|
||||
## Files
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `shared/management/proto/proxy_service.proto` | Source of truth: 2 new RPCs, 1 new message group (`MiddlewareConfig` + slot enum), additive fields on `PathTargetOptions`, `AccessLog`, `RecordLLMUsageRequest` |
|
||||
| `shared/management/proto/proxy_service.pb.go` | Generated (protoc-gen-go) |
|
||||
| `shared/management/proto/proxy_service_grpc.pb.go` | Generated; adds `CheckLLMPolicyLimits` + `RecordLLMUsage` client/server stubs and `UnimplementedProxyServiceServer` defaults |
|
||||
| `shared/management/http/api/openapi.yml` | 15 new `AgentNetwork*` schemas, 9 new path groups under `/api/agent-network/*` |
|
||||
| `shared/management/http/api/types.gen.go` | Generated (oapi-codegen; see codegen note below) |
|
||||
| `shared/management/status/error.go` | Four `NotFound` constructors for the new resource kinds (lines 208-227) |
|
||||
|
||||
## Architecture & flow
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Dash as Dashboard / CLI
|
||||
participant Mgmt as management (HTTP+gRPC)
|
||||
participant Px as proxy
|
||||
|
||||
Note over Dash,Mgmt: REST (OpenAPI / types.gen.go)
|
||||
Dash->>Mgmt: PUT /api/agent-network/providers (AgentNetworkProviderRequest)
|
||||
Dash->>Mgmt: PUT /api/agent-network/settings (AgentNetworkSettingsRequest)
|
||||
Dash->>Mgmt: GET /api/agent-network/consumption -> [AgentNetworkConsumption]
|
||||
|
||||
Note over Mgmt,Px: gRPC ProxyService (proxy_service.proto)
|
||||
Mgmt-->>Px: SyncMappingsResponse{ ProxyMapping.path[*].options.middlewares,<br/>agent_network, disable_access_log, capture_* }
|
||||
Px->>Mgmt: CheckLLMPolicyLimits(account, user, groups, provider, model)
|
||||
Mgmt-->>Px: decision=allow|deny + selected_policy_id + attribution_group_id + window_seconds
|
||||
Px->>Mgmt: RecordLLMUsage(account, user, group_id, group_ids, window_seconds, tokens, cost)
|
||||
Px->>Mgmt: SendAccessLog(AccessLog{ agent_network=true })
|
||||
```
|
||||
|
||||
The proto changes split into three independent slices: (1) **mapping enrichment** — `PathTargetOptions` grows fields 8-13 so management can ship middleware configs, capture limits, and the agent-network / log-suppression flags down to the proxy without a second RPC; (2) **two new request/response RPCs** (`CheckLLMPolicyLimits`, `RecordLLMUsage`) for per-LLM-request budget arbitration; (3) **observability tag** — `AccessLog.agent_network` so management can route logs to the right surface.
|
||||
|
||||
The OpenAPI side is a thin CRUD surface — every resource (`Provider`, `Policy`, `Guardrail`, `BudgetRule`, `Settings`) follows the same `GET-list / POST / GET / PUT / DELETE` pattern, plus a read-only `/consumption` listing and a catalog endpoint. The `*Request` variants drop server-controlled fields (id, timestamps). `AgentNetworkBudgetRule` deliberately reuses `AgentNetworkPolicyLimits` to keep wire-shape parity with policies.
|
||||
|
||||
## Public contracts added
|
||||
- gRPC RPCs (`proxy_service.proto:52-57`): `CheckLLMPolicyLimits(CheckLLMPolicyLimitsRequest) → CheckLLMPolicyLimitsResponse`, `RecordLLMUsage(RecordLLMUsageRequest) → RecordLLMUsageResponse`. Both unary; default `UnimplementedProxyServiceServer` returns `codes.Unimplemented` (`proxy_service_grpc.pb.go:283-289`).
|
||||
- New messages (`proxy_service.proto:145-175,448-502`): `MiddlewareConfig`, `MiddlewareSlot` enum, `CheckLLMPolicyLimitsRequest`/`Response`, `RecordLLMUsageRequest`/`Response`.
|
||||
- New `PathTargetOptions` fields 8-13 (`proxy_service.proto:124-140`): `capture_max_request_bytes`, `capture_max_response_bytes`, `capture_content_types`, `middlewares`, `agent_network`, `disable_access_log`. All default-false / zero; pre-existing fields 1-7 byte-for-byte unchanged.
|
||||
- `AccessLog.agent_network = 18` (`proxy_service.proto:258-261`).
|
||||
- `RecordLLMUsageRequest.group_ids = 8` (`proxy_service.proto:496-498`) — so the record path can fan out to every applicable budget rule's window without a re-lookup.
|
||||
- 15 new OpenAPI component schemas (`openapi.yml:5072-5829`): `AgentNetworkProvider[Request|Model]`, `AgentNetworkCatalog{Model,Provider,IdentityInjection,HeaderPairInjection,JSONMetadataInjection,ExtraHeader}`, `AgentNetworkPolicy[Request|TokenLimit|BudgetLimit|Limits]`, `AgentNetworkGuardrail[Checks|Request]`, `AgentNetworkConsumption`, `AgentNetworkSettings[Request]`, `AgentNetworkBudgetRule[Request]`.
|
||||
- 9 new path groups (`openapi.yml:12797-13460`): `/api/agent-network/{consumption,settings,budget-rules,budget-rules/{ruleId},catalog/providers,providers,providers/{providerId},policies,policies/{policyId},guardrails,guardrails/{guardrailId}}`.
|
||||
- Four typed NotFound errors (`shared/management/status/error.go:208-227`).
|
||||
|
||||
## Invariants
|
||||
- **Field-number monotonicity.** Every new proto field uses a previously-unallocated number in its message: `PathTargetOptions` 8-13 (was 1-7), `AccessLog` 18 (was 1-17), `RecordLLMUsageRequest` 8. `SendStatusUpdateRequest.inbound_listener = 50` (pre-existing) reserves 50+ for observability extensions, so 8 on `RecordLLMUsageRequest` doesn't conflict.
|
||||
- **Old proxies stay compatible.** Old management never sends `disable_access_log`/`middlewares`/`agent_network` (zero value → existing behaviour); old proxies that don't decode these fields just drop them silently (proto3 unknown-field semantics) — log emission stays on. No pre-existing field number changed: the proto change is insertions only.
|
||||
- **Old management stays compatible.** The two new RPCs are registered on the same `management.ProxyService` descriptor; old proxies hitting them get `codes.Unimplemented` from the unimplemented embed (`proxy_service_grpc.pb.go:283-289`), which is the same fallback pattern `SyncMappings` already documents (`proxy_service.proto:20-21`).
|
||||
- **OpenAPI shapes are append-only.** New schemas are placed at the end of `components.schemas` (line 5072+); new paths at the end of `paths` (line 12797+). No existing schema's `required` list, enum, or property type was changed.
|
||||
- **`*Request` vs response asymmetry.** Read shapes (`AgentNetworkProvider`, `AgentNetworkPolicy`, `AgentNetworkGuardrail`, `AgentNetworkSettings`, `AgentNetworkBudgetRule`) require `created_at`/`updated_at`; the matching `*Request` shapes do not — server fills them. `AgentNetworkProviderRequest.api_key` is write-only (`openapi.yml:5158-5161` "never returned in responses"); reviewers should confirm the response schema (5072-5138) actually omits `api_key`.
|
||||
|
||||
## Things to scrutinize
|
||||
### Correctness
|
||||
- `RecordLLMUsageRequest` carries both `group_id` (singular, the attribution group — field 3) and `group_ids` (plural, full membership — field 8). `b22d5a181` adds field 8 to drive account-budget fan-out; double-check that consumers can't accidentally key counters on the wrong one. Field comments at `proxy_service.proto:489-491` and `496-498` distinguish them but it's the kind of subtle thing a follow-up commit might collapse.
|
||||
- `PathTargetOptions.disable_access_log` is the only field whose default-false meaning **changes semantics** on the proxy side: false → log (status quo), true → suppress. Synthesizer sets `DisableAccessLog = !settings.EnableLogCollection`, so a missing/default settings row yields `EnableLogCollection=false → DisableAccessLog=true → suppressed`. Worth confirming downstream (`agentnetwork.synthesizer`) that operator-defined private services never inherit this flag — the proto field default protects them, but only if synth code is explicit.
|
||||
- `CheckLLMPolicyLimitsResponse.decision` is a free-form `string` (`proxy_service.proto:471`) rather than an enum. Only documented values are "allow" / "deny". An enum would prevent typo drift; consider before this RPC ships to external consumers.
|
||||
- `deny_code` (`proxy_service.proto:478-481`) is documented as "a stable label" but is also a free string. Pin the allowed set somewhere observable to the proxy.
|
||||
|
||||
### Security
|
||||
- `AgentNetworkProvider.api_key` MUST be write-only. Schema split (request has it at line 5158; response omits it) looks correct, but a regression here leaks the upstream provider credential to every dashboard reader. Check that the handler explicitly zeros it on the response path.
|
||||
- `extra_values` / `identity_header_*` headers on `AgentNetworkProvider` get stamped onto upstream requests. Description at `openapi.yml:5099` says "values not declared by the catalog are ignored at synth time" — a contract this module documents but the synthesizer must enforce. Confirm the synth module honours it.
|
||||
- Cluster + subdomain on `AgentNetworkSettings` are documented immutable (`openapi.yml:5686-5694`) and the `AgentNetworkSettingsRequest` (lines 5733-5752) doesn't accept them. Verify the `PUT /api/agent-network/settings` handler can't be tricked by extra JSON keys (oapi-codegen's `additionalProperties: false` is not declared here; spec defaults to permissive).
|
||||
|
||||
### Backward compatibility
|
||||
- The proto change is field-number additive: every previously numbered field keeps the same name + type, and the change is insertions only (no deletions in `proxy_service.proto`), so this holds at the source-text level.
|
||||
- `proxy_service_grpc.pb.go` adds two RPC handlers and registers them in `ProxyService_ServiceDesc.Methods` (lines 543-552). The existing entries are unchanged and order-preserving — gRPC method dispatch is name-keyed, so order doesn't matter, but reviewing the diff (no method renamed/dropped) is still worth a glance.
|
||||
- OpenAPI 3.0 doesn't have a built-in deprecation flow for paths; if any client tooling iterates `paths.*`, the additive routes shouldn't break it, but generated SDKs (especially the dashboard's) need a regen to gain access to `AgentNetwork*`.
|
||||
|
||||
### Codegen pinning
|
||||
- `generate.sh` (`shared/management/http/api/generate.sh:14`) installs `oapi-codegen@latest` rather than a pinned version. **This is a reproducibility gap** — re-running the script later may produce a different `types.gen.go`. Either pin the version in `generate.sh` (e.g. `@v2.7.0`) or document the pin in a `tools.go`.
|
||||
- proto codegen has the protoc / protoc-gen-go version stamped in the generated file header (`proxy_service.pb.go:3-4`).
|
||||
- Regenerate locally and confirm zero diff against the committed `types.gen.go` / `proxy_service.pb.go`.
|
||||
|
||||
## Test coverage
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| None in this scope | The proto and OpenAPI sources are tested transitively by the handler tests (`shared/management/http/handlers/agentnetwork/...`) and by the synthesizer/manager tests (`management/server/agentnetwork/...`). No round-trip serialisation test exists in the `proto/` or `api/` packages themselves. |
|
||||
| `shared/management/proto/*_test.go` | (absent) |
|
||||
| `shared/management/http/api/*_test.go` | (absent) |
|
||||
|
||||
Acceptable for codegen artefacts, but a single golden-file test that re-runs `oapi-codegen` and `protoc` in CI and diffs against the checked-in files would close the reproducibility gap noted above.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
- **No deprecation surface.** Old fields/RPCs are kept silently; there is no `[deprecated = true]` annotation on anything. Acceptable here because nothing is being removed.
|
||||
- **No proto-side validation.** Numeric ranges (e.g. `window_seconds >= 60`, `cost_usd >= 0`, capture-byte clamps) are enforced in the OpenAPI schema via `minimum:` and inside Go code by the proxy/management, but `proto3` itself can't express them; downstream is expected to validate every message.
|
||||
- **`MiddlewareConfig.config_json` is `bytes`** (`proxy_service.proto:163`) — opaque to the proto layer. Schema validity is the middleware factory's problem. This is a deliberate tradeoff (per the comment at 161-162) but worth flagging: a corrupted/malicious config_json can only fail at proxy apply time, not at the wire-decode step.
|
||||
- **No catalog endpoint schema for the catalog itself** — the catalog data ships as a `GET /api/agent-network/catalog/providers` returning `[AgentNetworkCatalogProvider]` (`openapi.yml:13024`), but the catalog source-of-truth lives in `management/server/agentnetwork/catalog`, not here.
|
||||
- The reaper / GC design was cut from scope; no reaper-related types appear here.
|
||||
|
||||
## Cross-references
|
||||
- Downstream: [management/store](20-management-store.md), [management/agentnetwork](21-management-agentnetwork.md), [management/handlers + wiring](22-management-handlers-wiring.md), [proxy/runtime](33-proxy-runtime.md)
|
||||
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level: [../00-overview.md](../00-overview.md)
|
||||
112
docs/agent-networks/modules/20-management-store.md
Normal file
112
docs/agent-networks/modules/20-management-store.md
Normal file
@@ -0,0 +1,112 @@
|
||||
# management/store — persistence for agent-network entities
|
||||
|
||||
> **Risk level:** Medium — six brand-new tables behind AutoMigrate, one upsert-counter table that runs on the request hot path, and one column carrying an encrypted secret.
|
||||
> **Backward-compat impact:** Additive (six new tables created by AutoMigrate; the `Store` interface gains 23 methods, but no existing column/index is touched).
|
||||
|
||||
## Module boundary
|
||||
|
||||
This module is the persistence layer for the Agent Network feature. Everything the management server stores about LLM proxying — providers, policies, guardrails, the per-account settings row, a usage-counter table written on every proxied LLM request, and the account-budget rules — flows through the methods added to `store.Store`. The module owns six tables, six entity types from `management/server/agentnetwork/types`, and a single hot-path upsert (`IncrementAgentNetworkConsumption`) consumed by the proxy fleet.
|
||||
|
||||
Out of scope here: the catalog of provider definitions (compiled-in, no DB), the synthesizer/manager built on top of these CRUDs (covered in [21-management-agentnetwork.md](21-management-agentnetwork.md)), and the HTTP handlers that translate API requests into Save/Delete calls.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `management/server/store/sql_store_agentnetwork.go` | gorm implementations of all 23 store methods |
|
||||
| `management/server/store/sql_store_agentnetwork_budgetrule_test.go` | round-trip + account-scoping coverage against a real sqlite store |
|
||||
| `management/server/store/sql_store.go` | one import, six entities appended to the `AutoMigrate` slice (sql_store.go:40, sql_store.go:141-142) |
|
||||
| `management/server/store/store.go` | 23 methods added to the `Store` interface (store.go:328-354) |
|
||||
| `management/server/store/store_mock_agentnetwork.go` | mockgen output for the new interface surface |
|
||||
|
||||
## Tables added / migrations
|
||||
|
||||
All six tables are created by `db.AutoMigrate` invoked from `NewSqlStore` at sql_store.go:133-143. There is no hand-rolled SQL migration script — the schema is whatever GORM derives from the struct tags.
|
||||
|
||||
- `agent_network_providers` — `Provider.TableName()` at provider.go:76. PK `id`, index on `account_id`, named index `idx_agent_network_provider` on `provider_id`. Carries an at-rest-encrypted `api_key` and ed25519 `session_private_key` (provider.go:35,56). `extra_values` and `models` are JSON blobs (`serializer:json`).
|
||||
- `agent_network_policies` — `Policy.TableName()` at policy.go:70. PK `id`, index on `account_id`. JSON columns: `source_groups`, `destination_provider_ids`, `guardrail_ids`, `limits`.
|
||||
- `agent_network_guardrails` — `Guardrail.TableName()` at guardrail.go:41. PK `id`, index on `account_id`. JSON `checks`.
|
||||
- `agent_network_settings` — `Settings.TableName()` at settings.go:33. PK `account_id` (one row per account), named index `idx_agent_network_settings_cluster_subdomain` on `subdomain` only — the index name implies a composite, but only one column is tagged.
|
||||
- `agent_network_consumption` — `Consumption.TableName()` at consumption.go:46. Composite PK across `(account_id, dim_kind, dim_id, window_seconds, window_start_utc)` — the same tuple the upsert keys on.
|
||||
- `agent_network_budget_rules` — `AccountBudgetRule.TableName()` at budgetrule.go:35. PK `id`, index on `account_id`. JSON `target_groups`, `target_users`, `limits`.
|
||||
|
||||
## CRUD surface added
|
||||
|
||||
Provider, Policy, Guardrail, BudgetRule follow the same pattern: `Get<Kind>ByID`, `GetAccount<Kind>` (list), `Save<Kind>` (upsert), `Delete<Kind>`, with account-scoping enforced by the existing `accountAndIDQueryCondition` / `accountIDCondition` constants (sql_store.go:59-62). Provider additionally exposes `GetAllAgentNetworkProviders` (cross-account, used by the synthesizer). Settings exposes `Get`/`GetByCluster`/`Save` (no delete — one row per account, created on first save). Consumption exposes the upsert `Increment`, a point `Get`, and a cross-window `List`.
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
handlers["HTTP handlers<br/>(management/server/agentnetwork)"] -->|Save/Delete| iface["Store interface<br/>store.go:328-354"]
|
||||
manager["agentnetwork.Manager"] -->|Get*| iface
|
||||
synth["synthesizer<br/>(global)"] -->|GetAllAgentNetworkProviders| iface
|
||||
proxy["proxy fleet<br/>(hot path)"] -->|IncrementAgentNetworkConsumption| iface
|
||||
iface --> sql["SqlStore methods<br/>sql_store_agentnetwork.go"]
|
||||
iface -.gomock.-> mock["MockStore<br/>store_mock_agentnetwork.go"]
|
||||
sql --> gorm["gorm.DB"]
|
||||
gorm --> tables[("6 tables<br/>agent_network_*")]
|
||||
sql --> enc["crypt.FieldEncrypt<br/>(provider only)"]
|
||||
```
|
||||
|
||||
Reads decrypt provider secrets in-place; writes do `provider.Copy().EncryptSensitiveData(...)` before `db.Save` so the caller's in-memory object keeps the plaintext `api_key` (sql_store_agentnetwork.go:88-102). Every list/get takes a `LockingStrength` and applies `clause.Locking{Strength: ...}` when non-`None` — matching the rest of the store. The upsert path uses `clause.OnConflict` with `gorm.Expr` server-side increments so concurrent proxy nodes converge without read-modify-write races (sql_store_agentnetwork.go:321-335).
|
||||
|
||||
## Invariants enforced at the store layer
|
||||
|
||||
- **Account scoping.** Every entity-by-ID method keys on `account_id = ? and id = ?`; no cross-tenant leak path through the API is reachable as long as callers always pass the auth'd `accountID` (sql_store_agentnetwork.go:70,141,201,429).
|
||||
- **NotFound mapping.** `gorm.ErrRecordNotFound` is translated to typed `status.NewAgentNetwork*NotFoundError`; `Delete*` returns NotFound when `RowsAffected == 0` (sql_store_agentnetwork.go:111-113,171-173,231-233,461-463).
|
||||
- **Provider secret encryption at rest.** `SaveAgentNetworkProvider` always encrypts before persist; `Get*` always decrypts after read. The plaintext `api_key` never reaches the DB through this layer (sql_store_agentnetwork.go:31,54,80,90).
|
||||
- **Consumption monotonicity.** The upsert only ever issues `col = col + ?` for the three counter columns — no decrement path exists (sql_store_agentnetwork.go:330-332).
|
||||
- **Window alignment is the caller's responsibility.** The store stamps `WindowStartUTC` as-passed; alignment to epoch happens in `types.WindowStart` at consumption.go:51-58.
|
||||
- **Settings has no Delete.** Intentional — one row per account, created on first save; the row sticks around for the account lifetime.
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
- `SaveAgentNetworkProvider` saves the copy (sql_store_agentnetwork.go:95). The caller's in-memory pointer therefore keeps plaintext `api_key` and any `CreatedAt`/`UpdatedAt` gorm autofills land on the copy, not the original. Callers that need synced timestamps must re-fetch.
|
||||
- `IncrementAgentNetworkConsumption`'s `Create` provides initial counter values (`TokensInput: tokensIn`, etc.) in the row, and on conflict the assignments add the same deltas to the existing values. The insert-vs-update arithmetic is consistent. Cross-check that no engine in use (sqlite, postgres, mysql) silently rejects the `OnConflict` clause — GORM emits engine-specific SQL but `ON DUPLICATE KEY UPDATE` (mysql) vs `ON CONFLICT (...)` (sqlite/postgres) need their unique constraint to match the composite PK on `agent_network_consumption`; it does, by construction.
|
||||
- `IncrementAgentNetworkConsumption` writes `updated_at: time.Now().UTC()` literally inside the assignments map (sql_store_agentnetwork.go:333) — fine, but it's a Go-side timestamp captured at call time, not a DB-side `now()`. Acceptable for an audit field.
|
||||
- `GetAgentNetworkConsumption` returns a zero-valued non-nil row on `ErrRecordNotFound` (sql_store_agentnetwork.go:364-371). Document or rename — a typed sentinel error would be more orthodox; callers must know not to error-check.
|
||||
|
||||
### Concurrency / transactions
|
||||
- Hot-path `IncrementAgentNetworkConsumption` runs outside any explicit transaction; concurrency safety relies entirely on the DB serialising the `ON CONFLICT` upsert against the composite PK. This is correct for postgres and mysql; for sqlite it serialises behind the single writer.
|
||||
- `SaveAgentNetworkSettings` is a blind upsert with no version/etag — concurrent writes from two operators last-write-wins on the collection-toggle flags (settings.go:23-25). Acceptable for admin-curated state but worth flagging.
|
||||
- `Save*Provider` uses `db.Save` on a struct with a PK already set — GORM emits UPDATE or INSERT based on row existence. No upsert clause is attached, so a race between two creates with the same generated `xid` (vanishingly unlikely) would surface as a PK violation.
|
||||
|
||||
### Migration safety
|
||||
- All six tables ride `AutoMigrate` (sql_store.go:141-142). AutoMigrate is additive: new columns get added, but it never drops columns nor narrows types. Three `bool` columns on `agent_network_settings` (`EnableLogCollection`, `EnablePromptCollection`, `RedactPii`) default to false at the GORM/DDL layer for existing rows; the test at sql_store_agentnetwork_budgetrule_test.go:83-112 locks that down on a fresh sqlite. Verify postgres/mysql produce the same default.
|
||||
- The named index `idx_agent_network_settings_cluster_subdomain` on settings.go:15 is declared on only `subdomain`. Either the cluster column also needs `gorm:"index:idx_agent_network_settings_cluster_subdomain"` to make it composite, or the name is misleading.
|
||||
- The named index `idx_agent_network_provider` on `Provider.ProviderID` (provider.go:30) is *not* unique and not scoped to account — two providers in the same account with the same `provider_id` are permitted at the DB layer; uniqueness, if any, must live above the store.
|
||||
|
||||
### Backward compatibility
|
||||
- Net additive. No removed methods, no renamed columns, no schema change to existing tables. Existing deployments running a prior binary continue to work; the first boot of the new binary creates the six tables.
|
||||
- The `Store` interface grows by 23 methods (store.go:330-354); any non-mock external implementer of `store.Store` will fail to compile. The repo only has `SqlStore` + `MockStore`, both updated.
|
||||
|
||||
### Performance (indexes, N+1)
|
||||
- All by-account list queries hit the `idx_account_id` per-table index. No N+1: list methods return the full slice in one query.
|
||||
- `GetAgentNetworkSettingsByCluster` (sql_store_agentnetwork.go:263-277) does a tablescan on `cluster` — no index. Tolerable for the bootstrap label generator (one-shot at provisioning) but worth noting if the call moves onto a hot path.
|
||||
- `ListAgentNetworkConsumption` returns every row ever recorded for the account (sql_store_agentnetwork.go:382-400) — unbounded growth, no `LIMIT`, no time filter. With one row per (dim, window) per request burst, this table grows fastest of the six; a retention job + a paginated list method are obvious follow-ups.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkBudgetRule_RealStore_RoundTrip` | full save → reload of `AccountBudgetRule` including the JSON-serialised `PolicyLimits`, target slices, double-delete returns NotFound (lines 18-59) |
|
||||
| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkBudgetRule_RealStore_ScopedByAccount` | cross-account isolation for budget rules (lines 63-78) |
|
||||
| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkSettings_RealStore_CollectionTogglesRoundTrip` | collection toggles default off, survive save/reload at the set values (lines 83-112) |
|
||||
|
||||
Gap: there is no store-level test for providers (encryption round-trip), policies, guardrails, or `IncrementAgentNetworkConsumption` (concurrent upsert, window-key uniqueness). The consumption upsert is the most performance-sensitive method in this module and the only one without a real-sqlite test.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- No retention / GC for `agent_network_consumption`.
|
||||
- No `Delete` for `Settings` (one row per account, cleared with the account).
|
||||
- No DB-engine-specific tuning — the same struct tags drive sqlite, mysql, postgres.
|
||||
- Provider `extra_values` and `models` are JSON blobs; querying inside them is not supported by design.
|
||||
- `GetAgentNetworkConsumption` "not-found = zero row" contract is convenient but unconventional.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Upstream: [shared/api](10-shared-api.md), [management/agentnetwork](21-management-agentnetwork.md)
|
||||
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level: [../00-overview.md](../00-overview.md)
|
||||
225
docs/agent-networks/modules/21-management-agentnetwork.md
Normal file
225
docs/agent-networks/modules/21-management-agentnetwork.md
Normal file
@@ -0,0 +1,225 @@
|
||||
# management/agentnetwork — domain layer + synth pipeline
|
||||
|
||||
> **Risk level:** High — central business logic + budget enforcement + the source of every middleware-chain change the proxy executes.
|
||||
> **Backward-compat impact:** Additive within the agent-network surface; one **behavioural difference for opted-out accounts** in parser capture (the capture flag is stamped explicitly false instead of being absent — see capture-pointer semantics below). Non-agent-network proxy services are untouched (the synth chain only ships on `agent-net-svc-*` targets).
|
||||
|
||||
## Module boundary
|
||||
|
||||
`management/server/agentnetwork` owns every agent-network entity (providers, policies, guardrails, account budget rules, per-account settings, consumption rows) and **translates them into the in-memory `*rpservice.Service` that the reverse-proxy controller turns into `proto.ProxyMapping`s and pushes to clusters**. It is the *only* writer of the agent-network middleware chain.
|
||||
|
||||
Inside the package: `manager.go` is the CRUD + permissions-gated facade; `synthesizer.go` walks settings + providers + policies + guardrails and emits the per-account service plus every middleware's JSON config; `policyselect.go` runs per-request attribution (min-wins account ceiling, then "drain bigger pool first"); `reconcile.go` diffs successive synth outputs and emits precise Create/Update/Delete proxy-mapping updates plus a peer-map refresh. `labelgen/` mints DNS-safe subdomain labels; `catalog/` is the static provider catalogue; `types/` carries gorm entity structs. The `_realstack_test.go` files in the parent `management/server/` directory exercise the manager + network-map controller end-to-end with no mocks.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `agentnetwork/manager.go` | Manager interface + CRUD + permission gates + bootstrap-settings + reconcile trigger |
|
||||
| `agentnetwork/synthesizer.go` | Settings/policy → wire-format synthesis; sole writer of the proxy middleware chain |
|
||||
| `agentnetwork/policyselect.go` | Per-request policy attribution + account-budget ceiling (min-wins) |
|
||||
| `agentnetwork/reconcile.go` | Per-account synth diff vs in-memory cache → Create/Update/Delete |
|
||||
| `agentnetwork/catalog/catalog.go` | Static provider catalogue (auth headers, identity-injection shapes) |
|
||||
| `agentnetwork/labelgen/{labelgen,words}.go` | DNS-safe subdomain picker + curated wordlist |
|
||||
| `agentnetwork/types/provider.go` | Provider entity + APIKey + Models + ExtraValues + SessionKeys |
|
||||
| `agentnetwork/types/policy.go` | Policy entity + `PolicyLimits` (token + budget) |
|
||||
| `agentnetwork/types/guardrail.go` | Guardrail entity (`ModelAllowlist`, `PromptCapture`) |
|
||||
| `agentnetwork/types/budgetrule.go` | `AccountBudgetRule` (reuses `PolicyLimits`) |
|
||||
| `agentnetwork/types/settings.go` | Per-account `Settings` (Cluster, Subdomain, 3 toggles) |
|
||||
| `agentnetwork/types/consumption.go` | `Consumption` row + `WindowStart` aligner |
|
||||
| `agentnetwork/{synthesizer,policyselect,reconcile,wire_shape}_*test.go` | See test coverage table |
|
||||
| `agentnetwork/types/consumption_test.go` | `WindowStart` alignment proofs |
|
||||
| `agentnetwork/labelgen/labelgen_test.go` | Deterministic picks + exhaustion + fallback |
|
||||
| `management/server/agentnetwork_realstack_test.go` | No-mock provider CRUD → network-map fan-out |
|
||||
| `management/server/agentnetwork_budgetrule_realstack_test.go` | No-mock budget-rule CRUD + settings preserve-immutable |
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### Synthesis (settings/policy → wire format)
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Mutation: provider/policy/guardrail/settings] --> B[managerImpl.reconcile accountID]
|
||||
B --> C{proxyController nil?}
|
||||
C -- yes --> D[accountManager.UpdateAccountPeers only]
|
||||
C -- no --> E[SynthesizeServices]
|
||||
E --> F[loadSettings — NotFound returns ok=false, no synth]
|
||||
F --> G[filterEnabledProviders sorted by CreatedAt]
|
||||
G --> H[filterEnabledPolicies]
|
||||
H --> I[backfillProviderSessionKeys if missing]
|
||||
I --> J[indexProviderGroups: providerID -> sorted source groups]
|
||||
J --> K[buildRouterConfigJSON drops orphan providers]
|
||||
J --> L[buildIdentityInjectConfigJSON per catalog entry]
|
||||
H --> M[mergeGuardrails: union allowlist, OR redact]
|
||||
M --> N[applyAccountCollectionControls account toggle = SOLE capture control]
|
||||
N --> O[marshalGuardrailConfig]
|
||||
K --> P[buildMiddlewareChain 8 middleware entries]
|
||||
L --> P
|
||||
O --> P
|
||||
P --> Q[buildAccountService: AccessGroups=union source groups, noop.invalid target]
|
||||
Q --> R[reconcile.diffMappings vs cache]
|
||||
R --> S[SendServiceUpdateToCluster CREATE/MODIFY/REMOVE]
|
||||
R --> T[accountManager.UpdateAccountPeers — fans synth ACLs into network map]
|
||||
```
|
||||
|
||||
### Budget rule resolution (min-wins, group+user bound)
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[SelectPolicyForRequest in] --> B[checkAccountBudget — runs FIRST, independent of policies]
|
||||
B --> C[GetAccountAgentNetworkBudgetRules]
|
||||
C --> D{for each enabled rule}
|
||||
D --> E{budgetRuleApplies?}
|
||||
E -- no --> D
|
||||
E -- yes --> F[attrGroup = lowestIntersect TargetGroups, in.GroupIDs]
|
||||
F --> G{Token cap enabled?}
|
||||
G -- yes --> H[evalTokenCap user dim + group dim]
|
||||
H --> I{exhausted?}
|
||||
I -- yes --> J[DENY: llm_account.token_cap_exceeded - STOP]
|
||||
I -- no --> K{Budget cap enabled?}
|
||||
G -- no --> K
|
||||
K -- yes --> L[evalBudgetCap user dim + group dim]
|
||||
L --> M{exhausted?}
|
||||
M -- yes --> N[DENY: llm_account.budget_cap_exceeded - STOP]
|
||||
M -- no --> D
|
||||
K -- no --> D
|
||||
D --> O[All rules passed -> fall through to per-policy selection]
|
||||
```
|
||||
|
||||
Key invariant: **rules are checked sequentially and ANY exhausted rule denies (all-must-pass / min-wins).** Untargeted rules (`len(TargetGroups)==0 && len(TargetUsers)==0`) apply to every caller (`policyselect.go:393`).
|
||||
|
||||
### Policy selection (per-peer, per-request)
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Account-budget gate passed] --> B[GetAccountAgentNetworkPolicies]
|
||||
B --> C[filterApplicablePolicies enabled + provider match + group intersect]
|
||||
C --> D{candidates empty?}
|
||||
D -- yes --> E[Allow, empty SelectedPolicyID]
|
||||
D -- no --> F[scoreCandidates -> scoreOne per policy]
|
||||
F --> G[scoreOne: attrGroup + window]
|
||||
G --> H{any cap exhausted?}
|
||||
H -- yes --> I[Drop policy; record last deny code]
|
||||
H -- no --> K[Keep as live candidate]
|
||||
F --> L{live candidates exist?}
|
||||
L -- no --> M[Deny with last exhaustion code]
|
||||
L -- yes --> N[Sort: uncapped wins -> larger group token -> group budget -> user token -> user budget -> oldest CreatedAt]
|
||||
N --> O[winner = scored 0]
|
||||
O --> P[Allow + SelectedPolicyID + AttributionGroupID + WindowSeconds]
|
||||
```
|
||||
|
||||
End-to-end: a mutation calls `managerImpl.reconcile(ctx, accountID)` (`manager.go:205,239,...`). Reconcile defers an `accountManager.UpdateAccountPeers` so the network-map controller re-runs and `injectAllProxyPolicies` picks up the new access groups; with a `proxyController` wired, it re-synthesizes the service, diffs against `reconcileCache[accountID]` (guarded by `reconcileMu`), and emits proto mappings to the cluster derived from the mapping's domain (`reconcile.go:120`). Synthesis is stateless and idempotent. Sole persistent side effect: `backfillProviderSessionKeys` (`synthesizer.go:249`) mints ed25519 keys on legacy provider rows and writes them back.
|
||||
|
||||
At request time the path is independent: the proxy calls `SelectPolicyForRequest` (`policyselect.go:56`); account-budget ceiling first, then per-policy scoring. Token + budget caps share `evalTokenCap` / `evalBudgetCap` — same primitive for account rules and policy limits, `label` differentiates the deny reason. After a served request, `RecordAccountBudgetUsage` (`policyselect.go:415`) fans deltas to every applicable rule's distinct `(dim_kind, dim_id, window)` tuple, deduplicating to prevent double-count when two rules share target+window.
|
||||
|
||||
## Public contracts
|
||||
|
||||
- **Manager interface** (`manager.go:48-80`): CRUD for `Providers/Policies/Guardrails/BudgetRules`; `GetSettings/UpdateSettings` (cluster + subdomain immutable, only the three toggles mutate); `ListConsumption/RecordConsumption(account, kind, dimID, windowSec, in, out, USD)`; `RecordAccountBudgetUsage(account, user, groups, in, out, USD)`; `SelectPolicyForRequest(ctx, PolicySelectionInput) → *PolicySelectionResult{Allow, SelectedPolicyID, AttributionGroupID, WindowSeconds, DenyCode, DenyReason}`.
|
||||
- **`PolicySelectionInput`** (`manager.go:85-90`): `{AccountID, UserID, GroupIDs, ProviderID}` — populated by the proxy from CapturedData + `llm_router` resolution.
|
||||
- **Synthesized middleware chain** (`synthesizer.go:576-657`), order load-bearing — response slot runs reverse-of-slice:
|
||||
|
||||
| Slot | Idx | ID | ConfigJSON shape | CanMutate |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| on_request | 0 | `llm_request_parser` | `{"capture_prompt": <bool>, "redact_pii"?: true}` | – |
|
||||
| on_request | 1 | `llm_router` | `{"providers":[{id, models[], upstream_*, auth_header_*, allowed_group_ids[]}]}` | **true** |
|
||||
| on_request | 2 | `llm_limit_check` | `{}` | – |
|
||||
| on_request | 3 | `llm_identity_inject` | `{"providers":[{provider_id, header_pair?, json_metadata?, extra_headers?}]}` | **true** |
|
||||
| on_request | 4 | `llm_guardrail` | `{"model_allowlist"?, "prompt_capture":{enabled,redact_pii}}` | – |
|
||||
| on_response | 5 | `llm_limit_record` | `{}` (runs LAST at runtime) | – |
|
||||
| on_response | 6 | `cost_meter` | `{}` | – |
|
||||
| on_response | 7 | `llm_response_parser` | `{"capture_completion": <bool>, "redact_pii"?: true}` | – |
|
||||
- **Synthesized service shape** (`synthesizer.go:739`): `Mode=HTTP`, `Private=true`, `Domain=<subdomain>.<cluster>`, `AccessGroups=unionSourceGroups(enabledPolicies)`, one `TargetTypeCluster` target with `Host=noop.invalid:443` (router rewrites per request), `Options.{DirectUpstream,AgentNetwork}=true`, `DisableAccessLog=!settings.EnableLogCollection`, `CaptureMax{Req,Resp}Bytes=1<<20`, `CaptureContentTypes=["application/json","text/event-stream"]`.
|
||||
|
||||
## Invariants
|
||||
|
||||
- **Min-wins / all-must-pass for account budget rules** (`checkAccountBudget`, `policyselect.go:353`): every applicable enabled rule is checked; first exhausted cap denies. Untargeted rules bind every caller.
|
||||
- **Account toggle is the SOLE control for capture enablement.** `applyAccountCollectionControls` (`synthesizer.go:701`) sets `merged.PromptCapture.Enabled = settings.EnablePromptCollection` *unconditionally*.
|
||||
- **Capture-pointer semantics on parser configs** — see "Things to scrutinize" below.
|
||||
- **`EnableLogCollection` ↔ `DisableAccessLog` is the only access-log toggle** (`synthesizer.go:770`). Default off ⇒ access log suppressed.
|
||||
- **`RedactPii` flows verbatim to BOTH parsers** (`synthesizer.go:584-585`) and is OR'd into the merged guardrail (`synthesizer.go:706`).
|
||||
- **Cluster and Subdomain are immutable on Settings.** `UpdateSettings` reloads existing row and overlays only the three toggles (`manager.go:558-561`).
|
||||
- **Orphan providers (no enabled policy authorises them) NEVER reach the router** (`synthesizer.go:351-357`); skipped from `identity_inject` for symmetry.
|
||||
- **Provider creation refuses empty `api_key`** (`manager.go:175`); **deletion refuses while any policy still references it** (`manager.go:265-273`).
|
||||
- **Session keypair stability across provider edits** (`manager.go:226-228`) — server-managed, copied through every `UpdateProvider`, never API-surfaced.
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
|
||||
- **Capture-pointer semantics — `*bool` vs `bool`.** Three states, owned by separate sides:
|
||||
- **Wire JSON this module emits:** `buildParserConfigJSON` (`synthesizer.go:678-693`) *always* stamps the capture field. Agent-network targets ship `"capture_prompt": false` or `"capture_prompt": true` — never absent. Same for `"capture_completion"`. The happy-path test pins `{"capture_prompt":false}` (`synthesizer_test.go:174`).
|
||||
- **Proxy-side parser config (consumer):** parsers decode into `*bool`. Matrix:
|
||||
- `nil` (field absent) → **legacy default = emit**. Preserved for non-agent-network callers and pre-existing tests (the backward-compat hook).
|
||||
- `false` (field present, value false) → **suppress emission entirely**. The behaviour for opted-out agent-network accounts. Without this, `enable_log_collection=true` + `enable_prompt_collection=false` would leak raw user input AND raw model output to the access log.
|
||||
- `true` → emit normally.
|
||||
- **Why the synth always stamps a value:** an agent-network mapping omitting the field would hit legacy "always emit" and re-introduce the leak. The `json.Marshal` error fallback at `synthesizer.go:687` degrades to `{}` — comment-claimed unreachable, but if ever fired re-introduces the leak. Consider fail-closed (return literal `{"capture_prompt":false}`) instead.
|
||||
- **`scoreCandidates` non-cumulative deny code.** Only the *last* exhausted policy's deny code survives (`policyselect.go:188-190`). Iteration order is store's natural order. Auth signal is `len(scored)==0`, so this is informational only — verify no UI depends on "first exhausted policy" semantics.
|
||||
- **`effectiveWindowSeconds` token-wins tiebreak.** When both halves are enabled with different windows, token's window wins (`policyselect.go:482`). Verify `RecordLLMUsage` increments against the winning window only.
|
||||
- **`RecordAccountBudgetUsage` dedup.** Two rules with the same `(kind, dim_id, window)` would double-count without the `tuples` map (`policyselect.go:434-449`). Key includes all three dimensions — correct.
|
||||
- **Fail-closed on bad provider:** unknown catalog id (`synthesizer.go:794-796`) or empty API key (`synthesizer.go:801-803`) drops the **entire** account's synth, not just the bad provider. Confirm matches operator UX.
|
||||
|
||||
### Security
|
||||
|
||||
- **Redact OR-merge:** merged `RedactPii` = account OR guardrail (`synthesizer.go:706`). **Parser-side flag is `settings.RedactPii` only, NOT the OR** — a guardrail-only opt-in does not propagate to parsers. Correct because the account toggle gates capture, but worth noting on the proxy side.
|
||||
- **Group resolution must not leak across accounts.** Every store call carries `accountID` (`policyselect.go:73, 286, 298, 322, 334, 354`); `lowestIntersect` uses caller's claimed groups only (`policyselect.go:494`). Risk surface is upstream (handler populates `in.GroupIDs`).
|
||||
- **`UpdateSettings` preserves immutable Cluster + Subdomain** (`manager.go:558`). A client can't rebind the cluster.
|
||||
- **Provider session keypair backfill writes through `SaveAgentNetworkProvider`** (`synthesizer.go:256`) from a read-shaped call. Idempotent → worst case is a wasted write under concurrent reconcile + snapshot.
|
||||
|
||||
### Concurrency
|
||||
|
||||
- **`reconcileMu`** guards `reconcileCache`. Lock window is narrow — compute diff inside, send outside (`reconcile.go:56-68`).
|
||||
- **`labelRngMu`** guards `labelRng` because `math/rand.Source` is unsafe for concurrent use (`manager.go:638-640`).
|
||||
- **Real-store tests** use `store.NewTestStoreFromSQL` with `t.TempDir()` per test — no shared state, no `t.Parallel()`.
|
||||
- **`RecordAccountBudgetUsage` dedup `tuples` map is per-call;** concurrent calls fan out fully — correct (each request's tokens book once per applicable rule).
|
||||
- **Deferred `UpdateAccountPeers` runs inline after the proxy push** (`reconcile.go:28-35`); a slow call stretches CRUD response time.
|
||||
|
||||
### Backward compatibility
|
||||
|
||||
- **Capture-pointer semantics (restated):** non-agent-network callers see no field → legacy nil-default emit, identical to pre-PR. Agent-network targets always carry an explicit `capture_*` value.
|
||||
- **`TestSynthesizeServices_HappyPath` was updated:** request-parser config moved from `{}` to `{"capture_prompt":false}` (`synthesizer_test.go:174`). External snapshot tests against synth output need updating.
|
||||
- **`MergedGuardrails` retains zeroed `TokenLimits`/`Budget`/`Retention`** even though `Policy.Limits` carries the real values now; `llm_limit_check` is the authoritative enforcement. Comment at `synthesizer.go:940-948` calls this out.
|
||||
|
||||
### Performance
|
||||
|
||||
- **`SynthesizeServices` runs on every controller tick / mutation reconcile.** Cost: 4 store reads + optional per-provider keypair backfill. Sort + index + merge are O(N log N) / O(P × G); dominant cost is JSON marshalling. No nested loops escape these dimensions.
|
||||
- **`reconcile.diffMappings` is O(N + M)** with N=M=1 per account today — effectively constant.
|
||||
- **`SynthesizeServicesForCluster`** (`synthesizer.go:71`) walks every account on a cluster; per-account failures are **swallowed** (`synthesizer.go:91-93`) so a single misconfigured account doesn't drop the cluster. Runs per proxy reconnect.
|
||||
|
||||
### Observability
|
||||
|
||||
- **Activity codes:** `AgentNetwork{Provider,Policy,Guardrail,BudgetRule}{Created,Updated,Deleted}`; `AgentNetworkSettingsUpdated` with `log_collection/prompt_collection/redact_pii` payload (`manager.go:567-571`). **No activity code for `SelectPolicyForRequest` denies** — surfaced via proxy access log only (likely intentional given volume).
|
||||
- **Deny codes** namespaced: `llm_policy.{token,budget}_cap_exceeded`, `llm_account.{token,budget}_cap_exceeded` (`policyselect.go:18-26`).
|
||||
- **Reconcile failures are logged at warn and swallowed** (`reconcile.go:42-44`). Persistent synth failures (e.g. unknown catalog id) silently keep the proxy out of sync — consider a manager-level synth-health surface if this becomes a support burden.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| `synthesizer_test.go` | Mock-store: `HappyPath` (8-mw chain ordering, `{"capture_prompt":false}` baseline); `No{Settings,Providers}`; `Disabled{Provider,Policy}_NoService`; `RouterConfigOrdering`; `PolicyCheckConfig_UnionsSourceGroups`; `OrphanProvider_HasEmptyAllowedGroups`; identity-inject for LiteLLM / Bifrost (overrides + partial disable) / Cloudflare / Portkey / Vercel / OpenRouter / generic non-customizable; `GuardrailMerge_AllowlistUnion_LimitsRestrictive`; `BackfillsMissingSessionKeys`; `HTTPUpstream_KeepsExplicitPort`; `UpstreamURLPath_FlowsToRouter`; `UnknownProviderID_FailsClosed`; `EmptyAPIKey_FailsClosed`. |
|
||||
| `synthesizer_realstore_test.go` | Real-sqlite: `SurvivesStatusToggle` reproduces the disable/re-enable 403 regression; `Reconcile_RealStore_PushesPrivateAfterStatusToggle` extends through reconcile push. |
|
||||
| `synthesizer_guardrail_realstore_test.go` | `PromptCaptureAccountIsSoleControl`; `PromptCaptureFlowsWhenAccountOptsIn`; `AccountRedactWithoutGuardrailRedact`; `NoGuardrail_CaptureOff`. |
|
||||
| `synthesizer_log_collection_realstore_test.go` | `LogCollection{Off_SuppressesAccessLog,On_PermitsAccessLog}` — verifies `DisableAccessLog` propagation through `ToProtoMapping`. |
|
||||
| `synthesizer_parser_redact_realstore_test.go` | **Capture-pointer regression suite:** `ParserConfigsCarryRedactPii`; `ParserConfigsSuppressCaptureWhenLogCollectionOnly` (log=on/prompt=off ⇒ both capture flags false); `ParserConfigsOmitRedactPiiWhenOff`. |
|
||||
| `policyselect_test.go` | Mock-store: `NoApplicablePolicies`; `AllowWithLowestGroupAttribution`; `LargerPoolWinsAcrossUsageLevels`; `StaysOnLargerPoolAfterPartialDrain`; `FallsThroughToSmallerPoolWhenLargerExhausted`; `TiebreakBy{LargerGroupPool,CreatedAt}`; `DeniesWhenAllExhausted`; `UncappedPolicyAlwaysWinsAgainstCapped`; `DisabledPolicyIgnored`; `StoreErrorPropagates`; `RejectsEmptyAccount`; `SharesGroupCounterAcrossPolicies`; `AntiFallThroughOnLowestGroup`; `BudgetOnlyExhaustionDenies`; `BudgetTighterThanTokenWins`. |
|
||||
| `policyselect_realstore_test.go` | Real-sqlite regression guard: `NoApplicablePolicies`; `AllowAndLowestGroupAttribution`; `LargerPoolWins_FallsThroughWhenExhausted`; `BudgetCapDenies`; `GroupCounterSharedAcrossPolicies`; `DisabledPolicyIgnored`. |
|
||||
| `policyselect_account_realstore_test.go` | Account budget rules: `AccountCeilingBindsEvenWithUncappedPolicy` (min-wins); `AccountGroupCeiling`; `AccountTargetUsersBindsOnlyThatUser`; `AccountRuleRecordsToOwnWindow`. |
|
||||
| `reconcile_test.go` | `FirstSynth_EmitsCreate`; `NoChange_EmitsNothingExtra` (re-push as Modified — verify desired); `PolicyRemoved_EmitsDelete`; `NilProxyController_NoOp`; `EmptyAccountID_NoOp`; `ClusterFromMapping`. |
|
||||
| `wire_shape_test.go` | `TestSynthesizedService_WireShape` — proto-shape lockdown via `ToProtoMapping`. Catches "service not matching" (mapping reaches proxy but no SNI/HTTP route). Asserts ID, Domain, Mode, AuthToken, `Private`, `Auth.Oidc=false`, one path `/` + `https://noop.invalid/`, 8 middlewares with correct slot enums, router config `auth_header_value="Bearer sk-test-key"`. |
|
||||
| `labelgen/labelgen_test.go` | `PickUnique_{DeterministicWithSeededRng,AvoidsTakenWordsWhenMostAreReserved,FallsBackWhenAllReserved}`; `UniqueWords_DropsDuplicates`. |
|
||||
| `types/consumption_test.go` | `WindowStart_{AlignedToUnixEpoch,WithinWindowConverges,AcrossWindowsDiverges,DifferentWindowsHaveDifferentBuckets,SubMinuteAndMinuteAlignment,ZeroWindowReturnsInputUTC}`. Bucket alignment so multi-node reads converge. |
|
||||
| `agentnetwork_realstack_test.go` | `ProviderCRUD_FansOutToProxyAndClientPeers` — no-mock end-to-end through real account manager + network-map + agentnetwork: provider create propagates the updated map to both proxy peer and client peer with the synth DNS surface. |
|
||||
| `agentnetwork_budgetrule_realstack_test.go` | `BudgetRuleCRUD_RealManager`; `UpdateSettings_PreservesImmutableAndTogglesCollection`. |
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- **`MergedGuardrails.TokenLimits/Budget/Retention` emit at zero** (`synthesizer.go:940-948`); real enforcement is `Policy.Limits` via `llm_limit_check`. Future cleanup implied.
|
||||
- **Session keys picked from first enabled provider by created_at** (`pickServiceSessionKeys`, `synthesizer.go:270`). Existing session cookies survive provider edits only while the first-by-CreatedAt provider stays in place. Document for operators.
|
||||
- **Reconcile failures silently swallowed** (`reconcile.go:42-44`). Persistent failures keep the proxy out of sync until the next reconcile.
|
||||
- **`scoreCandidates` exposes only the LAST exhaustion's deny code** when multiple policies are exhausted.
|
||||
- **`bootstrapSettingsIfNeeded` failure is non-fatal to provider create** (`manager.go:200`): provider lands, synth is no-op until the next provider create retries the bootstrap.
|
||||
- **Budget rules do not trigger a reconcile** (`manager.go:476-477`). Request-time evaluation only; new rules take effect on the next request without a proxy push.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- **Upstream:** [shared/api](10-shared-api.md), [management/store](20-management-store.md), reverseproxy `service`/`proxy`/`sessionkey` packages, `management/server/permissions` + `activity`.
|
||||
- **Downstream:** [management/handlers (HTTP wiring)](22-management-handlers-wiring.md), [proxy/middleware-builtin](31-proxy-middleware-builtin.md), network-map controller (`injectAllProxyPolicies` fan-out).
|
||||
- **End-to-end flow:** [../01-end-to-end-flows.md](../01-end-to-end-flows.md) — "Provider create → reconcile → proxy push → peer map refresh" and "request → policy select → record" diagrams.
|
||||
- **Top-level:** [../00-overview.md](../00-overview.md)
|
||||
203
docs/agent-networks/modules/22-management-handlers-wiring.md
Normal file
203
docs/agent-networks/modules/22-management-handlers-wiring.md
Normal file
@@ -0,0 +1,203 @@
|
||||
# management/handlers + wiring — HTTP API + gRPC delivery
|
||||
|
||||
> **Risk level:** Medium — the surface is mostly additive, but two changes are load-bearing: `injectAllProxyPolicies` runs on every per-peer compute, and `shallowCloneMapping` must round-trip `Private` (a missed field silently breaks every MODIFIED).
|
||||
> **Backward-compat impact:** Additive on the wire (new routes, new RPCs, new proto fields, new gorm column on `AccessLogEntry`). One management-internal break: `nbhttp.NewAPIHandler` gains a trailing `agentNetworkManager` parameter; `nil` is tolerated and silently skips route registration.
|
||||
|
||||
## Module boundary
|
||||
|
||||
This module is the seam between the public Agent Network HTTP API and the proxy fleet that serves agent traffic. North side: a `/api/agent-network/*` surface (providers, policies, guardrails, budget rules, settings, consumption) on the existing gorilla router, delegating to `agentnetwork.Manager`. Handlers are thin — they translate `api.*` ↔ `types.*`, validate shape, forward. RBAC and event emission stay inside the manager (`manager.go:680-682`).
|
||||
|
||||
South side: `ProxyServiceServer` (`proxy.go`) learns to (a) ship synth services to a proxy on initial snapshot, (b) resolve agent-network domains in `getServiceByDomain` for OIDC/session/tunnel-peer flows, (c) gate LLM requests via `CheckLLMPolicyLimits` + `RecordLLMUsage`, (d) preserve `Private` through `shallowCloneMapping` so per-proxy live updates don't silently flip services public. The network_map controller prepends synth services to `account.Services` on every per-peer compute; `accesslogentry.go` gains an indexed `AgentNetwork` column so the dashboard can filter cheaply.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `handlers/agentnetwork/providers_handler.go` | Catalog + provider CRUD + central `AddEndpoints` |
|
||||
| `handlers/agentnetwork/policies_handler.go` | Policy CRUD + shared `validatePolicy*` |
|
||||
| `handlers/agentnetwork/guardrails_handler.go` | Guardrail CRUD |
|
||||
| `handlers/agentnetwork/budget_handler.go` | Account-level budget rule CRUD |
|
||||
| `handlers/agentnetwork/settings_handler.go` | GET (200+`null` if unbootstrapped) + PUT toggles |
|
||||
| `handlers/agentnetwork/consumption_handler.go` | Read-only consumption rows |
|
||||
| `handlers/agentnetwork/handlers_test.go` | Real-store fixture; wire round-trip + validation |
|
||||
| `handlers/agentnetwork/budget_handler_test.go` | Budget-rule + settings toggles |
|
||||
| `server/http/handler.go` | New `agentNetworkManager` arg; conditional `AddEndpoints` |
|
||||
| `server/permissions/modules/module.go` | New `AgentNetwork` module key |
|
||||
| `internals/server/boot.go` | Wires synthesiser adapter + limits service into proxy server |
|
||||
| `internals/server/modules.go` | `AgentNetworkManager()` lazy-create node |
|
||||
| `internals/controllers/network_map/controller/controller.go` | `injectAllProxyPolicies` replaces 4 `InjectProxyPolicies` calls |
|
||||
| `internals/controllers/network_map/controller/repository.go` | `SynthesizeAgentNetworkServices` repo method |
|
||||
| `internals/modules/reverseproxy/service/service.go` | `MiddlewareConfig`, capture limits, `AgentNetwork`, `DisableAccessLog` + proto |
|
||||
| `internals/modules/reverseproxy/accesslogs/accesslogentry.go` | Indexed `AgentNetwork bool` from proto |
|
||||
| `internals/shared/grpc/proxy.go` | Synth wiring, 2 RPCs, domain fallback, `Private` in clone |
|
||||
| `internals/shared/grpc/proxy_clone_test.go` | Locks every `ProxyMapping` field minus `AuthToken` |
|
||||
| `server/activity/codes.go` | 13 new activity codes (125-137) |
|
||||
|
||||
## HTTP routes added
|
||||
|
||||
All routes inherit the platform's auth middleware. Perms enforced inside `agentnetwork.Manager.requirePermission` (`manager.go:680-682`) on `modules.AgentNetwork`. Permission column shows the `op` passed to `requirePermission` — read = `Read`, etc.
|
||||
|
||||
| Method | Path | Perm | Handler |
|
||||
| ------ | ---- | ---- | ------- |
|
||||
| GET | `/agent-network/catalog/providers` | authn only | `providers_handler.go:43` |
|
||||
| GET | `/agent-network/providers` | read | `providers_handler.go:57` |
|
||||
| POST | `/agent-network/providers` | create | `providers_handler.go:97` |
|
||||
| GET | `/agent-network/providers/{providerId}` | read | `providers_handler.go:77` |
|
||||
| PUT | `/agent-network/providers/{providerId}` | update | `providers_handler.go:132` |
|
||||
| DELETE | `/agent-network/providers/{providerId}` | delete | `providers_handler.go:172` |
|
||||
| GET | `/agent-network/policies` | read | `policies_handler.go:32` |
|
||||
| POST | `/agent-network/policies` | create | `policies_handler.go:72` |
|
||||
| GET | `/agent-network/policies/{policyId}` | read | `policies_handler.go:52` |
|
||||
| PUT | `/agent-network/policies/{policyId}` | update | `policies_handler.go:102` |
|
||||
| DELETE | `/agent-network/policies/{policyId}` | delete | `policies_handler.go:142` |
|
||||
| GET | `/agent-network/guardrails` | read | `guardrails_handler.go:25` |
|
||||
| POST | `/agent-network/guardrails` | create | `guardrails_handler.go:65` |
|
||||
| GET | `/agent-network/guardrails/{guardrailId}` | read | `guardrails_handler.go:45` |
|
||||
| PUT | `/agent-network/guardrails/{guardrailId}` | update | `guardrails_handler.go:95` |
|
||||
| DELETE | `/agent-network/guardrails/{guardrailId}` | delete | `guardrails_handler.go:135` |
|
||||
| GET | `/agent-network/budget-rules` | read | `budget_handler.go:24` |
|
||||
| POST | `/agent-network/budget-rules` | create | `budget_handler.go:64` |
|
||||
| GET | `/agent-network/budget-rules/{ruleId}` | read | `budget_handler.go:44` |
|
||||
| PUT | `/agent-network/budget-rules/{ruleId}` | update | `budget_handler.go:95` |
|
||||
| DELETE | `/agent-network/budget-rules/{ruleId}` | delete | `budget_handler.go:135` |
|
||||
| GET | `/agent-network/settings` | read | `settings_handler.go:53` (200+`null` if no row) |
|
||||
| PUT | `/agent-network/settings` | update | `settings_handler.go:27` |
|
||||
| GET | `/agent-network/consumption` | read | `consumption_handler.go:21` |
|
||||
|
||||
## gRPC RPCs added (or modified)
|
||||
|
||||
| RPC | Direction | Trigger |
|
||||
| --- | --------- | ------- |
|
||||
| `CheckLLMPolicyLimits` | proxy→mgmt unary | Pre-flight gate; returns allow/deny, selected policy, attribution group, window, deny code+reason (`proxy.go:259-301`). `Unimplemented` when limits service is nil. |
|
||||
| `RecordLLMUsage` | proxy→mgmt unary | Post-flight write of tokens+cost against policy-window dimensions + every applicable account budget rule (`proxy.go:303-349`). `window_seconds==0` ⇒ no policy cap, only account fan-out runs. |
|
||||
| `GetMappingUpdate`/`SendServiceUpdate` (stream) | mgmt→proxy | Snapshot (`proxy.go:752-780`) now appends `SynthesizeServicesForCluster`. Live updates use `SendServiceUpdateToCluster` + `shallowCloneMapping`. |
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### HTTP request lifecycle
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant DB as Dashboard
|
||||
participant R as gorilla.Router (/api)
|
||||
participant H as handler (agentnetwork)
|
||||
participant M as agentnetwork.Manager
|
||||
participant S as store.Store
|
||||
participant AM as accountManager (StoreEvent)
|
||||
|
||||
DB->>R: POST /api/agent-network/providers
|
||||
R->>H: createProvider (auth mw sets UserAuth)
|
||||
H->>H: GetUserAuthFromContext + validate(req)
|
||||
H->>M: CreateProvider(userID, provider, bootstrapCluster)
|
||||
M->>M: requirePermission(AgentNetwork, Create)
|
||||
M->>S: SaveAgentNetworkProvider
|
||||
M->>AM: StoreEvent(AgentNetworkProviderCreated)
|
||||
M-->>H: created provider
|
||||
H-->>DB: 200 + api.AgentNetworkProvider JSON
|
||||
```
|
||||
|
||||
### Synth-service delivery via gRPC
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant P as Proxy
|
||||
participant G as ProxyServiceServer
|
||||
participant SM as service.Manager (persisted)
|
||||
participant SA as synthesizerAdapter
|
||||
participant AN as SynthesizeServicesForCluster
|
||||
participant ST as store.Store
|
||||
|
||||
Note over P,G: Initial snapshot
|
||||
P->>G: GetMappingUpdate (stream open)
|
||||
G->>SM: GetServicesForCluster(conn.address)
|
||||
SM-->>G: persisted []*Service
|
||||
G->>SA: SynthesizeServicesForCluster(conn.address)
|
||||
SA->>AN: SynthesizeServicesForCluster(store, clusterAddr)
|
||||
AN->>ST: walk every account; read providers/policies/settings
|
||||
AN-->>SA: in-memory []*Service
|
||||
SA-->>G: []*Service
|
||||
G->>P: response (persisted + synth)
|
||||
|
||||
Note over G,P: Per-request live update
|
||||
G->>G: SendServiceUpdateToCluster(update, clusterAddr)
|
||||
G->>G: shallowCloneMapping(update) %% Private MUST survive
|
||||
G->>P: response with single mapping
|
||||
```
|
||||
|
||||
End-to-end: HTTP write persists rows and emits an activity event; the manager then triggers `proxyController.SendServiceUpdate` so proxies re-render. **The snapshot path is the only one that calls into the synthesiser** — on stream open it pulls persisted services then appends synth services for the cluster. Synth services are never persisted. For OIDC/session/tunnel-peer flows, `getServiceByDomain` falls back to `SynthesizeServicesForCluster(clusterFromDomain(domain))` when persisted lookup misses (`proxy.go:1763-1793`). The network_map contribution is orthogonal: per-peer compute prepends the same synth services to `account.Services` before `InjectProxyPolicies`.
|
||||
|
||||
## Permissions model added
|
||||
|
||||
- `permissions/modules/module.go:22` adds `AgentNetwork Module = "agent_network"`, registered in `All` (`module.go:42`). Standard `operations.{Read,Create,Update,Delete}` matrix.
|
||||
- Handlers don't call `permissionsManager` directly — they extract `UserAuth` and delegate to `agentnetwork.Manager`, which gates every mutation through `requirePermission` (`manager.go:168, 308, 549`, etc.). Confirm your role-set provider has `agent_network` rows for owner/admin/user/billing-admin before merging.
|
||||
- `getCatalogProviders` (`providers_handler.go:43`) intentionally skips RBAC — catalog is global static data.
|
||||
|
||||
## Activity codes added
|
||||
|
||||
`activity/codes.go:244-274` adds Activities 125-137 + string/code mappings (`codes.go:428-444`), following `<domain>.<resource>.<action>` (e.g., `agent_network.provider.create`). Audit-log exporters / SIEM forwarders need to know the new codes.
|
||||
|
||||
## Invariants
|
||||
|
||||
- **Synth services are never persisted.** Snapshot appends after `serviceManager.GetServicesForCluster` (`proxy.go:761-770`); network_map prepends before `InjectProxyPolicies` (`controller.go:117-126`).
|
||||
- **`shallowCloneMapping` must round-trip every `ProxyMapping` field except `AuthToken`** — `proxy_clone_test.go:50-58` enforces via `gproto.Equal`. The bug it guards: a missing `Private` made every MODIFIED arrive `private=false`, the proxy skipped `ValidateTunnelPeer`, `UserGroups` stayed empty, `llm_router` denied `no_authorised_provider`; a restart "fixed" it because the snapshot uses the original mapping.
|
||||
- **Limit-window floor is 60s** (`policies_handler.go:189-220`); enabled cap with both per-group and per-user at zero is rejected. Budget rules reuse the same validator (`budget_handler.go:170`).
|
||||
- **Manager is optional at boot.** `NewAPIHandler` registers routes only when non-nil (`handler.go:129`); `ProxyServiceServer` returns `Unimplemented` from both RPCs when limits service is unwired (`proxy.go:262-265, 306-309`).
|
||||
- **Settings GET on an unbootstrapped account returns 200 + `null`** (`settings_handler.go:65-72`) — not 404.
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
- **`injectAllProxyPolicies` runs on every per-peer compute**: `controller.go:163, 309, 415, 681`. `sendUpdateAccountPeers` is the target of the buffered fan-out — synth runs once per debounced account-update tick **and** once per direct `UpdateAccountPeer`. Cost is O(providers + policies × users-per-group) per account under `LockingStrengthNone`. No per-account synth cache — verify it fits the buffer interval for your largest tenant.
|
||||
- **`clusterFromDomain` strips at the first `.`** (`proxy.go:1784-1792`). A zero-dot domain returns `""` and the synth call walks every account. Confirm no path reaches this with a malformed/internal domain.
|
||||
- **Account-budget `RecordConsumption` fans out even when `window_seconds == 0`** (`proxy.go:341-348`) — intentional. Verify the proxy never sends `RecordLLMUsage` for a request that wasn't actually allowed.
|
||||
|
||||
### Security
|
||||
- Every handler extracts `UserAuth` via `nbcontext.GetUserAuthFromContext` before any work. Routes live behind the standard `/api` mux; bypass list is not extended.
|
||||
- `CheckLLMPolicyLimits` / `RecordLLMUsage` ride the existing **proxy → mgmt** gRPC connection auth. No additional token check inside the RPCs — they trust the connection. Confirm the proxy-side token-verification interceptor in this package gates both.
|
||||
- `RecordLLMUsage` only validates `account_id != ""` (`proxy.go:317-319`). A compromised proxy can attribute cost to any account in its cluster — was already true for prior RPCs but is louder now that data drives denials.
|
||||
|
||||
### Concurrency
|
||||
- `SetAgentNetworkSynthesizer` / `SetAgentNetworkLimitsService` write under `s.mu.Lock`; read paths copy the interface under read lock (`proxy.go:236-247, 260-263, 304-307`). Same pattern as existing `serviceManager`/`proxyController` setters.
|
||||
- Manager writes use `LockingStrengthUpdate`; synth reads use `LockingStrengthNone` — read-after-write via the proxy snapshot can observe a stale view by up to one fan-out tick.
|
||||
- Network_map controller is single-threaded per account; cross-account is parallel.
|
||||
|
||||
### Backward compatibility
|
||||
- `proxy_clone_test.go` is the regression net; any new `ProxyMapping` field must be cloned or explicitly nulled in the test.
|
||||
- `AccessLogEntry` adds indexed `AgentNetwork bool` — implicit AutoMigrate; deploy story must handle table-rewrite cost on high-volume access-log tables.
|
||||
- `TargetOptions` gains seven `omitempty` JSON fields (`service.go:69-94`); on-wire shape stays compatible. `targetOptionsToProto` tests all fields when deciding nil (`service.go:551-556`).
|
||||
- `NewAPIHandler` signature changes — every caller must pass `agentNetworkManager`; `nil` is supported.
|
||||
|
||||
### Observability
|
||||
- 13 new activity codes via `accountManager.StoreEvent` in the manager — confirm dashboard's audit-log UI maps them.
|
||||
- `AccessLogEntry.AgentNetwork` is indexed for the dashboard's agent-network log filter.
|
||||
- New RPCs log at error level on store/selector failures (`proxy.go:284, 327, 332, 348`). Snapshot synth failures degrade to warnings — stream is not aborted (`proxy.go:765`).
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test | Locks down |
|
||||
| ---- | ---------- |
|
||||
| `handlers_test.go::TestPolicyHandler_WindowSecondsRoundTrip` | GET carries `window_seconds`; legacy `window_hours`/`window_days` absent. |
|
||||
| `handlers_test.go::TestPolicyHandler_RejectsSubMinuteWindow` | POST `<60s` returns 4xx. |
|
||||
| `handlers_test.go::TestConsumptionHandler_EmptyAccountReturnsArray` | `/consumption` returns `[]` — never null. |
|
||||
| `handlers_test.go::TestConsumptionHandler_PopulatedAccountListsRows` | RecordConsumption×2 surfaces both with correct tokens/cost/window. |
|
||||
| `budget_handler_test.go::TestBudgetRuleHandler_RoundTrip` | Targets + PolicyLimits shape round-trip. |
|
||||
| `budget_handler_test.go::TestBudgetRuleHandler_ListReturnsArray` | Empty-list shape. |
|
||||
| `budget_handler_test.go::TestBudgetRuleHandler_{RejectsMissingName,RejectsSubMinuteWindow}` | Validation rejections are 4xx. |
|
||||
| `budget_handler_test.go::TestSettingsHandler_GetExposesCollectionToggles` | All four toggles + computed `Endpoint`. |
|
||||
| `proxy_clone_test.go::TestShallowCloneMapping_PreservesAllFieldsExceptAuthToken` | Future-proofs clone; every field round-trips, `AuthToken` dropped. |
|
||||
|
||||
Handler tests use a real sqlite store + real manager + always-allow permissions mock (`handlers_test.go:53-75`). Create/update/delete success paths flow through `accountManager.StoreEvent` which the fixture doesn't wire — covered by manager-level no-mock tests outside this module.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- No pagination on any list endpoint; no bulk endpoints.
|
||||
- Synth result is not cached — every snapshot and every per-peer compute repeats the store walk.
|
||||
- `getSettings` returning `200 + null` is a deliberate dashboard concession.
|
||||
- No rate-limiting beyond the global `/api` rate limiter.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Upstream: [shared/api](10-shared-api.md), [management/agentnetwork](21-management-agentnetwork.md), [management/store](20-management-store.md)
|
||||
- Downstream: [proxy/runtime](33-proxy-runtime.md)
|
||||
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level: [../00-overview.md](../00-overview.md)
|
||||
215
docs/agent-networks/modules/30-proxy-middleware-framework.md
Normal file
215
docs/agent-networks/modules/30-proxy-middleware-framework.md
Normal file
@@ -0,0 +1,215 @@
|
||||
# proxy/middleware-framework — generic plugin system
|
||||
|
||||
> **Risk level:** **High** — every proxied request transits this chain. Budget exhaustion, panic recovery, or chain-close bugs hit the hot path for all targets, not just agent-network ones.
|
||||
> **Backward-compat impact:** Additive at the proxy. The `middleware` and `bodytap` packages are new (`proxy/internal/middleware/middleware.go:1`, `proxy/internal/middleware/bodytap/request.go:13`); existing proxy targets keep working until a chain is bound to them via `Manager.Rebuild`.
|
||||
|
||||
This module is the **framework only** — no LLM/agent-network domain knowledge is required, since every example built into it is generic.
|
||||
|
||||
## Module boundary
|
||||
|
||||
This module is the **framework only**: slots, chains, registry, dispatcher, accumulator, body-tap, output filters. No middleware *implementation* lives here — those land in `proxy/internal/middleware/builtin/*` (covered in module 31). The package contract is:
|
||||
|
||||
1. The proxy hands a `Manager` to its config-apply path. The synth pushes per-path `PathTargetBinding` lists (`proxy/internal/middleware/manager.go:26`) into `Manager.Rebuild`, which resolves each spec via the `Registry`/`Resolver` (`proxy/internal/middleware/registry.go:81-121`) and produces an immutable `Chain` keyed by `serviceID|pathID` (`proxy/internal/middleware/manager.go:410-412`).
|
||||
2. The reverse-proxy handler captures the request body via `bodytap.CaptureRequest`, calls `Chain.RunRequest`, applies returned mutations (already filtered by `chain.applyMutations`), forwards to the upstream behind a `bodytap.CapturingResponseWriter`, then calls `Chain.RunResponse` and `Chain.RunTerminal`.
|
||||
3. Middlewares are inert plugins that receive a deep-cloned `Input` and return an `Output` whose decision/mutations are clamped by the dispatcher's `filterOutput` (`proxy/internal/middleware/dispatcher.go:149-172`).
|
||||
|
||||
Everything that crosses the framework boundary in either direction is value-typed and deep-copied — middlewares cannot mutate the live request directly, and the framework cannot inadvertently leak middleware-owned slices into the request hot path.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `proxy/internal/middleware/middleware.go` | `Middleware` + `Factory` interfaces. |
|
||||
| `proxy/internal/middleware/types.go` | `Slot`, `FailMode`, `Decision`, all limit constants, `Input`/`Output`/`Mutations`/`UpstreamRewrite`/`AuthHeader` value types. |
|
||||
| `proxy/internal/middleware/spec.go` | Apply-time `Spec` (validated wire shape + runtime-injected fields) and `Clone`. |
|
||||
| `proxy/internal/middleware/registry.go` | `Registry` (factory map, RWMutex) and `Resolver` (Spec → bound `Middleware`). |
|
||||
| `proxy/internal/middleware/manager.go` | `Manager`, `chainTable` reverse index, `Rebuild`/`Invalidate*`, async chain close. |
|
||||
| `proxy/internal/middleware/chain.go` | `Chain.RunRequest`/`RunResponse`/`RunTerminal`, mutation gating, `cloneInputFor`. |
|
||||
| `proxy/internal/middleware/chain_test.go` | Metadata threading, LIFO response order, rewrite gating, UserGroups propagation, terminal accumulation. |
|
||||
| `proxy/internal/middleware/dispatcher.go` | Timeout/panic recovery, fail-mode, error classification, `filterOutput`. |
|
||||
| `proxy/internal/middleware/decision.go` | `RenderDenyResponse`, deny-code regex, status clamp. |
|
||||
| `proxy/internal/middleware/headerpolicy.go` | Compile-in header denylist + `FilterHeaderMutations`. |
|
||||
| `proxy/internal/middleware/bodypolicy.go` | `ValidateBodyReplace` / `ApplyBodyReplace` smuggling guards. |
|
||||
| `proxy/internal/middleware/keys.go` | Metadata key namespace constants. |
|
||||
| `proxy/internal/middleware/metadata.go` | `Accumulator` — allowlist, per-mw/per-request byte caps, redaction. |
|
||||
| `proxy/internal/middleware/metrics.go` | OTel instrument bundle (`proxy.middleware.*`). |
|
||||
| `proxy/internal/middleware/redaction.go` | `Scan` — PEM/JWT/AWS/bearer/Luhn-validated CC patterns. |
|
||||
| `proxy/internal/middleware/bodytap/request.go` | Capture + replay reader, `Budget` semaphore, bypass reason codes. |
|
||||
| `proxy/internal/middleware/bodytap/response.go` | `CapturingResponseWriter` (tee with `PassthroughWriter` for Flusher/Hijacker preservation). |
|
||||
|
||||
## Slot model
|
||||
|
||||
Three slots, declared per-middleware exactly once (`proxy/internal/middleware/types.go:27-41`):
|
||||
|
||||
- **`SlotOnRequest`** (`Slot=1`) — runs **before** the upstream call, in registration order. May `DecisionDeny`, may emit `Mutations` (header add/remove, body replace, `UpstreamRewrite`) when both `Spec.CanMutate` and `Middleware.MutationsSupported()` are true. May emit metadata. Each middleware in the slot sees metadata that earlier ones in the same slot just emitted (`proxy/internal/middleware/chain.go:144-178`) — this is how the framework gives middlewares an intra-slot side channel without a global bag.
|
||||
- **`SlotOnResponse`** (`Slot=2`) — runs **after** the upstream returns, in **reverse** registration order. Cannot deny (clamped in `dispatcher.filterOutput`, `proxy/internal/middleware/dispatcher.go:153-157`). May still mutate response headers in principle, but the current chain only forwards `RewriteUpstream` from on_request, so on_response mutations are observe-only in practice. Threads the same per-slot metadata view as on_request.
|
||||
- **`SlotTerminal`** (`Slot=3`) — runs **after** every on_response middleware has emitted, in registration order. Sees the full accumulated bag plus prior terminal emissions (`chain.go:221-245`). Cannot deny, cannot mutate (`dispatcher.go:168-170`). Designed for sinks (access log, metrics push, audit emitter).
|
||||
|
||||
Splitting a feature across slots (e.g. "parse on the way out, ship on terminal") is the explicit architectural choice — `types.go:7-15` and `types.go:22-25` make it clear no middleware participates in more than one slot.
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### Chain dispatch
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
participant H as proxy HTTP handler
|
||||
participant BT as bodytap.CaptureRequest
|
||||
participant CH as Chain
|
||||
participant DI as Dispatcher
|
||||
participant MW as Middleware (per slot)
|
||||
participant US as Upstream
|
||||
participant CW as CapturingResponseWriter
|
||||
|
||||
H->>BT: CaptureRequest(r, cfg, budget)
|
||||
BT-->>H: body[], truncated, release()
|
||||
H->>CH: RunRequest(ctx, r, Input, Accumulator)
|
||||
loop on_request, registration order
|
||||
CH->>CH: cloneInputFor(in, OnRequest)
|
||||
CH->>DI: Invoke(ctx, spec, mw, call)
|
||||
DI->>MW: mw.Invoke(callCtx, in)
|
||||
MW-->>DI: Output{decision, metadata, mutations?}
|
||||
DI->>DI: filterOutput (clamp deny, gate mutations)
|
||||
DI-->>CH: filtered Output
|
||||
CH->>CH: Accumulator.Emit (allowlist + caps + redact)
|
||||
alt DecisionDeny
|
||||
CH-->>H: denied, merged, rewrite
|
||||
else allow
|
||||
CH->>CH: applyMutations(r, m) and capture rewrite
|
||||
end
|
||||
end
|
||||
CH-->>H: nil, merged, rewrite
|
||||
H->>US: ProxyRequest (with rewrite/mutations applied)
|
||||
US-->>CW: bytes (streamed, tee'd into cap-bounded buf)
|
||||
CW-->>H: passthrough complete
|
||||
H->>CH: RunResponse(ctx, Input{RespBody:CW.Body(),...}, acc)
|
||||
loop on_response, REVERSE order (LIFO)
|
||||
CH->>DI: Invoke (same wrappers)
|
||||
end
|
||||
H->>CH: RunTerminal(ctx, Input{Metadata:full bag}, acc)
|
||||
H->>BT: release() + CW.Release()
|
||||
```
|
||||
|
||||
### Body-tap mechanics (request + response)
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph req[Request capture — bodytap.CaptureRequest]
|
||||
R0[r.Body] --> R1{cfg.MaxRequestBytes > 0?\nUpgrade absent?\nContent-Type allowed?\nCL <= cap?}
|
||||
R1 -- no --> R2[bypass = reason\nbody = nil\nr.Body untouched]
|
||||
R1 -- yes --> R3[Budget.Acquire(cap)]
|
||||
R3 -- denied --> R4[bypass=BypassBudget]
|
||||
R3 -- ok --> R5[io.LimitReader(r.Body, cap+1)\nio.ReadAll]
|
||||
R5 --> R6{len > cap?}
|
||||
R6 -- truncated --> R7[viewable = buf[:cap]\nr.Body = replayReadCloser{buf, tail}]
|
||||
R6 -- whole --> R8[r.Body = NopCloser(bytes.Reader(buf))\nclose original]
|
||||
R7 --> R9[(release captured\nbudget on req end)]
|
||||
R8 --> R9
|
||||
end
|
||||
|
||||
subgraph resp[Response capture — CapturingResponseWriter]
|
||||
W0[client] -.-> CW[Write(p)]
|
||||
CW --> P1[PassthroughWriter.Write(p)\n— bytes leave to client first]
|
||||
P1 --> P2{!stopped?}
|
||||
P2 -- yes --> P3{remaining = cap - buf.Len()}
|
||||
P3 --> P4[buf.Write(p[:take])\nset truncated if take<n]
|
||||
P2 -- no --> P5[silent drop into the tee\n(client write already done)]
|
||||
end
|
||||
```
|
||||
|
||||
The body-tap is the highest-leak-risk surface in this module; three details matter:
|
||||
|
||||
1. **Request capture is "read-and-replay", not "read-and-forward".** `CaptureRequest` always swaps `r.Body` for either a `bytes.Reader` (whole body fit) or a `replayReadCloser` that replays the captured prefix then drains the remaining stream from the original body (`bodytap/request.go:178-201`). This means the **upstream still sees the full body even when the tap truncates**. The original `r.Body` is **not** closed in the truncated branch — `replayReadCloser.Close()` only closes the tail (`bodytap/request.go:199-201`), which is the same reader, so close once on request end is correct, but reviewers should confirm the upstream proxy always reads to EOF (otherwise the tail is leaked).
|
||||
2. **Response capture is a write-through tee.** `CapturingResponseWriter.Write` forwards to the underlying writer **first** (`bodytap/response.go:116-117`), then tees into `buf` under its own mutex. Client never blocks on the tee. `Flusher`/`Hijacker` are preserved via the embedded `responsewriter.PassthroughWriter`. SSE/chunked streams flow through untouched; middlewares only see the bounded prefix.
|
||||
3. **Budget is a single shared semaphore.** `Manager` constructs one `bodytap.Budget` at startup (`manager.go:138-144`, default `256 MiB` from `bodytap/request.go:39`). Every capture pre-acquires its full `MaxRequestBytes` / `MaxResponseBytes` from the budget regardless of actual body size; that prevents a flood of small captures from collectively exceeding the cap, but it also means a misconfigured `MaxRequestBytes = 1 MiB` with 256 concurrent requests already exhausts the default budget. Reviewers should sanity-check the operator-facing defaults that ship with synth-service.
|
||||
|
||||
The framework explicitly aborts capture (and increments `proxy.middleware.capture_bypass_total`) before reading the first byte when `Upgrade`/`Connection: upgrade` is set (`bodytap/request.go:120-125`), when the content-type isn't in the allowlist (`bodytap/request.go:126-128`), or when the advertised `Content-Length` already exceeds the cap (`bodytap/request.go:131-133`). This is the right place to make sure WebSocket upgrades and large file uploads never reach the buffer.
|
||||
|
||||
## Public contracts
|
||||
|
||||
- **`Middleware` interface** (`middleware.go:14-36`): `ID()`, `Version()`, `Slot()`, `AcceptedContentTypes()`, `MetadataKeys()`, `MutationsSupported()`, `Invoke(ctx, *Input) (*Output, error)`, `Close()`. `MetadataKeys()` is the **closed set** the middleware is allowed to emit — the accumulator drops anything outside it (`metadata.go:71-75`). `Close` must be idempotent (called even when `Invoke` was never reached).
|
||||
- **`Factory` interface** (`middleware.go:44-47`): `ID()`, `New(rawConfig []byte) (Middleware, error)`. `RawConfig` is opaque JSON bytes on the wire (`spec.go:6-12`); each factory owns its own typed config.
|
||||
- **`Decision` type** (`types.go:59-69`): `Allow=0`, `Deny=1`, `Passthrough=2`. Default-zero is permissive — important because every middleware that omits `Decision` gets `Allow`. Dispatcher clamps `Deny` to `Passthrough` outside `SlotOnRequest` (`dispatcher.go:153-157`).
|
||||
- **`Mutations`** (`types.go:196-201`): `HeadersAdd`/`HeadersRemove` (filtered through `headerpolicy.go`), `BodyReplace` (gated through `bodypolicy.go`), and `RewriteUpstream`. `RewriteUpstream` is **last-write-wins** within the on_request slot (`chain.go:170-172`, locked down by `TestChain_RunRequest_LatestRewriteWins`).
|
||||
- **Metadata propagation keys** (`keys.go`): all keys live in a single file and follow `^[a-z][a-z0-9_-]*(\.[a-z0-9_-]*)+$` (`metadata.go:8`). Framework-injected error tagging uses `mw.<id>.error_kind` (`keys.go:81`) so operators can distinguish framework-emitted entries from middleware-emitted ones.
|
||||
|
||||
## Invariants
|
||||
|
||||
- **Per-request context isolation.** `cloneInputFor` deep-copies every mutable field (`Headers`, `RespHeaders`, `Metadata`, `Body`, `RespBody`, `UserGroups`, `UserGroupNames`) before each invocation (`chain.go:286-308`). A misbehaving middleware that mutates `in.Headers` only corrupts its own copy.
|
||||
- **Body-tap bounded by capture limit.** Request side uses `io.LimitReader(r.Body, limit+1)` (`bodytap/request.go:152`) — the `+1` is how the code detects truncation (`bodytap/request.go:160`); the surfaced buffer is sliced back down to `limit`. Response side stops teeing once `buf.Len() >= cap` (`bodytap/response.go:121-133`). Neither side can grow the buffer past the configured cap.
|
||||
- **Headers/body redaction order.** Accumulator runs `Scan(value)` **before** counting cost (`metadata.go:81-82`), so the byte budgets are computed against post-redaction sizes. `Scan` order is PEM → JWT → AWS key → bearer → Luhn-validated CC (`redaction.go:25-51`) — the comment block in `redaction.go:8-13` is explicit that this is best-effort, not DLP.
|
||||
- **No middleware can starve the chain.** Every invocation runs inside `context.WithTimeout(ctx, clampTimeout(spec.Timeout))` in a separate goroutine (`dispatcher.go:51-94`), with the deadline race-`select`ed against the result channel. A blocked middleware fires the timeout path, gets fail-mode'd, and `IncError(kind=timeout)`. Timeouts are clamped to `[10ms, 5s]` (`types.go:80-86`, `dispatcher.go:174-185`).
|
||||
- **Panic recovery.** `recover()` captures the panic, logs only the type + a 4 KiB stack prefix (no panic value — avoids leaking secrets the middleware was processing), and produces a `panicError` that flows through fail-mode (`dispatcher.go:64-76`).
|
||||
- **Chain immutability + atomic swap.** `chainTable` is cloned on every `Rebuild`/`Invalidate*` and swapped via `atomic.Pointer` (`manager.go:44-69`, `manager.go:221-300`). Readers (`ChainFor`) are lock-free; writers serialise on `writeMu`. The retired chain is `Close`-d in a background goroutine bounded by `chainCloseTimeout = 2 * MaxTimeout` (`manager.go:21-22`, `manager.go:326-346`), so in-flight invocations finish on the old chain after the swap.
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
|
||||
- **Chain ordering deterministic from synth output?** `Manager.buildChain` iterates `b.Specs` in slice order and appends to `bound` (`manager.go:366-391`); `NewChain` then partitions by slot but **preserves slice order within each slot** (`chain.go:50-60`). So order on the wire = order observed at runtime. Synth must therefore emit specs in the intended execution order — there is no per-spec `Priority` field. Worth flagging.
|
||||
- **Decision short-circuit semantics.** `RunRequest` returns immediately on `DecisionDeny` (`chain.go:164-167`) **with the metadata accumulated so far** plus the `denied.Metadata`. Callers that ignore `merged` on deny will lose framework-injected `mw.<id>.error_kind` entries. The proxy runtime is the only caller; confirm it always feeds `merged` into the access log on the deny path as well.
|
||||
- **`UpstreamRewrite` `AuthHeader` bypass** (`types.go:218-235`). The `AuthHeader`/`StripHeaders` fields *intentionally* bypass the header denylist on the basis that the proxy itself rewrites auth. The denylist still blocks middleware-emitted `HeadersAdd: Authorization=...`. This is a delicate carve-out — review the runtime consumer to confirm only the trusted upstream-build path unpacks `AuthHeader`, never the generic `applyMutations` loop.
|
||||
- **`replayReadCloser.Close` only closes the tail** (`bodytap/request.go:199-201`). The replay buffer doesn't own a resource, so this is correct, but it conflates "replay finished" with "underlying body closed". If a caller `Close()`s without reading to EOF, the original body is closed but the captured prefix is lost; harmless for the proxy path (upstream always reads to EOF) but worth a doc-comment.
|
||||
|
||||
### Security
|
||||
|
||||
- **Body-tap memory bounds.** Discussed above — bounded by `MaxBodyCapBytes = 1 MiB` per direction (`types.go:77`) and the shared `Budget` (default 256 MiB). The concerning case is the **deep-copy in `cloneInputFor`** (`chain.go:300-306`): every middleware invocation gets its **own copy** of `Body` and `RespBody`. A chain of N middlewares with a 1 MiB body allocates N MiB of transient bytes per request. With `MaxMiddlewaresPerChain = 16` (`types.go:103`) that's up to 16 MiB extra per in-flight request. Worth pricing into the budget model.
|
||||
- **Header redaction completeness.** `denyHeaders` (`headerpolicy.go:5-17`) covers the auth/forwarding family and framing (`Content-Length`, `Transfer-Encoding`, `Trailer`). `denyHeaderPrefixes` covers `X-Authenticated-*`, `X-Forwarded-*`, `X-Remote-*`, `X-NetBird-*`. Notably absent: `Range`, `If-Match`/`If-None-Match` (mutation could cause cache poisoning), `Origin`/`Referer`. Not necessarily wrong, but worth a deliberate decision.
|
||||
- **Metadata key collisions across middlewares.** The accumulator has no cross-middleware uniqueness check; two middlewares with the same key in their allowlist can both emit it, and both copies land in `merged` (`metadata.go:51-99`). Downstream consumers must tolerate duplicates. Worth documenting.
|
||||
- **Deny rendering.** `RenderDenyResponse` only allows codes matching `^[a-z][a-z0-9._-]{0,63}$` (`decision.go:9`), redacts/truncates message + detail values, caps `Details` at 8 entries (`decision.go:42-50`), clamps status to `[400,499]\{401}` (`decision.go:65-73`). The deny body type is fixed; middlewares cannot inject arbitrary JSON.
|
||||
|
||||
### Concurrency
|
||||
|
||||
- **Per-request state vs shared state in factories.** Each `Factory.New` is called once per chain build; the returned `Middleware` instance is **shared across all requests** for that chain. `Invoke` must be reentrant. The framework does not enforce this — a buggy middleware that holds per-call state on the struct will silently race. Suggest a `// Invoke must be safe for concurrent use` doc on the interface.
|
||||
- **`chainTable` clone-on-write** is correct, but `addChain`/`removeChain` mutate the *cloned* table before the swap (`manager.go:71-108`), and they're called under `writeMu`. Readers only ever see the post-swap pointer. Good.
|
||||
- **`Chain.inflight` WaitGroup**. `Run*` does `Add(1)`/`Done()` (`chain.go:142-143`, `chain.go:194-195`, `chain.go:225-226`); `Close` waits on it bounded by ctx (`chain.go:75-85`). One concern: a *new* `RunRequest` can `Add(1)` *after* `Close` started waiting if the caller still holds a stale chain pointer. `WaitGroup` does not panic on this if the count was already > 0 at `Wait` time, but it does panic if `Add` happens after `Wait` returns and another `Wait` runs. `Close` is documented one-shot, so single-`Wait` is fine, but callers must drop the chain reference before calling `Close`. Worth a code comment near `Close`.
|
||||
- **Goroutine leaks.** `Dispatcher.Invoke` spawns one goroutine per call and *always* writes to a buffered (cap=1) channel (`dispatcher.go:62-76`), so even if the timeout fires the goroutine completes its send and exits. No leak.
|
||||
- **`closeChainsAsync`** detaches retired chains into a goroutine (`manager.go:326-346`). If `Manager` is never GC'd this is fine, but there's no shutdown hook to wait on outstanding closes. Reviewers should confirm the proxy shutdown path explicitly drains in-flight requests before tearing down `Manager`, or accept that the last chain-close round may be cut short on exit.
|
||||
|
||||
### Performance
|
||||
|
||||
- **Allocations per request.** `cloneInputFor` allocates new slices for `Headers`, `RespHeaders`, `Metadata`, `Body`, `RespBody`, `UserGroups`, `UserGroupNames` — once per middleware per request. For a typical 5-middleware chain on a 1 KiB body that's ~10 small slice allocs plus one `Body` copy each. Not a hot-path crisis, but `sync.Pool` for the per-call `Input` would be a natural follow-up.
|
||||
- **Accumulator allocates a fresh `allowSet` per `Emit` call** (`metadata.go:55-58`). One per middleware per slot pass = up to 48 per request. Cheap, but worth noting.
|
||||
- **Regex cost.** `Scan` runs five regex passes on every accepted metadata value (`redaction.go:25-51`). Bounded by `MaxMetadataValueBytes = 4 KiB` so worst case is small.
|
||||
|
||||
### Observability
|
||||
|
||||
- **Per-middleware metrics.** `proxy.middleware.requests_total{middleware,target_id,outcome}` (`metrics.go:34-41`), `duration_ms`, `invocations_total`, `errors_total{kind}`, `metadata_rejected_total{reason}`, `header_mutation_blocked_total{header}`, `capture_bypass_total{reason}`. Comprehensive surface; operators can alert on `errors_total{kind=panic}` and `errors_total{kind=timeout}` separately. **Latency histogram is in milliseconds with default OTel buckets** — for a 10ms–5s timeout range default buckets cover OK, but a custom bucket set centred on 1–500ms would resolve the agent-network response-parser tail better.
|
||||
- **Decision logs.** Panic logs (`dispatcher.go:69`) include `request_id`, type, and stack but not the panic value (safe). `Chain.Close` logs middleware-close errors at debug (`chain.go:91`). `applyMutations` logs body-replace rejections at warn (`chain.go:278`). No log on the deny path itself — by design, since the access-log terminal middleware is expected to record outcomes.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| `proxy/internal/middleware/chain_test.go:77` | `RunRequest` threads metadata across on_request middlewares (regression for the "later mw can't see earlier mw's emissions" bug). |
|
||||
| `chain_test.go:110` | `RunResponse` reverse-order threading. |
|
||||
| `chain_test.go:142` | `cost_meter`-shaped scenario: response_parser registered after cost_meter still emits *before* cost_meter sees the bag (guards the `cost.skipped=missing_tokens` regression). |
|
||||
| `chain_test.go:178` | `UpstreamRewrite` last-write-wins. |
|
||||
| `chain_test.go:206` | No middleware emits → nil rewrite. |
|
||||
| `chain_test.go:224` | Rewrite filtered when `CanMutate=false`. |
|
||||
| `chain_test.go:245` | `Input.UserGroups` propagates verbatim through `cloneInputFor`. |
|
||||
| `chain_test.go:304` | Terminal middlewares see the full accumulated bag + prior terminal emissions. |
|
||||
|
||||
**Gaps** worth raising with the author:
|
||||
- No direct test for `Dispatcher.Invoke` timeout / panic / fail-mode behaviour at the framework level (covered indirectly by built-in tests, but a unit test pinning `errors_total{kind=...}` labels would be cheap insurance).
|
||||
- No test for `bodytap.CaptureRequest` truncated replay (the upstream-sees-full-body invariant is exactly the kind of thing a regression would silently break).
|
||||
- No test for `Budget` exhaustion behaviour under concurrency.
|
||||
- No test for `Manager.InvalidateMiddleware` + `LiveServiceCheck` race (the auth-revocation race the comment at `manager.go:33-38` calls out is the load-bearing reason for `LiveServiceCheck`).
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- **No middleware-to-middleware RPC.** Side-channel is metadata only.
|
||||
- **No streaming body inspection.** Middlewares see a bounded prefix; SSE / chunked parsing happens against that prefix in the response middleware.
|
||||
- **No per-spec priority.** Order is registration order in the spec slice.
|
||||
- **No retry / circuit-breaker** on middleware errors. Fail-mode is binary (open/closed) and per-spec.
|
||||
- **Mutations cannot rewrite the request URL path or query** — only `RewriteUpstream` can change scheme/host (+ optional path replacement, see `types.go:218-235`).
|
||||
- **Redaction is best-effort.** Explicitly documented in `redaction.go:8-13`. Not a DLP solution.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Upstream wire shape: [../modules/10-shared-api.md](10-shared-api.md) (Spec/RawConfig encoding from management).
|
||||
- Built-in middlewares using this framework: [../modules/31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md).
|
||||
- Runtime wiring (where `Manager`, `Chain`, and `bodytap` are consumed by the HTTP handler): [../modules/33-proxy-runtime.md](33-proxy-runtime.md).
|
||||
- End-to-end request flow including capture + chain dispatch: [../01-end-to-end-flows.md](../01-end-to-end-flows.md).
|
||||
- Top-level architecture: [../00-overview.md](../00-overview.md).
|
||||
365
docs/agent-networks/modules/31-proxy-middleware-builtin.md
Normal file
365
docs/agent-networks/modules/31-proxy-middleware-builtin.md
Normal file
@@ -0,0 +1,365 @@
|
||||
# proxy/middleware-builtin — the LLM chain
|
||||
|
||||
The registry-mounted middleware set the proxy executes on every agent-network
|
||||
LLM request. The two highest-blast-radius areas are the **capture-pointer
|
||||
semantics** and the **limit_check ⇒ limit_record** record-once invariant.
|
||||
|
||||
Sibling module: [32-proxy-llm-parsers.md](./32-proxy-llm-parsers.md) — the SDK
|
||||
adapters + pricing catalog this chain delegates to.
|
||||
|
||||
---
|
||||
|
||||
## Module boundary
|
||||
|
||||
This module is the registry-mounted middleware set the proxy executes on
|
||||
every agent-network LLM request. Each sub-package registers itself via
|
||||
`init()`
|
||||
([builtin.go:32–34](../../../proxy/internal/middleware/builtin/builtin.go));
|
||||
the proxy server anonymous-imports the set
|
||||
([all_test.go:11–19](../../../proxy/internal/middleware/builtin/all_test.go))
|
||||
so the registry is populated at boot. The chain is wired by the management
|
||||
synthesiser and executed by the framework
|
||||
(`proxy/internal/middleware/{chain,dispatcher,accumulator}.go` — both out
|
||||
of scope). Everything here reads from / writes to one envelope: the
|
||||
`middleware.KV` metadata bag plus `middleware.Mutations` for header/body
|
||||
rewrites.
|
||||
|
||||
## The 8 middlewares
|
||||
|
||||
| Name | Slot | Inputs (metadata read) | Outputs (metadata written) | Side effects |
|
||||
|---|---|---|---|---|
|
||||
| `llm_request_parser` | OnRequest | `Input.{URL,Body,BodyTruncated}` | `llm.{provider,model,stream,request_prompt_raw,capture_truncated}` | none |
|
||||
| `llm_router` | OnRequest | `llm.model`, `Input.{URL,UserGroups}` | `llm.{resolved_provider_id,authorising_groups}`, `llm_policy.{decision,reason}` | upstream rewrite + auth strip/inject |
|
||||
| `llm_limit_check` | OnRequest | `llm.{resolved_provider_id,model}`, `Input.{AccountID,UserID,UserGroups}` | `llm.{selected_policy_id,attribution_group_id,attribution_window_seconds}`, `llm_policy.{decision,reason}` | gRPC `CheckLLMPolicyLimits` |
|
||||
| `llm_identity_inject` | OnRequest | `llm.{resolved_provider_id,authorising_groups}`, `Input.{UserEmail,UserID,UserGroups,UserGroupNames}` | none | header strip/inject + optional body rewrite |
|
||||
| `llm_guardrail` | OnRequest | `llm.{model,request_prompt_raw}` | `llm_policy.{decision,reason}`, `llm.request_prompt` | none (model allowlist deny) |
|
||||
| `llm_response_parser` | OnResponse | `llm.provider`, `Input.{RespHeaders,RespBody,Status}` | `llm.{input,output,total,cached_input,cache_creation}_tokens`, `llm.response_completion` | none |
|
||||
| `cost_meter` | OnResponse | `llm.{provider,model}`, token buckets | `cost.usd_total` or `cost.skipped` | pricing lookup |
|
||||
| `llm_limit_record` | OnResponse | `llm.{attribution_group_id,attribution_window_seconds,input_tokens,output_tokens}`, `cost.usd_total` | none | gRPC `RecordLLMUsage` |
|
||||
|
||||
[all_test.go:26–40](../../../proxy/internal/middleware/builtin/all_test.go)
|
||||
locks the ID set; adding or removing one is a conscious extension.
|
||||
|
||||
## Files
|
||||
|
||||
| File | LOC | Notes |
|
||||
|---|---:|---|
|
||||
| `builtin.go` | 86 | Registry + `FactoryContext` (ctx, data dir, meter, logger, mgmt client) |
|
||||
| `all_test.go` | 41 | Locks the 8-ID registry surface |
|
||||
| `agentnetwork_chain_integration_test.go` | 319 | Live sqlite + real gRPC bufconn; gate→recorder wire path |
|
||||
| `llm_request_parser/*` | 162 / 66 / 356 | Provider detection, body parse, prompt extraction with capture-pointer gating |
|
||||
| `llm_router/*` | 385 / 84 / 586 | Three-pass route selection (model → groups → path-prefix) |
|
||||
| `llm_limit_check/*` | 196 / 38 / 182 | Pre-flight `CheckLLMPolicyLimits` (2s, fail-open) |
|
||||
| `llm_identity_inject/*` | 440 / 108 / 666 | HeaderPair (LiteLLM) + JSONMetadata (Portkey) + ExtraHeaders |
|
||||
| `llm_guardrail/*` | 176 / 82 / 75 / 219 / 217 | Model allowlist + optional prompt capture with PII redaction |
|
||||
| `llm_response_parser/*` | 258 / 222 / 43 / 433 / 169 / 111 | Buffered + SSE accumulation; AWS event-stream accumulator (`streaming_bedrock.go`) for Bedrock; capture-pointer gates completion emit |
|
||||
| `cost_meter/*` | 181 / 84 / 439 | Token → USD via `proxy/internal/llm/pricing` |
|
||||
| `llm_limit_record/*` | 144 / 35 / 191 | Post-flight `RecordLLMUsage` (5s, debug-on-error) |
|
||||
|
||||
## Per-middleware
|
||||
|
||||
### llm_request_parser
|
||||
|
||||
Detects the LLM provider via `llm.DetectParser` (URL sniff) or by name via
|
||||
`llm.ParserByName` when synthesiser stamps `provider_id`
|
||||
([middleware.go:96–99](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)).
|
||||
**Path-routed providers short-circuit first:** `parseVertexPath` and
|
||||
`parseBedrockPath` ([middleware.go:85–94](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go))
|
||||
pull the model + vendor out of the URL before parser selection runs — Vertex
|
||||
from `/v1/projects/.../publishers/{pub}/models/{model}:{action}` (publisher →
|
||||
vendor via `vertexPublisherVendor`), Bedrock from `/model/{id}/{action}` with
|
||||
`normalizeBedrockModel` stripping the region prefix + version suffix. See
|
||||
[50-path-routed-providers.md](./50-path-routed-providers.md) for the full path
|
||||
grammar. For body-routed providers it decodes the body into `RequestFacts`
|
||||
(model + stream) and extracts the prompt. On
|
||||
`capture_prompt=true` (or absent — see capture-pointer semantics below) the
|
||||
prompt is run through `llm_guardrail.RedactPII` when `redact_pii=true` and
|
||||
truncated rune-safely to 3500 bytes
|
||||
([middleware.go:109–122](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)).
|
||||
**Key invariant:** redaction is parser-side, not guardrail-side — access-log
|
||||
reads `llm.request_prompt_raw` directly.
|
||||
|
||||
### llm_router
|
||||
|
||||
Three-pass route selection in `matchRoute`
|
||||
([middleware.go:241–300](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)):
|
||||
filter by `Models` claim → vendor-pin (a vendor-tagged request never crosses to
|
||||
another vendor's route) → filter by `AllowedGroupIDs` intersection → model
|
||||
precedence over path → tie-break by longest `UpstreamPath` prefix match.
|
||||
Model-miss returns `llm_policy.model_not_routable`; known-but-unauthorised
|
||||
returns `llm_policy.no_authorised_provider`. **Key invariant:** auth-header
|
||||
strip+inject rides on `UpstreamRewrite.{StripHeaders,AuthHeader}`
|
||||
([middleware.go:606–646](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
— NOT `HeadersAdd/HeadersRemove` — because the framework's mutation gate
|
||||
blocks `Authorization` on the generic header path.
|
||||
|
||||
**Path-routed providers route before the model table.** `Invoke` checks
|
||||
`isVertexPath` / `isBedrockPath`
|
||||
([middleware.go:138–216](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
ahead of the model lookup, so a path-carried model can't be claimed by a
|
||||
same-vendor body-routed provider. `matchPathRoute` enforces the route's `Models`
|
||||
allowlist (empty = catch-all) even though the model came from the URL.
|
||||
Two path-only behaviours:
|
||||
- **Vertex unmeterable publisher** — when `llm_request_parser` emits no
|
||||
`llm.provider` (e.g. Gemini/`google`), the router denies with
|
||||
`llm_policy.unmeterable_publisher` (403) rather than forward it uncounted.
|
||||
- **GCP token minting** — when the route carries `GCPServiceAccountKeyB64`
|
||||
(set from a `keyfile::` api_key), `gcpBearer` mints + caches a short-lived
|
||||
OAuth2 token per request instead of injecting a static value; a bad key or
|
||||
unreachable token endpoint denies with `llm_policy.upstream_auth_failed`
|
||||
(502). Bedrock uses its static bearer token directly (no minting).
|
||||
- **`/bedrock` prefix** — an optional `/bedrock` gateway-namespace prefix is
|
||||
accepted and stripped via `RewriteUpstream.StripPathPrefix` so the native
|
||||
`/model/...` path reaches the upstream.
|
||||
|
||||
Full treatment in [50-path-routed-providers.md](./50-path-routed-providers.md).
|
||||
|
||||
### llm_limit_check
|
||||
|
||||
Pre-flight gate. Reads `llm.resolved_provider_id`, calls
|
||||
`CheckLLMPolicyLimits` with a 2s context timeout
|
||||
([middleware.go:24, 97–106](../../../proxy/internal/middleware/builtin/llm_limit_check/middleware.go)),
|
||||
on allow stamps `llm.selected_policy_id`, `llm.attribution_group_id`,
|
||||
`llm.attribution_window_seconds`. **Key invariant:** fail-open. Nil
|
||||
`MgmtClient`, empty provider id, or RPC error returns `allowNoAttribution()`
|
||||
— management outage doesn't take down every LLM request. Operators audit via
|
||||
the access-log; a future flag may switch this to fail-closed.
|
||||
|
||||
### llm_identity_inject
|
||||
|
||||
Dispatches per-rule between LiteLLM-shaped `HeaderPair`
|
||||
([middleware.go:169](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go))
|
||||
and Portkey-shaped `JSONMetadata`
|
||||
([middleware.go:292](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go)).
|
||||
Identity is the peer's email (or `UserID` fallback); tags are the
|
||||
**authorising-groups intersection** emitted by `llm_router`, not the full
|
||||
`UserGroups` — a peer in 5 groups authorised under 1 only tags as that 1.
|
||||
**Anti-spoof:** every `HeadersAdd` is preceded by a `HeadersRemove` of the
|
||||
same name; the framework runs `Remove` before `Add` so client-supplied
|
||||
identity never reaches the upstream. Body-level inject (`tags_in_body`,
|
||||
`end_user_id_in_body`) is skipped on empty / truncated / non-JSON bodies so
|
||||
header attribution stays intact.
|
||||
|
||||
### llm_guardrail
|
||||
|
||||
Model allowlist deny + optional prompt-capture-with-redaction. Allowlist
|
||||
match is case-insensitive via `normaliseModel`; empty allowlist disables the
|
||||
check. Prompt capture reads `llm.request_prompt_raw` and emits
|
||||
`llm.request_prompt` only when `prompt_capture.enabled`
|
||||
([middleware.go:149–165](../../../proxy/internal/middleware/builtin/llm_guardrail/middleware.go)).
|
||||
**Key invariant:** `RedactPII` is the exported function the parsers call —
|
||||
single PII contract across all three keys.
|
||||
|
||||
### llm_response_parser
|
||||
|
||||
Buffered and SSE paths share one `Invoke`
|
||||
([middleware.go:102–127](../../../proxy/internal/middleware/builtin/llm_response_parser/middleware.go)):
|
||||
content-type sniffing dispatches to `invokeBuffered` (JSON, status<400) or
|
||||
`invokeStreaming` (text/event-stream, partial bodies tolerated). Streaming
|
||||
delegates to `accumulateStream`
|
||||
([streaming.go:21–30](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go))
|
||||
using `llm.NewScanner`. A third path, `accumulateBedrockStream`
|
||||
([streaming_bedrock.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock.go)),
|
||||
decodes the AWS binary event-stream (`application/vnd.amazon.eventstream`)
|
||||
returned by Bedrock's `-stream` actions — InvokeModel `chunk` frames wrap a
|
||||
base64 Anthropic event, Converse frames carry text + a trailing usage block.
|
||||
Cached / cache-creation buckets emit only when non-zero, preserving the existing
|
||||
token schema.
|
||||
|
||||
### cost_meter
|
||||
|
||||
Reads `llm.provider` + `llm.model` + token buckets, looks up per-1k rate via
|
||||
`pricing.Loader`, emits `cost.usd_total` or a closed-set `cost.skipped`
|
||||
reason (`missing_provider/model/tokens`, `unparseable_tokens`, `zero_tokens`,
|
||||
`unknown_model`). Loader's hot-reload goroutine is bound to proxy-lifetime
|
||||
context via `startReloader`. **Key invariant:** provider-shape switch lives
|
||||
in `pricing.Table.Cost` (sibling doc) — `cost_meter` stays provider-agnostic.
|
||||
|
||||
### llm_limit_record
|
||||
|
||||
Post-flight write. Always returns `DecisionAllow`; response has already been
|
||||
served so RPC errors mustn't surface (logged at `Debugf`). Skip-on-no-signal
|
||||
at line 81 (zero tokens + zero cost). **Key invariant:** the
|
||||
skip-on-missing-attribution guard at line 98 is a safety net independent of
|
||||
the framework's deny short-circuit — if the gate denied and the framework
|
||||
still runs the recorder, the recorder skips on absent
|
||||
`UserID`+`groupID`+`UserGroups` and no phantom counter materialises.
|
||||
|
||||
## Full-chain diagram (canonical order)
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[HTTP request] --> B[llm_request_parser<br/>OnRequest]
|
||||
B -->|llm.provider, llm.model,<br/>llm.stream, llm.request_prompt_raw| C[llm_router<br/>OnRequest]
|
||||
C -->|llm.resolved_provider_id,<br/>llm.authorising_groups,<br/>upstream rewrite + auth| D[llm_limit_check<br/>OnRequest]
|
||||
D -->|deny path| Z1[403 llm_policy.*]
|
||||
D -->|allow + llm.selected_policy_id,<br/>llm.attribution_group_id,<br/>llm.attribution_window_seconds| E[llm_identity_inject<br/>OnRequest]
|
||||
E -->|header strip+inject<br/>+ optional body rewrite| F[llm_guardrail<br/>OnRequest]
|
||||
F -->|deny: model_blocked| Z2[403 llm_policy.model_blocked]
|
||||
F -->|allow + llm.request_prompt| G[upstream LLM call]
|
||||
G --> H[llm_response_parser<br/>OnResponse]
|
||||
H -->|llm.{input,output,total,cached_input,cache_creation}_tokens,<br/>llm.response_completion| I[cost_meter<br/>OnResponse]
|
||||
I -->|cost.usd_total or cost.skipped| J[llm_limit_record<br/>OnResponse]
|
||||
J --> K[response to client]
|
||||
```
|
||||
|
||||
## limit_check ⇒ limit_record record-once invariant
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant LC as llm_limit_check
|
||||
participant M as management gRPC
|
||||
participant U as upstream LLM
|
||||
participant LR as llm_limit_record
|
||||
participant DB as sqlite consumption table
|
||||
|
||||
LC->>M: CheckLLMPolicyLimits (2s)
|
||||
alt allow
|
||||
M-->>LC: selected_policy_id, attribution_group_id, window_s
|
||||
LC->>U: stamps attribution metadata
|
||||
U-->>LR: response + tokens (via llm_response_parser + cost_meter)
|
||||
LR->>M: RecordLLMUsage (5s, debug-on-error)
|
||||
M->>DB: increment (user, group, window) row
|
||||
else deny
|
||||
M-->>LC: llm_policy.token_cap_exceeded
|
||||
Note over LR: framework short-circuits; even if invoked,<br/>recorder skips on absent UserID+groupID+UserGroups
|
||||
else mgmt nil / rpc error
|
||||
LC-->>LC: allowNoAttribution() — fail open
|
||||
Note over LR: no window_s ⇒ recorder books only account-level<br/>budget rules (which run independently)
|
||||
end
|
||||
```
|
||||
|
||||
The integration test
|
||||
[agentnetwork_chain_integration_test.go](../../../proxy/internal/middleware/builtin/agentnetwork_chain_integration_test.go)
|
||||
exercises all three branches against a real sqlite store + bufconn gRPC —
|
||||
no mocks. Tests: `TestChain_AllowPath_StampsAttributionAndRecordsCounter`
|
||||
(line 130), `TestChain_DenyPath_GateRejectsAndNoConsumptionWritten` (line
|
||||
207), `TestChain_CapExhaustTransition` (line 265).
|
||||
|
||||
## Public contracts (per-middleware JSON config)
|
||||
|
||||
| Middleware | Config shape |
|
||||
|---|---|
|
||||
| `llm_request_parser` | `{provider_id?, redact_pii?, capture_prompt?: *bool}` ([factory.go:19–37](../../../proxy/internal/middleware/builtin/llm_request_parser/factory.go)) |
|
||||
| `llm_router` | `{providers: [{id, models, upstream_scheme, upstream_host, upstream_path?, auth_header_name, auth_header_value, allowed_group_ids}]}` |
|
||||
| `llm_limit_check` | `{}` — pulls `MgmtClient` from `FactoryContext` |
|
||||
| `llm_identity_inject` | `{providers: [{provider_id, header_pair?|json_metadata?, extra_headers?}]}` |
|
||||
| `llm_guardrail` | `{model_allowlist: []string, prompt_capture: {enabled, redact_pii}}` |
|
||||
| `llm_response_parser` | `{redact_pii?, capture_completion?: *bool}` |
|
||||
| `cost_meter` | `{pricing_path?}` (basename inside data-dir; defaults `pricing.yaml`) |
|
||||
| `llm_limit_record` | `{}` — same pattern as `llm_limit_check` |
|
||||
|
||||
All factories accept empty / null / `{}` / whitespace as zero-value config;
|
||||
only structurally invalid JSON is rejected so misconfig surfaces at chain
|
||||
build time.
|
||||
|
||||
## Invariants
|
||||
|
||||
1. **limit_check ↔ limit_record paired.** They MUST appear together. Gate
|
||||
stamps attribution metadata on the request leg; recorder reads it on the
|
||||
response leg. If a chain contains only the recorder, the
|
||||
skip-on-missing-attribution guard at
|
||||
[llm_limit_record/middleware.go:81–87, 98–103](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go)
|
||||
keeps counters consistent but no enforcement runs. Only-gate means
|
||||
counters never tick and headroom appears infinite.
|
||||
|
||||
2. **`capture_prompt` / `capture_completion` pointer semantics.** Both are
|
||||
`*bool`. `nil` = "preserve legacy emit" (back-compat default for
|
||||
non-agent-network callers and pre-toggle tests). `false` = suppress the
|
||||
key entirely (access-log row carries zero prompt / completion content).
|
||||
`true` = emit. The synthesiser sets the pointer explicitly to the
|
||||
account's `EnablePromptCollection` toggle. The handling lives
|
||||
in [llm_request_parser/factory.go:55–61](../../../proxy/internal/middleware/builtin/llm_request_parser/factory.go)
|
||||
and the symmetric [llm_response_parser/middleware.go:62–68](../../../proxy/internal/middleware/builtin/llm_response_parser/middleware.go);
|
||||
a missing pointer must not be treated as `false` (that would suppress
|
||||
capture for legacy non-agent-network callers).
|
||||
`redact_pii` is an orthogonal `bool` controlling **form** of emitted
|
||||
content, not whether it's emitted.
|
||||
|
||||
3. **`redact_pii` is parser-side.** Both parsers import
|
||||
`llm_guardrail.RedactPII` and run it BEFORE stamping the metadata bag.
|
||||
Load-bearing because the access-log sink reads `llm.request_prompt_raw`
|
||||
and `llm.response_completion` directly — by the time `llm_guardrail`
|
||||
runs its own pass on `llm.request_prompt`, the raw key has already been
|
||||
stamped. Tests: `TestInvoke_RedactPii_RedactsBeforeEmittingRawPrompt`,
|
||||
`TestInvoke_RedactPii_RedactsCompletionBeforeEmit`.
|
||||
|
||||
4. **Metadata allowlist enforcement.** Every middleware declares
|
||||
`MetadataKeys()`. The framework accumulator drops any KV outside that
|
||||
allowlist. When adding a new key, also extend the docstring in
|
||||
`middleware/keys.go`.
|
||||
|
||||
5. **Closed deny-code set.** All deny paths emit one of:
|
||||
`llm_policy.model_not_routable`, `llm_policy.no_authorised_provider`,
|
||||
`llm_policy.model_blocked`, `llm_policy.token_cap_exceeded`,
|
||||
`llm_policy.unmeterable_publisher` (path-routed Vertex publisher with no
|
||||
parser → 403), `llm_policy.upstream_auth_failed` (GCP token mint failure →
|
||||
502), or the management-supplied code on `llm_limit_check`. These surface
|
||||
verbatim; arbitrary middleware text never reaches the wire.
|
||||
|
||||
## Things to scrutinise
|
||||
|
||||
**Correctness.** `llm_router` model match treats an empty `Models` slice as
|
||||
"claim every model"
|
||||
([middleware.go:238–248](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
for gateway-style providers — confirm no real provider record ships with an
|
||||
empty `Models` by accident. Path-prefix tie-break falls back to declaration
|
||||
order when no candidate prefix-matches, so the synthesiser must emit a
|
||||
deterministic order. `llm_limit_record` discards `strconv.ParseInt` errors
|
||||
([middleware.go:78–80](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go))
|
||||
— relies on `llm_response_parser` always emitting parseable values; spot-check
|
||||
the streaming partial path on truncated bodies.
|
||||
|
||||
**Security.** Auth headers must NEVER appear on `Mutations.HeadersAdd/Remove`
|
||||
for the router — a direct headers path would bypass the framework gate. The
|
||||
capture-pointer handling is the kind of place a bug ships PII to logs
|
||||
silently; every synthesiser config path must set the pointer explicitly.
|
||||
`llm_identity_inject` body inject silently skips on a
|
||||
non-object `metadata` field
|
||||
([middleware.go:262–270](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go))
|
||||
— header path still attributes, but body-level tag-budget enforcement
|
||||
doesn't run for that request.
|
||||
|
||||
**Concurrency.** `cost_meter` shares a `pricing.Loader` via
|
||||
`atomic.Pointer[Table]`; readers always see a consistent table. Every
|
||||
middleware is a stateless value receiver. Integration test uses real bufconn
|
||||
gRPC — race detector is the meaningful bar.
|
||||
|
||||
**Perf.** Hot path is `lookupKV` linear scan over <10 KVs; `cost_meter.Cost`
|
||||
is O(1); SSE accumulation is single-pass. No map allocation per call.
|
||||
|
||||
**Observability.** Every deny stamps `llm_policy.decision=deny` and a
|
||||
matching `llm_policy.reason` — access-log can pivot on either.
|
||||
`llm_limit_record` only logs at `Debugf` on RPC failure
|
||||
([middleware.go:125–130](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go));
|
||||
operators need an alternate signal (metric on `RecordLLMUsage` failures) for
|
||||
counter accuracy.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| File | Tests | Notes |
|
||||
|---|---:|---|
|
||||
| `all_test.go` | 1 | Registry surface lock |
|
||||
| `agentnetwork_chain_integration_test.go` | 3 | Allow/deny/cap-exhaust vs live sqlite + bufconn gRPC |
|
||||
| `llm_request_parser/middleware_test.go` | 18 | `provider_id` bypass, redaction, capture-pointer, rune-safe truncation |
|
||||
| `llm_router/middleware_test.go` | 19 | Three-pass match, deny codes, path-prefix tie-break, header strip+inject |
|
||||
| `llm_limit_check/middleware_test.go` | 6 | Allow/deny, fail-open on nil mgmt / RPC error, attribution stamping |
|
||||
| `llm_identity_inject/middleware_test.go` | 28 | HeaderPair, JSONMetadata, ExtraHeaders, body inject, anti-spoof |
|
||||
| `llm_guardrail/middleware_test.go` | 15 | Allowlist case-insensitivity, prompt capture toggle, deny shape |
|
||||
| `llm_guardrail/redact_test.go` | 15 | Email, SSN, phone (E.164 + NA), bearer, IPv4; fixture-driven |
|
||||
| `llm_response_parser/middleware_test.go` | 18 | Buffered OAI+Anthro, capture-pointer, redact, truncation |
|
||||
| `llm_response_parser/streaming_test.go` | 7 | OAI usage frame, Anthro message_delta, truncated body best-effort |
|
||||
| `cost_meter/middleware_test.go` | 17 | Each skip reason, provider-shape, pricing loader integration |
|
||||
| `llm_limit_record/middleware_test.go` | 7 | Skip-on-no-signal, skip-on-missing-attribution, RPC failure swallowed |
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Sibling: [32-proxy-llm-parsers.md](./32-proxy-llm-parsers.md) — SDK adapters
|
||||
+ SSE framer + pricing loader.
|
||||
- Path-routed providers (Vertex AI + Bedrock), `keyfile::` credential, GCP
|
||||
token minting, `/bedrock` prefix:
|
||||
[50-path-routed-providers.md](./50-path-routed-providers.md).
|
||||
- Upstream config: `management/server/agentnetwork/synthesizer` (out of scope).
|
||||
- Framework: `proxy/internal/middleware/{chain,dispatcher,accumulator,registry}.go`.
|
||||
- Metadata key registry: `proxy/internal/middleware/keys.go`.
|
||||
- gRPC surface: `proto.ProxyServiceClient.{CheckLLMPolicyLimits,RecordLLMUsage}`.
|
||||
392
docs/agent-networks/modules/32-proxy-llm-parsers.md
Normal file
392
docs/agent-networks/modules/32-proxy-llm-parsers.md
Normal file
@@ -0,0 +1,392 @@
|
||||
# proxy/llm-parsers — SDK adapters + pricing + SSE
|
||||
|
||||
The runtime-agnostic LLM library: the OpenAI Responses API (`/v1/responses`)
|
||||
and the older Chat Completions API (`/v1/chat/completions`), the Anthropic
|
||||
Messages API (`/v1/messages`), the SSE wire format (`event:` / `data:` lines,
|
||||
`\n\n` framing, CRLF tolerance), and per-provider token accounting (OpenAI's
|
||||
cached-prompt **subset** vs Anthropic's cache_read **additive** model). The
|
||||
pricing table's per-provider cost formula is the highest-leverage place a
|
||||
small bug would silently mis-bill operators.
|
||||
|
||||
Sibling module: [31-proxy-middleware-builtin.md](./31-proxy-middleware-builtin.md)
|
||||
— the 8 middlewares that consume this package's parsers + pricing loader.
|
||||
|
||||
---
|
||||
|
||||
## Module boundary
|
||||
|
||||
`proxy/internal/llm` is the runtime-agnostic LLM library shared by every
|
||||
middleware that needs to understand provider-specific shapes. Zero
|
||||
proxy-framework dependencies:
|
||||
|
||||
- `parser.go` — `Parser` interface, `Provider` enum, public factories
|
||||
(`Parsers`, `DetectParser`, `ParserByName`).
|
||||
- `openai.go` / `anthropic.go` / `bedrock.go` — per-provider `Parser` impls.
|
||||
- `sse.go` — SSE scanner (`Scanner`, `Event`, `NewScanner`).
|
||||
- `errors.go` — sentinels callers branch on with `errors.Is`.
|
||||
- `pricing/` — embedded-default + hot-reload override table with
|
||||
symlink-safe Unix loader (build-tagged stub elsewhere).
|
||||
- `fixtures/` — captured request/response/stream bodies the tests replay.
|
||||
|
||||
The package carries zero proxy-framework dependencies so the same parsers can
|
||||
be reused later by a WASM adapter
|
||||
([parser.go:1–6](../../../proxy/internal/llm/parser.go)).
|
||||
|
||||
## Files
|
||||
|
||||
| File | LOC | Notes |
|
||||
|---|---:|---|
|
||||
| `parser.go` | 104 | Interface + factories + `Provider{Unknown,OpenAI,Anthropic}` enum |
|
||||
| `openai.go` | 347 | Chat Completions + Completions + Responses API; cached_tokens subset |
|
||||
| `openai_test.go` | 222 | 11 tests; fixture replay + cached/Responses-API matrix |
|
||||
| `anthropic.go` | 172 | Messages + legacy `/v1/complete`; cache_read + cache_creation additive |
|
||||
| `anthropic_test.go` | 154 | 7 tests including streaming-extraction-skipped contract |
|
||||
| `bedrock.go` | 190 | AWS Bedrock InvokeModel (snake_case) + Converse (camelCase) response shapes; model lives in URL path |
|
||||
| `bedrock_test.go` | — | InvokeModel + Converse usage shapes; AWS event-stream content-type → `ErrStreamingUnsupported` on buffered `ParseResponse` |
|
||||
| `sse.go` | 117 | `bufio`-backed scanner; CRLF normalised; trailing-event handling |
|
||||
| `sse_test.go` | 175 | 12 tests; fixture replay + multiline + size limits |
|
||||
| `parser_test.go` | 53 | `Parsers()`, `DetectParser`, provider enum values |
|
||||
| `errors.go` | 31 | 6 sentinels: `Err{Unknown,Unsupported}Provider/Model`, `Err{NotLLM,Malformed}Response`, `ErrStreamingUnsupported`, `ErrMalformedRequest` |
|
||||
| `pricing/pricing.go` | 421 | `Loader`, `Table`, `Entry`; embedded defaults + atomic swap + mtime reload |
|
||||
| `pricing/pricing_unix.go` | 69 | `O_NOFOLLOW` + fstat-from-FD + 1 MiB cap |
|
||||
| `pricing/pricing_other.go` | 21 | Stub returning "not supported on this platform" |
|
||||
| `pricing/pricing_test.go` | 432 | 21 tests — symlink rejection, reload race, path traversal, oversize |
|
||||
| `pricing/defaults_pricing.yaml` | 85 | go:embed source of truth |
|
||||
| `fixtures/*` | 21–59 | OAI chat/responses/stream + Anthro messages/stream + pricing starter |
|
||||
|
||||
## Request body → parser dispatch
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[HTTP request<br/>URL + JSON body] --> B{ParserByName?<br/>provider_id config set}
|
||||
B -- yes --> P[matched Parser]
|
||||
B -- no --> C[DetectParser]
|
||||
C --> D{loop Parsers<br/>OpenAIParser, AnthropicParser}
|
||||
D -- DetectFromURL match --> P
|
||||
D -- no match --> X[ok=false<br/>middleware skips]
|
||||
P --> E[ParseRequest body]
|
||||
E -->|err: ErrMalformedRequest| Y[middleware emits provider only]
|
||||
E --> F[RequestFacts<br/>model + stream]
|
||||
P --> G[ExtractPrompt body]
|
||||
G --> H[joinMessages<br/>extractContentParts<br/>decodeStringOrJoin]
|
||||
H --> I[prompt text<br/>or empty]
|
||||
F --> J[stamps llm.model + llm.stream]
|
||||
I --> K[stamps llm.request_prompt_raw<br/>subject to capture_prompt gate]
|
||||
```
|
||||
|
||||
OpenAI's URL hints
|
||||
([openai.go:27–33](../../../proxy/internal/llm/openai.go)) include
|
||||
both `/v1/chat/completions` and the bare `/chat/completions` — the latter
|
||||
covers Cloudflare AI Gateway, which rewrites the canonical version segment.
|
||||
Anthropic's hints are `/v1/messages` and `/v1/complete`
|
||||
([anthropic.go:14–17](../../../proxy/internal/llm/anthropic.go)).
|
||||
Both implementations use case-insensitive substring matching so a proxy prefix
|
||||
strip / rewrite doesn't defeat detection.
|
||||
|
||||
`ParserByName` ([parser.go:93–103](../../../proxy/internal/llm/parser.go))
|
||||
is the **agent-network bypass**: the synthesiser knows which parser to use
|
||||
because it built the synth service from the catalog, so it stamps
|
||||
`provider_id` on the parser config and the middleware skips URL sniffing
|
||||
entirely. This is what makes the same parser set work whether the request
|
||||
flows to OpenAI direct, to LiteLLM, to Portkey, or to any gateway with a
|
||||
non-canonical URL shape.
|
||||
|
||||
**Path-routed providers (Vertex AI, Bedrock) bypass both `ParserByName` and
|
||||
`DetectParser`.** The model and the parser surface live in the URL path, so the
|
||||
request middleware extracts them directly (`parseVertexPath` /
|
||||
`parseBedrockPath`) before the parser-selection step. For Vertex the publisher
|
||||
segment picks the parser (`anthropic` → Anthropic parser; `google`/Gemini →
|
||||
none, request denied as unmeterable). For Bedrock the dedicated `BedrockParser`
|
||||
handles the response. Full treatment in
|
||||
[50-path-routed-providers.md](./50-path-routed-providers.md).
|
||||
|
||||
## Streaming response → SSE chunker → response parser → completion + token count
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant U as upstream LLM
|
||||
participant LR as llm_response_parser<br/>(OnResponse)
|
||||
participant S as llm.NewScanner<br/>(SSE framer)
|
||||
participant P as Parser-specific accumulator<br/>(accumulateOpenAIStream<br/>or accumulateAnthropicStream)
|
||||
|
||||
U-->>LR: text/event-stream<br/>(buffered prefix in RespBody)
|
||||
LR->>S: NewScanner(bytes.NewReader(body))
|
||||
loop until EOF or [DONE]
|
||||
S-->>LR: Event{Type, Data}
|
||||
LR->>P: dispatch per event.Type<br/>(OpenAI: data-only<br/>Anthropic: named events)
|
||||
P-->>P: accumulate completion text<br/>track usage from final frame
|
||||
end
|
||||
P-->>LR: llm.Usage + completion string
|
||||
LR->>LR: appendUsage stamps<br/>llm.{input,output,total,cached_input,cache_creation}_tokens
|
||||
LR->>LR: truncateCompletion(3500 bytes, rune-safe)
|
||||
LR->>LR: redactPII if redact_pii && captureCompletion
|
||||
```
|
||||
|
||||
`Scanner.Next`
|
||||
([sse.go:44–87](../../../proxy/internal/llm/sse.go)) returns one
|
||||
event per `\n\n` boundary; multiple `data:` lines join with `\n`; comment lines
|
||||
(starting with `:`) are skipped per the SSE spec; a trailing event without a
|
||||
closing blank line is still returned before `io.EOF` so a server that closes
|
||||
the connection cleanly doesn't lose the last frame
|
||||
([sse.go:55–58](../../../proxy/internal/llm/sse.go)). CRLF is
|
||||
normalised in `trimEOL` so fixtures captured from live servers replay
|
||||
unchanged.
|
||||
|
||||
## Per-provider
|
||||
|
||||
### OpenAI
|
||||
|
||||
[openai.go:54–67](../../../proxy/internal/llm/openai.go) defines
|
||||
`openAIRequest` with three prompt fields: `messages` (Chat Completions),
|
||||
`prompt` (legacy), `input` (Responses API). The decoder uses
|
||||
`json.RawMessage` so each shape is parsed lazily.
|
||||
|
||||
`ParseResponse`
|
||||
([openai.go:117–146](../../../proxy/internal/llm/openai.go))
|
||||
accepts both naming conventions: Chat Completions returns
|
||||
`prompt_tokens`/`completion_tokens`, Responses API returns
|
||||
`input_tokens`/`output_tokens`. `pickInt64` prefers Responses-API names and
|
||||
falls back — same parser handles both endpoints without per-route config.
|
||||
`openAICachedTokens` mirrors the fallback for
|
||||
`input_tokens_details.cached_tokens` vs `prompt_tokens_details.cached_tokens`.
|
||||
|
||||
**Key invariant:** `CachedInputTokens` for OpenAI is a SUBSET of
|
||||
`InputTokens`. The cost meter clamps to guard against malformed upstream
|
||||
responses where `cached > total`.
|
||||
|
||||
### Anthropic
|
||||
|
||||
[anthropic.go:37–49](../../../proxy/internal/llm/anthropic.go)
|
||||
defines `anthropicRequest` covering Messages API (`system` + `messages[]`)
|
||||
and legacy `/v1/complete` (`prompt` string). `ExtractPrompt` emits
|
||||
`system: <text>` first when present, then per-message `role: content`.
|
||||
|
||||
`ParseResponse`
|
||||
([anthropic.go:82–104](../../../proxy/internal/llm/anthropic.go))
|
||||
fills three independent token buckets: `InputTokens`, `CacheReadInputTokens`,
|
||||
`CacheCreationInputTokens`. Latter two are **additive** (not subset).
|
||||
`TotalTokens` sums all four so downstream dashboards render one "tokens"
|
||||
number without double-counting.
|
||||
|
||||
`ExtractCompletion` walks `content[]` `{type, text}` parts and concatenates
|
||||
non-empty text with newlines, falling back to legacy `completion`.
|
||||
|
||||
### Bedrock
|
||||
|
||||
[bedrock.go](../../../proxy/internal/llm/bedrock.go) implements the
|
||||
`Parser` interface for the AWS Bedrock runtime. Bedrock is **path-routed**: the
|
||||
model lives in the URL (`/model/{id}/{action}`), so the request middleware
|
||||
extracts it (see [50-path-routed-providers.md](./50-path-routed-providers.md))
|
||||
and `ParseRequest` is a deliberate no-op. The parser's real work is on the
|
||||
response leg, covering both Bedrock body shapes:
|
||||
|
||||
- **InvokeModel** — vendor-native. Anthropic-on-Bedrock returns snake_case usage
|
||||
(`input_tokens`, `output_tokens`, `cache_read_input_tokens`,
|
||||
`cache_creation_input_tokens`) with the same additive cache buckets as
|
||||
first-party Anthropic.
|
||||
- **Converse** — unified camelCase (`inputTokens`, `outputTokens`,
|
||||
`totalTokens`). `firstNonZero` folds the two naming conventions into one
|
||||
`Usage`; when Converse omits `totalTokens` the parser sums the buckets.
|
||||
|
||||
`ProviderName()` returns `"bedrock"` — its own `defaults_pricing.yaml` block,
|
||||
keyed by the **normalised** model id (region prefix + version suffix stripped by
|
||||
the request parser). `ParseResponse` returns `ErrStreamingUnsupported` for an
|
||||
AWS binary event-stream content-type (`application/vnd.amazon.eventstream`,
|
||||
`isAWSEventStream`) so the caller routes to the streaming accumulator instead.
|
||||
|
||||
### SSE framing
|
||||
|
||||
`Scanner` is `bufio`-backed, 64 KiB read buffer, 1 MiB max line so a
|
||||
malicious upstream can't blow process memory
|
||||
([sse.go:33–38, 97–100](../../../proxy/internal/llm/sse.go)).
|
||||
`splitField` strips one space after the `:` per the SSE spec. Documented
|
||||
`not safe for concurrent use`; every consumer creates a fresh scanner per
|
||||
response body. Streaming accumulators live in the middleware package
|
||||
([llm_response_parser/streaming.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go))
|
||||
but use `llm.NewScanner` so the framing contract stays here.
|
||||
|
||||
### Pricing catalog
|
||||
|
||||
`Table.Cost`
|
||||
([pricing.go:129–174](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
is the cost formula — most security-relevant math in this module:
|
||||
|
||||
| Provider | Formula |
|
||||
|---|---|
|
||||
| `openai` | `(inTokens − clamped) × InputPer1K + clamped × CachedInputPer1K + outTokens × OutputPer1K` where `clamped = min(cachedInput, inTokens)` |
|
||||
| `anthropic`, `bedrock` | `inTokens × InputPer1K + cachedInput × CacheReadPer1K + cacheCreation × CacheCreationPer1K + outTokens × OutputPer1K` |
|
||||
| default | `inTokens × InputPer1K + outTokens × OutputPer1K` |
|
||||
|
||||
`bedrock` shares the Anthropic additive-cache formula
|
||||
([pricing.go:172-174](../../../proxy/internal/llm/pricing/pricing.go)):
|
||||
Anthropic-on-Bedrock reports the same additive cache buckets, while non-Anthropic
|
||||
Bedrock models (Nova, Llama) simply report zero in those buckets so cost reduces
|
||||
to `input + output`.
|
||||
|
||||
Each per-bucket rate falls back to `InputPer1K` when zero — operators opt in
|
||||
to discounts by setting the field.
|
||||
|
||||
`Loader`
|
||||
([pricing.go:212–268](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
overlays an optional `pricing.yaml` from data-dir on top of the go:embed
|
||||
defaults. Atomic pointer swap means readers never observe a partial update.
|
||||
The mtime-poll reloader (30s default cadence) keeps the previous table on
|
||||
parse failure so cost annotation never goes blank during a botched edit.
|
||||
|
||||
`defaults_pricing.yaml` is the source of truth for built-in pricing.
|
||||
Operator overrides only carry the entries they want to change.
|
||||
|
||||
## Public contracts
|
||||
|
||||
**`Parser` interface**
|
||||
([parser.go:50–66](../../../proxy/internal/llm/parser.go)):
|
||||
|
||||
```go
|
||||
type Parser interface {
|
||||
Provider() Provider
|
||||
ProviderName() string
|
||||
DetectFromURL(path string) bool
|
||||
ParseRequest(body []byte) (RequestFacts, error)
|
||||
ParseResponse(status int, contentType string, body []byte) (Usage, error)
|
||||
ExtractPrompt(body []byte) string
|
||||
ExtractCompletion(status int, contentType string, body []byte) string
|
||||
}
|
||||
```
|
||||
|
||||
Adding a provider means implementing this interface and appending to the
|
||||
slice returned by `Parsers()` ([parser.go:78–84](../../../proxy/internal/llm/parser.go)).
|
||||
Order matters: `DetectFromURL` ties resolve by registration order.
|
||||
`Parsers()` today returns `{OpenAIParser, AnthropicParser, BedrockParser}`.
|
||||
|
||||
**`Provider` enum**
|
||||
([parser.go:8–18](../../../proxy/internal/llm/parser.go)):
|
||||
`ProviderUnknown = 0`, `ProviderOpenAI = 1`, `ProviderAnthropic = 2`,
|
||||
`ProviderBedrock = 3`. Numeric values are persisted in nothing today but treat
|
||||
them as wire-stable — new providers must take fresh numbers.
|
||||
|
||||
**`Pricing` lookup**
|
||||
([pricing.go:129](../../../proxy/internal/llm/pricing/pricing.go)):
|
||||
|
||||
```go
|
||||
func (t *Table) Cost(provider, model string, inTokens, outTokens, cachedInput, cacheCreation int64) (float64, bool)
|
||||
```
|
||||
|
||||
Nil-safe: `t.Cost` on a nil receiver returns `(0, false)`
|
||||
([pricing.go:130–132](../../../proxy/internal/llm/pricing/pricing.go)).
|
||||
`ok=false` means provider or model is absent from the loaded table; the caller
|
||||
emits `cost.skipped=unknown_model`.
|
||||
|
||||
## Invariants
|
||||
|
||||
1. **Cross-platform pricing build.** `pricing_unix.go` carries the only
|
||||
functional `loadPricing` (uses `syscall.O_NOFOLLOW` and `f.Stat()` on an
|
||||
open descriptor — both Unix-only). `pricing_other.go` is a build-tag
|
||||
fallback that returns `"not supported on this platform"`
|
||||
([pricing_other.go:14–16](../../../proxy/internal/llm/pricing/pricing_other.go)).
|
||||
The proxy is Linux-only in production today; a Windows port needs an
|
||||
equivalent path-as-handle implementation. Reviewers building on Windows
|
||||
should expect this surface to return an error at startup if an override
|
||||
file is configured.
|
||||
|
||||
2. **SSE scanner handles partial chunks.** A buffered prefix that doesn't end
|
||||
in `\n\n` still yields its accumulated event before `io.EOF`
|
||||
([sse.go:55–58](../../../proxy/internal/llm/sse.go)). Tests:
|
||||
`TestSSEScanner_OpenAIFixture`, `TestSSEScanner_AnthropicFixture`,
|
||||
`TestSSEScanner_MultilineData`, `TestSSEScanner_CRLF`. The streaming
|
||||
accumulators ride on this: `accumulateAnthropicStream` and
|
||||
`accumulateOpenAIStream` `break` on any scanner error to return partial
|
||||
usage rather than aborting
|
||||
([streaming.go:68–73, 144–150](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go)).
|
||||
|
||||
3. **`defaults_pricing.yaml` is the source of truth.** Compiled into the
|
||||
binary via `//go:embed`
|
||||
([pricing.go:29–30](../../../proxy/internal/llm/pricing/pricing.go)).
|
||||
`DefaultTable()` parses once and panics on parse failure
|
||||
([pricing.go:42–49](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
— by design: a broken embedded YAML must not ship to production.
|
||||
|
||||
4. **Loader path validation.** `resolveMiddlewareDataPath`
|
||||
([pricing.go:370–394](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
rejects absolute paths, traversal segments, and basenames that fail
|
||||
`basenameRegex = ^[a-zA-Z0-9._-]+$`. The resolved path must remain
|
||||
inside `baseDir` even after `filepath.Clean`. Tests:
|
||||
`TestNewLoader_PathValidation`, `TestNewLoader_PathValidation_Extended`,
|
||||
`TestNewLoader_SymlinkOutsideBaseDirRejected`, `TestNewLoader_SymlinkRejected`.
|
||||
|
||||
5. **Unix loader symlink safety.** `O_NOFOLLOW` on open, `f.Stat()` on the
|
||||
open descriptor (never re-stat by path), `info.Mode().IsRegular()` check,
|
||||
`io.LimitReader(f, maxPricingBytes+1)` with a final size assertion
|
||||
([pricing_unix.go:25–57](../../../proxy/internal/llm/pricing/pricing_unix.go)).
|
||||
A mid-read symlink swap is detected because the fstat is on the original
|
||||
fd. Test: `TestNewLoader_RejectsOversizedFile_FixesM4`.
|
||||
|
||||
6. **`yaml.NewDecoder(...).KnownFields(true)`**
|
||||
([pricing.go:397–398](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
rejects YAML files that carry fields not in the schema. A typo in an
|
||||
operator override file fails loud instead of silently zeroing rates.
|
||||
|
||||
## Things to scrutinise
|
||||
|
||||
**Correctness.** Verify OpenAI cached-prompt clamp at
|
||||
[pricing.go:147–149](../../../proxy/internal/llm/pricing/pricing.go)
|
||||
short-circuits before subtraction. `Anthropic.TotalTokens` sums all four
|
||||
buckets (in + out + cache_read + cache_creation) — downstream dashboards
|
||||
need to know this differs from `input + output`.
|
||||
`OpenAIParser.ExtractPrompt` falls through `messages → input → prompt`; a
|
||||
request sending all three reports only `messages` (uncommon but worth
|
||||
noting).
|
||||
|
||||
**Security.** `Scanner.maxLine = 1 MiB`; a 2 MiB single-line `data:` event
|
||||
errors from `Scanner.Next` and both accumulators stop with partial usage.
|
||||
Pricing file 1 MiB cap is orders of magnitude larger than realistic. Confirm
|
||||
new schema additions are mirrored in both `pricingFile` and `Entry`;
|
||||
`KnownFields(true)` will reject silently-typo'd operator overrides
|
||||
otherwise.
|
||||
|
||||
**Concurrency.** `Loader.table` is `atomic.Pointer[Table]`; readers never
|
||||
block or see a torn table. `Loader.Reload` is one goroutine, cancelled via
|
||||
context (`TestLoader_ReloadBackgroundLoopCancellation`). `DefaultTable()`
|
||||
uses `sync.Once`. Per-call `Scanner` instances mean no shared state across
|
||||
concurrent response-parser calls.
|
||||
|
||||
**Perf.** `Table.Cost` is two map lookups + multiplications, O(1).
|
||||
`Scanner.Next` is one `ReadString('\n')` per line. Pricing reload poll 30s.
|
||||
|
||||
**Observability.** Reload failures count via `metric.Int64Counter` keyed
|
||||
`plugin`; warning log rate-limited at 5 min so a broken file doesn't flood.
|
||||
Parser errors return sentinels — middleware uses `errors.Is` to map to the
|
||||
right `cost.skipped` reason.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| File | Tests | Coverage highlights |
|
||||
|---|---:|---|
|
||||
| `parser_test.go` | 3 | `Parsers()` shape lock, `DetectParser` URL matrix, provider enum stability |
|
||||
| `openai_test.go` | 11 | Chat Completions + Responses API + legacy `prompt`; cached-tokens subset for both naming conventions; fixture replays |
|
||||
| `anthropic_test.go` | 7 | Messages + legacy `/v1/complete`; streaming REJECTED on `ParseResponse` (must use scanner); fixture replays |
|
||||
| `sse_test.go` | 12 | Fixture replay both providers; multiline `data:`; CRLF; comment skip; trailing-event-without-blank-line; oversize rejection |
|
||||
| `pricing/pricing_test.go` | 21 | Provider-shape switch; cached-rate fallback; cached-clamp; symlink rejection (target outside basedir + symlink to file); path validation matrix; oversize rejection; reload-keeps-previous-on-parse-error; mtime change detection; goroutine cancellation |
|
||||
|
||||
**Fixtures** ([proxy/internal/llm/fixtures/](../../../proxy/internal/llm/fixtures/)):
|
||||
`openai_chat_completion.json` (chat.completions with usage),
|
||||
`openai_responses.json` (Responses API shape),
|
||||
`openai_stream.txt` (3 deltas + usage + `[DONE]`),
|
||||
`anthropic_messages.json` (Messages API non-streaming),
|
||||
`anthropic_stream.txt` (full 7-event sequence: message_start →
|
||||
content_block_{start,delta×2,stop} → message_delta (usage) → message_stop),
|
||||
`pricing.yaml` (realistic-pricing starter for operator overrides).
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Sibling: [31-proxy-middleware-builtin.md](./31-proxy-middleware-builtin.md)
|
||||
— the chain that calls `llm.Parsers()`, `llm.ParserByName`,
|
||||
`llm.NewScanner`, `pricing.NewLoader`.
|
||||
- Path-routed providers (Vertex AI + Bedrock), credential syntax, and the
|
||||
Bedrock AWS event-stream accumulator:
|
||||
[50-path-routed-providers.md](./50-path-routed-providers.md).
|
||||
- Direct callers: `llm_request_parser/middleware.go:82–94`,
|
||||
`llm_response_parser/middleware.go:113–123`,
|
||||
`llm_response_parser/streaming.go:65, 142`, `cost_meter/factory.go:49–57`.
|
||||
- Related elsewhere: the agent-network synthesiser stamping `provider_id`
|
||||
is covered in the management-side module guide; proxy server boot +
|
||||
`FactoryContext` construction is covered in the proxy-framework guide.
|
||||
194
docs/agent-networks/modules/33-proxy-runtime.md
Normal file
194
docs/agent-networks/modules/33-proxy-runtime.md
Normal file
@@ -0,0 +1,194 @@
|
||||
# proxy/runtime — translate + serve + log
|
||||
|
||||
> **Risk level:** High — every config push from management is translated here, and the chain runs on every HTTP request to a synth target.
|
||||
> **Backward-compat impact:** Additive at the wire (`PathTargetOptions.middlewares`, `agent_network`, `disable_access_log`, capture caps) and on the proxy `Server` struct (`MiddlewareDataDir`, `MiddlewareCaptureBudgetBytes`). Non-agent-network targets stay on the no-middleware fast path.
|
||||
|
||||
## Module boundary
|
||||
|
||||
Turns the synth-service wire format from `ProxyService.SyncMappings`/`GetMappingUpdate` into in-process middleware chains and runs them on top of the existing `httputil.ReverseProxy`. Four concerns: (a) **translate** — `proto.MiddlewareConfig` → validated `middleware.Spec` (proxy/middleware_translate.go) + self-register the eight built-ins (proxy/middleware_register.go); (b) **boot + rebuild** — construct the `middleware.Manager`, share the OTel meter, install the live-service check, rebuild per-path chains on every `addMapping`/`modifyMapping` (proxy/server.go); (c) **serve** — resolve chain at request time, capture bodies under a global budget, invoke `RunRequest`/`RunResponse`/`RunTerminal`, render deny responses, apply `UpstreamRewrite` (proxy/internal/proxy/reverseproxy.go); (d) **log + tag** — emit access-log entries with the new `agent_network` flag, gate emission on `EnableLogCollection` via `DisableAccessLog` (proxy/internal/accesslog).
|
||||
|
||||
**Inert for non-agent-network targets**: nil or empty chain → existing fast path (reverseproxy.go:127-139); `SuppressAccessLog` defaults false so the access-log middleware emits unchanged.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| proxy/middleware_translate.go | proto→Spec translation; slot/failmode/timeout mapping; caps |
|
||||
| proxy/middleware_translate_test.go | translator unit tests |
|
||||
| proxy/middleware_register.go | blank-imports the eight builtins for `init()` registration |
|
||||
| proxy/server.go | `initMiddlewareManager`, `rebuildMiddlewareChains`, `isLiveService`, `buildMiddlewareBindings`, new Server fields, `protoToMapping` stamps AgentNetwork/DisableAccessLog/CaptureConfig/Middlewares |
|
||||
| proxy/internal/proxy/reverseproxy.go | `WithMiddlewareManager`, chain dispatch, body capture, `applyUpstreamRewrite`/`Headers`, `buildRequestInput`, response-leg respInput identity fields |
|
||||
| proxy/internal/proxy/reverseproxy_test.go | `TestBuildRequestInput_PropagatesIdentityAndGroups` |
|
||||
| proxy/internal/proxy/context.go | `agentNetwork`, `suppressAccessLog`, `userGroupNames` on `CapturedData` |
|
||||
| proxy/internal/proxy/servicemapping.go | new `PathTarget` fields |
|
||||
| proxy/internal/proxy/agent_network_chain_realstack_test.go | end-to-end self-contained chain test |
|
||||
| proxy/internal/accesslog/logger.go | `logEntry.AgentNetwork` → `proto.AccessLog` |
|
||||
| proxy/internal/accesslog/middleware.go | reads `GetAgentNetwork()`; gates `l.log` on `!GetSuppressAccessLog()` |
|
||||
| proxy/internal/accesslog/middleware_test.go | suppress/default/preserves-usage assertions |
|
||||
| proxy/internal/auth/middleware_test.go | tunnel-peer group propagation contract |
|
||||
| proxy/internal/metrics/metrics.go | `Meter()` getter for the middleware manager |
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### Synth-service ingestion → translate → register → serve
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Management SyncMappings/GetMappingUpdate] --> B["processMappings\nserver.go:1492"]
|
||||
B --> C{Mapping type}
|
||||
C -->|CREATED| D["addMapping → setupHTTPMapping → updateMapping"]
|
||||
C -->|MODIFIED| E["modifyMapping → cleanupMappingRoutes → setupHTTPMapping → updateMapping"]
|
||||
C -->|REMOVED| F["removeMapping → cleanupMappingRoutes → invalidateMiddlewareChains"]
|
||||
D --> G["protoToMapping\nserver.go:2181"]
|
||||
E --> G
|
||||
G --> H["translateMiddlewareConfigs\nmiddleware_translate.go:55"]
|
||||
G --> I["translateMiddlewareCaptureConfig\nmiddleware_translate.go:18"]
|
||||
H --> J["[]middleware.Spec on PathTarget"]
|
||||
I --> K["*bodytap.Config on PathTarget"]
|
||||
J --> L["proxy.AddMapping\nservicemapping.go:118"]
|
||||
K --> L
|
||||
L --> M["rebuildMiddlewareChains\nserver.go:2017 → Manager.Rebuild"]
|
||||
F --> N["Manager.Invalidate(serviceID)"]
|
||||
```
|
||||
|
||||
### Per-request lifecycle through the chain + accesslog
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
participant C as Client
|
||||
participant M as accesslog.Middleware
|
||||
participant A as auth.Middleware (Protect)
|
||||
participant RP as ReverseProxy.ServeHTTP
|
||||
participant CH as middleware.Chain
|
||||
participant U as Upstream
|
||||
C->>M: HTTP request
|
||||
M->>M: NewCapturedData(requestID), WithCapturedData(ctx)
|
||||
M->>A: next.ServeHTTP
|
||||
A->>A: Private → ValidateTunnelPeer → stamp UserID/Email/Groups/GroupNames/AuthMethod
|
||||
A->>RP: next.ServeHTTP
|
||||
RP->>RP: findTargetForRequest → targetResult
|
||||
RP->>RP: stamp ServiceID/AccountID/AgentNetwork/SuppressAccessLog on CapturedData
|
||||
RP->>RP: resolveChain via Manager.ChainFor
|
||||
alt chain == nil or Empty
|
||||
RP->>U: httputil.ReverseProxy.ServeHTTP (fast path)
|
||||
else chain non-empty
|
||||
RP->>RP: bodytap.CaptureRequest (global budget)
|
||||
RP->>CH: RunRequest
|
||||
CH-->>RP: denyOutput? requestMeta + upstreamRewrite
|
||||
alt deny
|
||||
RP->>C: RenderDenyResponse
|
||||
else allow
|
||||
RP->>RP: capturingWriter + applyUpstreamRewrite/Headers
|
||||
RP->>U: httputil.ReverseProxy.ServeHTTP(respWriter)
|
||||
U-->>RP: response
|
||||
RP->>CH: RunResponse (respInput carries UserGroups)
|
||||
RP->>CH: RunTerminal (merged request+response metadata)
|
||||
end
|
||||
end
|
||||
RP-->>M: handler returns
|
||||
M->>M: build logEntry incl. AgentNetwork
|
||||
alt SuppressAccessLog == true
|
||||
M->>M: skip l.log; still trackUsage
|
||||
else default
|
||||
M->>M: l.log → goroutine SendAccessLog
|
||||
end
|
||||
```
|
||||
|
||||
### EnableLogCollection suppression path
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
S["agentnetwork.Settings.EnableLogCollection"] --> B["synthesizer: target.DisableAccessLog = !EnableLogCollection"]
|
||||
B --> P["proto PathTargetOptions.disable_access_log (field 13)"]
|
||||
P --> T["protoToMapping reads GetDisableAccessLog()\nserver.go:2211"]
|
||||
T --> M["PathTarget.DisableAccessLog\nservicemapping.go:47"]
|
||||
M --> R["ServeHTTP: cd.SetSuppressAccessLog\nreverseproxy.go:106"]
|
||||
R --> G["accesslog middleware: if !GetSuppressAccessLog l.log\nmiddleware.go:95"]
|
||||
R --> U["trackUsage unconditional — bandwidth telemetry preserved"]
|
||||
```
|
||||
|
||||
**Ingestion** lands as a `ProxyMapping` batch on `handleSyncMappingsStream`/`handleMappingStream`. `processMappings` dispatches to `addMapping`/`modifyMapping`/`removeMapping`; HTTP goes `setupHTTPMapping → updateMapping → protoToMapping`. `protoToMapping` (server.go:2181) is the single translation surface that materialises `[]middleware.Spec`, `*bodytap.Config`, `AgentNetwork`, `DisableAccessLog` onto each `PathTarget`; `updateMapping` finishes with `s.proxy.AddMapping(m)` (atomic swap under `mappingsMux`) and `s.rebuildMiddlewareChains(svcID, m)`.
|
||||
|
||||
At **request time** the access-log middleware stamps `CapturedData`; the auth chain runs (Private services lift `peer_group_ids` from `ValidateTunnelPeer` — auth/middleware_test.go:322). `ReverseProxy.ServeHTTP` resolves the chain; nil or empty → original `httputil.ReverseProxy`, no body capture. When a chain matches, body is captured under the global budget, `RunRequest` produces an `UpstreamRewrite` (`llm_router` selects a provider, rewrites scheme/host/path, injects `Authorization`), and `RunResponse`+`RunTerminal` run after the upstream returns. The terminal slot sees the merged metadata bag — that's how `llm_limit_record` ships the consumption sample. The **access-log** addition: `logEntry.AgentNetwork` from `GetAgentNetwork()` onto `proto.AccessLog.AgentNetwork`; the gate at middleware.go:95 honors `EnableLogCollection`, skipping `l.log` but keeping `trackUsage` so bandwidth telemetry survives.
|
||||
|
||||
## Public contracts touched
|
||||
|
||||
- `proxy.Server.MiddlewareDataDir` (string) — base dir for file-backed middleware config (server.go:238-241).
|
||||
- `proxy.Server.MiddlewareCaptureBudgetBytes` (int64) — process-wide capture cap; defaults to 256 MiB (server.go:248-250).
|
||||
- `proxy/internal/proxy.WithMiddlewareManager(*middleware.Manager) Option` — new option on `NewReverseProxy`; nil keeps the fast path (reverseproxy.go:48-56).
|
||||
- `proxy/internal/proxy.PathTarget` adds `Middlewares`, `CaptureConfig`, `AgentNetwork`, `DisableAccessLog` (servicemapping.go:27-51), all zero-default.
|
||||
- `proxy/internal/proxy.CapturedData` adds `agentNetwork`, `suppressAccessLog`, `userGroupNames` behind `sync.RWMutex`; slices deep-copied (context.go:47-66, 183-258).
|
||||
- `accesslog.logEntry.AgentNetwork` + `proto.AccessLog.AgentNetwork` (logger.go:131, 268).
|
||||
- `metrics.Metrics.Meter()` exposes the OTel meter for the middleware manager (metrics.go:53-58).
|
||||
|
||||
## Invariants
|
||||
|
||||
- **Synth-service updates are live (no proxy restart).** Every `MODIFIED` flows through `modifyMapping → cleanupMappingRoutes` (invalidates chains) `→ setupHTTPMapping → updateMapping → rebuildMiddlewareChains`. **ProxyMapping.Private preservation:** the relevant logic lives in `management/internals/shared/grpc/proxy.go:shallowCloneMapping`, not this module, but it surfaces here — if a `MODIFIED` synth service arrives `private=false`, auth skips `ValidateTunnelPeer`, `CapturedData.UserGroups` stays empty, and `llm_router` denies with `llm_policy.no_authorised_provider` until a management restart re-pushes the snapshot. This module assumes `mapping.GetPrivate()` is correct on every batch.
|
||||
- **`EnableLogCollection=false` suppresses access-log writes but middleware still runs.** Gate is one `if !cd.GetSuppressAccessLog()` immediately around `l.log(entry)` (middleware.go:95); `trackUsage` runs below the gate. Locked by `TestMiddleware_SuppressAccessLog_PreservesUsageTracking` (middleware_test.go:139).
|
||||
- **`agent_network` flag on access-log entries is set when the chain processed the request.** Source `target.AgentNetwork`, stamped at reverseproxy.go:105, read at accesslog/middleware.go:86.
|
||||
- **auth → builtin group propagation.** `Protect` writes `UserGroups`/`UserGroupNames`; `buildRequestInput` (reverseproxy.go:333) copies them into `middleware.Input`. The response-leg `respInput` (reverseproxy.go:196-223) also carries `UserEmail`/`UserGroups`/`UserGroupNames` — `llm_limit_record` needs `UserGroups` to ship `group_ids` so management's group-targeted budget rules match (comment at reverseproxy.go:211-215).
|
||||
- **Empty chains stay on the fast path.** `ServeHTTP` skips body capture and the run sequence when `chain == nil || chain.Empty()` (reverseproxy.go:127).
|
||||
- **Self-registration is the only way a builtin reaches the registry.** `middleware_register.go` blank-imports each builtin; `init()` adds the factory to `mwbuiltin.DefaultRegistry()`. Missing it → translator drops the entry with a warn (translate.go:97).
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
- **Translate edge cases** — drops on nil cfg, empty ID, unknown ID, UNSPECIFIED slot; each logs one warn; volume bounded by `MaxMiddlewaresPerChain`.
|
||||
- **Re-translate without dropping in-flight requests** — `Manager.Rebuild` is the only call from `rebuildMiddlewareChains`. Reverse proxy reads `ChainFor` once per request (reverseproxy.go:327) and runs the captured `*Chain` for the whole request. Verify in module 30 that `Rebuild` swaps atomically.
|
||||
- **ProxyMapping.Private preservation** — enforced management-side in `shallowCloneMapping`. Proxy-side regression catches: `TestProtect_PrivateService_TunnelPeerGroupsPropagate` + the integration test.
|
||||
- **Body-capture cleanup** — `defer releaseBudget()` (reverseproxy.go:145) and `defer capturingWriter.Release()` (reverseproxy.go:180) must run on every return; confirm no future `return` lands between acquisition and defer.
|
||||
- **`applyUpstreamRewrite` clones the URL** — `cloned := *orig` value-copies `*url.URL`; safe because overwritten fields are strings, not slices/maps (reverseproxy.go:285-292).
|
||||
|
||||
### Security
|
||||
- **Translate validates every config** — registry membership rejects unknown IDs; UNSPECIFIED slot drops; ID-less drops; raw config copied (not aliased) at translate.go:109.
|
||||
- **`AuthHeader`/`StripHeaders` only reachable via `UpstreamRewrite`** — regular mutation surface goes through the framework denylist (`Authorization`/`Cookie` blocked); only the router middleware can replace `Authorization` (reverseproxy.go:296-304). Confirm in module 30 nothing outside the proxy-trusted path populates `UpstreamRewrite.AuthHeader`.
|
||||
- **`stampNetBirdIdentity` strips client-sent values first** (reverseproxy.go:742-743) — anti-spoof for `X-NetBird-User`/`X-NetBird-Groups`; control chars filtered; comma-bearing labels dropped (reverseproxy_test.go:1217/:1243/:1193).
|
||||
- **Auth → group propagation** — `auth/middleware_test.go:322` and `:366` cover the contract. If auth ever stops calling `ValidateTunnelPeer` for Private services, every agent-network request silently denies.
|
||||
|
||||
### Concurrency
|
||||
- **Chain replacement under in-flight requests** — `findTargetForRequest` takes `mappingsMux.RLock`; `AddMapping` writes. `resolveChain` calls `ChainFor` once; even if `Rebuild` swaps mid-request, in-flight requests keep running on the captured pointer.
|
||||
- **`CapturedData` mutation across slots** — accessors take `sync.RWMutex`; slices deep-copied on both Set and Get. Verify no caller mutates the returned slice expecting it to land back.
|
||||
- **`Manager.Invalidate` race** — `removeMapping` invalidates after `cleanupMappingRoutes`; mapping read happens before chain resolution, so requests before invalidate run captured chains; later ones fail `findTargetForRequest`.
|
||||
- **`Logger.log` goroutine** — `logSem` caps at `maxLogWorkers = 4096`; overflow → `dropped.Add(1)` + debug log. Middleware test uses a buffered channel and 150ms negative-assertion window — review whether 150ms holds on slow CI.
|
||||
|
||||
### Backward compatibility
|
||||
- **Non-agent-network services unaffected** — `protoToMapping` reads new fields only when `opts != nil`; defaults leave `Middlewares`/`CaptureConfig` nil → chain resolves nil → fast path. Existing `reverseproxy_test.go` (non-chain) still passes.
|
||||
- **`disable_access_log` is proto field 13, default false** — every existing target unset; gate is no-op. Locked by `TestMiddleware_SuppressAccessLog_DefaultEmitsLog` (middleware_test.go:104).
|
||||
- **`Server` additions optional** — 256 MiB default when `MiddlewareCaptureBudgetBytes ≤ 0` (server.go:1997-2000).
|
||||
|
||||
### Performance
|
||||
- **Translate cost per push** — O(n) with per-entry registry lookup and `config_json` copy; negligible vs. the upstream gRPC unmarshal.
|
||||
- **Empty-chain hot path** — one `ChainFor` map lookup + one `chain.Empty()` check; no allocation delta vs. pre-PR.
|
||||
- **Body capture buffer churn** — `bodytap.CaptureRequest` allocates `MaxRequestBytes` per chain-hitting request; `releaseBudget` ties allocation to the 256 MiB proxy-wide budget. Confirm in module 30 the budget is a hard cap.
|
||||
|
||||
### Observability
|
||||
- **Metrics** — `Metrics.Meter()` shared with `middleware.NewMetrics` (server.go:1990-1993) so middleware instruments land in the same prometheus exporter. No new metrics defined here.
|
||||
- **Access-log accuracy** — every entry carries `AgentNetwork`; terminal-slot metadata merged into `CapturedData.Metadata` (reverseproxy.go:238-241).
|
||||
- **Deny logs at `Infof`** (reverseproxy.go:170) — review whether `Info` is too noisy at high deny rates; consider Debug or rate-limit.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| proxy/middleware_translate_test.go | Empty/nil → nil; field preservation; unknown ID skip; nil registry permissive; timeout clamping; fail-mode + slot incl. UNSPECIFIED-drop; empty-ID drop; truncation above + at `MaxMiddlewaresPerChain` |
|
||||
| proxy/internal/proxy/reverseproxy_test.go | Rewrite host/headers/cookies/query; trusted proxy; path forwarding; classifyProxyError; X-NetBird-User/Groups anti-spoof + CSV-join + control-char/comma rejection + fallback-to-ID; `TestBuildRequestInput_PropagatesIdentityAndGroups` (UserGroups/Email/GroupNames/AgentNetwork reach `middleware.Input`) |
|
||||
| proxy/internal/proxy/agent_network_chain_realstack_test.go | **The end-to-end integration test.** Drives a real agent-network request through `ReverseProxy.ServeHTTP` with the chain the synthesizer produces, against an in-process management gRPC (bufconn) backed by a real sqlite store + real `agentnetwork.Manager`, plus an `httptest` upstream — no external infrastructure or real LLM. Guarantees: (1) response-leg `respInput` carries `UserGroups` so `llm_limit_record` ships non-empty `group_ids` and the admin-group consumption row increments; (2) `RedactPii=true` redacts both prompt and completion on captured metadata; (3) the full chain runs against a real management stack. **Line 189-211 inlines the proto→Spec mapping** instead of calling the proxy's private `translateMiddlewareConfig` — keep that inline mirror in sync with `proxy/middleware_translate.go` or the test silently diverges from production. |
|
||||
| proxy/internal/accesslog/middleware_test.go | `SuppressAccessLog=true` skips `SendAccessLog` (150ms negative wait); default emits one send (2s positive); usage tracking runs under suppression |
|
||||
| proxy/internal/auth/middleware_test.go | `TestProtect_PrivateService_TunnelPeerGroupsPropagate` proves `peer_group_ids` reach `CapturedData.UserGroups`; `TestProtect_PrivateService_TunnelPeerDenied` proves rejected peers 403 without reaching the handler |
|
||||
|
||||
The integration test runs in a few seconds with no external infrastructure — exercising the real synthesizer, `Manager.Rebuild`, `ServeHTTP` dispatch, and `llm_limit_record` writing a real consumption row through the real `agentnetwork.Manager` over real gRPC.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- **Translator does not validate `RawConfig` JSON** — factory's job at `New([]byte)`. Confirm in module 30 that a per-binding factory failure doesn't poison the rest of the chain.
|
||||
- **No throttle on management push rate** — every `MODIFIED` triggers `Manager.Rebuild`. Mitigation upstream.
|
||||
- **Streaming responses (SSE)** — body capture is streaming-aware, but response-leg middleware runs only after the response completes; long SSE streams delay `llm_limit_record` until close.
|
||||
- **OIDC-only path doesn't carry tunnel-peer groups** — agent-network synth services rely on the Private tunnel-peer path; JWT groups claim is the only carrier for non-Private OIDC.
|
||||
- **`agent_network` flag on L4 entries** not added; HTTP-only.
|
||||
- **`mw.capture.bypass_reason` metadata key** documented at reverseproxy.go:151,184; namespace this in module 30/31 to avoid collisions.
|
||||
|
||||
## Cross-references
|
||||
- Upstream: [shared/api](10-shared-api.md), [proxy/middleware-framework](30-proxy-middleware-framework.md), [proxy/middleware-builtin](31-proxy-middleware-builtin.md), [proxy/llm-parsers](32-proxy-llm-parsers.md)
|
||||
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level: [../00-overview.md](../00-overview.md)
|
||||
228
docs/agent-networks/modules/40-dashboard.md
Normal file
228
docs/agent-networks/modules/40-dashboard.md
Normal file
@@ -0,0 +1,228 @@
|
||||
# dashboard — UI for agent-networks
|
||||
|
||||
This module documents code that lives in the **dashboard repo** (under
|
||||
`src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`), not
|
||||
in this repo. It is co-located here so backend readers see the full picture.
|
||||
|
||||
> **Risk level:** Medium. The new surface is isolated under `src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`, but it also reshapes the sidebar, splits `/peers`, renames `reverse-proxy/clusters` → `self-hosted-proxies`, and overlays the Control Center graph. Regressions here would be cross-cutting.
|
||||
> **Backward-compat impact:** Additive on the API side. Breaking on URL/navigation: `/peers` redirects to `/peers/devices` (src/app/(dashboard)/peers/page.tsx:7-15), `/reverse-proxy/clusters` was renamed to `/reverse-proxy/self-hosted-proxies`, the sidebar lost Access Control / Networks / Reverse Proxy / DNS / standalone Guardrails / Consumption / Activity (Navigation.tsx:165-171 — routes still resolve via URL), and the standalone `/agent-network/{access-log,consumption,global-controls}` routes are gone in favor of `/agent-network/observability`.
|
||||
|
||||
## Module boundary
|
||||
|
||||
The dashboard is the only place an operator interacts with agent-networks: provider catalog, configured providers, policies, guardrails, account-level budget rules, account settings (collection / redaction toggles), per-request access log, and consumption rollups all render, paginate, and edit here. Data flows in via SWR (`useFetchApi`) keyed by REST URL. One big context provider (`src/modules/agent-network/AIProvidersProvider.tsx`) aggregates five resources (providers, policies, guardrails, budget rules, settings) plus the proxy access-log stream filtered to `agent_network=true`, and exposes `add* / update* / toggle* / delete*` mutators that call through `useApiCall` and re-`mutate()` SWR. Pages mount the provider once at the top and compose presentational tables and modals beneath. The control-center page additionally fetches `/agent-network/{providers,policies}` directly (control-center/page.tsx:123-130) to overlay graph nodes.
|
||||
|
||||
## What the UI delivers
|
||||
|
||||
- **AI Observability** page with four tabs: Access Logs, Budget Dashboard,
|
||||
Budget Settings, Log Settings (replaces the standalone access-log,
|
||||
consumption, and global-controls routes).
|
||||
- **Providers** page: provider catalog + connect/edit wizard with per-vendor
|
||||
copy (LiteLLM, Portkey, Bifrost, Cloudflare, Vercel, OpenRouter, custom).
|
||||
- **Policies** page: group → provider authorization with per-policy Limits
|
||||
(minute-granular windows) + guardrail attach.
|
||||
- **Guardrails** page: reusable model-allowlist + prompt-capture sets.
|
||||
- **Account controls**: Log Collection / Prompt Collection / Redact PII toggles.
|
||||
- **Budget rules**: account-level rules reusing the policy Limits UI.
|
||||
- **Control Center overlay**: provider + agent-policy nodes on the graph.
|
||||
- **Navigation + peers reshaping**: peers split into Devices / Agents,
|
||||
`reverse-proxy/clusters` renamed to `self-hosted-proxies`, sidebar
|
||||
repackaged for agent-network focus.
|
||||
|
||||
## Surface added
|
||||
|
||||
### New pages
|
||||
|
||||
| Route | Purpose | Backing module(s) |
|
||||
| ----- | ------- | ----------------- |
|
||||
| `/agent-network` | Redirect to `/agent-network/providers` | page.tsx:7-15 |
|
||||
| `/agent-network/providers` | List + connect providers; header surfaces per-account base URL | providers/page.tsx + AgentProvidersTable + AIProviderModal |
|
||||
| `/agent-network/policies` | Group → Provider authorization with per-policy Limits + Guardrail attach | policies/page.tsx + AgentPoliciesTable + AgentPolicyModal |
|
||||
| `/agent-network/guardrails` | Reusable guardrail sets (model allowlist + prompt capture) | guardrails/page.tsx + AgentGuardrailsTable + AgentGuardrailModal |
|
||||
| `/agent-network/observability` | Tabs: Access Logs / Budget Dashboard / Budget Settings / Log Settings | observability/page.tsx |
|
||||
| `/peers/devices`, `/peers/agents` | Split of `/peers`, shared via `PeersListView` keyed by `kind` | peers/{devices,agents}/page.tsx |
|
||||
| `/reverse-proxy/self-hosted-proxies` | Renamed from `clusters` | self-hosted-proxies/page.tsx |
|
||||
|
||||
Removed in favor of `/agent-network/observability`: `/agent-network/access-log`, `/agent-network/consumption`, `/agent-network/global-controls`.
|
||||
|
||||
### New modules under src/modules/agent-network
|
||||
|
||||
| File | Role |
|
||||
| ---- | ---- |
|
||||
| AIProvidersProvider.tsx (~1158 LOC) | Aggregates every agent-network resource via SWR; normalises snake↔camel; exposes mutators; holds wizard-open state |
|
||||
| AIProviderModal.tsx (~1268 LOC) | Connect / edit provider wizard with per-vendor copy (Bifrost, Portkey, LiteLLM, Cloudflare, Vercel, OpenRouter, custom) |
|
||||
| AIProviderLogo + useProviderCatalog | Catalog-driven brand swatch + SWR hook over `/agent-network/catalog/providers` |
|
||||
| AgentPoliciesTable + AgentPolicyModal + AgentPolicyGuardrailsTab + AgentPolicyLimitsTab | Policies; modal has 3 tabs (Rule, Limits, Guardrails) |
|
||||
| AgentGuardrailsTable + AgentGuardrailModal + AgentGuardrailBrowseModal + AgentGuardrailChecksCell | Guardrails CRUD + attach-from-policy |
|
||||
| AgentBudgetRulesTable + AgentBudgetRuleModal | Account-level budget rules; modal reuses AgentPolicyLimitsTab verbatim |
|
||||
| AgentAccountControlsCard | Three account-wide toggles (Log Collection / Prompt Collection / Redact PII) |
|
||||
| AgentAccessLogTable + AgentAccessLogExpandedRow | Access log on `/events/proxy?agent_network=true` |
|
||||
| AgentConsumptionPanel + AgentConsumptionTable | Token + cost panel: charts + counter table |
|
||||
| table/AgentProvidersTable + AgentProviderActionCell | Providers table + per-row actions |
|
||||
| data/mockData.ts | Domain types and a few residual `MOCK_*` constants (see scrutinize) |
|
||||
|
||||
### Touched non-agent-network areas
|
||||
|
||||
- **control-center**: agent-network overlay (provider + agent-policy nodes); removed the All Networks dropdown; hid the Networks tab in FlowSelector (FlowSelector.tsx:9-14 — enum value kept so `?tab=networks` still type-checks); wrapped `ControlCenterView` in `AIProvidersProvider` (page.tsx:73-83); `agentPolicyNode` clicks routed to a separate state slot (page.tsx:1871-1874). New node renderers: nodes/ProviderNode.tsx, nodes/AgentPolicyNode.tsx (registered at utils/nodes.ts:21-22).
|
||||
- **peers**: Split into Devices and Agents sub-routes; shared via `PeersListView` keyed by `kind` (PeersListView.tsx:24-95). New compact-toolbar `UserFilterSelector` (users/UserFilterSelector.tsx).
|
||||
- **reverse-proxy**: Folder rename `clusters/` → `self-hosted-proxies/`; deleted `ClustersFeaturesCell.tsx`, `ClusterTypeIndicator.tsx`; new ReverseProxyClusterTargetSelector for cluster target type; Private toggle on target modal; body-capture knobs removed; new ReverseProxyEventExpandedRow.
|
||||
- **events**: `ReverseProxyEventsUserCell` rewritten with user + peer fallback (ReverseProxyEventsUserCell.tsx:14-21), shared with the access-log table.
|
||||
- **navigation**: Full repackaging in Navigation.tsx — Agent Network items flattened (no collapsible parent), distinct icons per item; Access Control, Networks, Reverse Proxy, DNS, standalone Guardrails, Consumption, Activity removed (still URL-reachable, per lines 165-171).
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### Page → Provider → Table/Modal hierarchy
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
Nav[Navigation.tsx]
|
||||
Nav --> ProvidersPage[/agent-network/providers/]
|
||||
Nav --> PoliciesPage[/agent-network/policies/]
|
||||
Nav --> GuardrailsPage[/agent-network/guardrails/]
|
||||
Nav --> ObsPage[/agent-network/observability/]
|
||||
|
||||
ProvidersPage --> AIPP1[AIProvidersProvider]
|
||||
PoliciesPage --> AIPP2[AIProvidersProvider]
|
||||
GuardrailsPage --> AIPP3[AIProvidersProvider]
|
||||
ObsPage --> AIPP4[AIProvidersProvider]
|
||||
ObsPage -.wraps.-> GroupsProvider
|
||||
ObsPage -.wraps.-> PeersProvider
|
||||
|
||||
AIPP1 --> ProvTable[AgentProvidersTable]
|
||||
ProvTable --> ProvModal[AIProviderModal]
|
||||
AIPP2 --> PolTable[AgentPoliciesTable]
|
||||
PolTable --> PolModal[AgentPolicyModal]
|
||||
PolModal --> PolGuardTab[AgentPolicyGuardrailsTab]
|
||||
PolModal --> PolLimitsTab[AgentPolicyLimitsTab]
|
||||
PolGuardTab --> GuardBrowse[AgentGuardrailBrowseModal]
|
||||
PolGuardTab --> GuardModal[AgentGuardrailModal]
|
||||
AIPP3 --> GuardTable[AgentGuardrailsTable]
|
||||
GuardTable --> GuardModal
|
||||
AIPP4 --> Tabs[Tabs]
|
||||
Tabs --> AccessLog[AgentAccessLogTable]
|
||||
Tabs --> Consumption[AgentConsumptionPanel]
|
||||
Tabs --> BudgetRules[AgentBudgetRulesTable]
|
||||
Tabs --> AccountCtl[AgentAccountControlsCard]
|
||||
BudgetRules --> BudgetModal[AgentBudgetRuleModal]
|
||||
BudgetModal -.reuses.-> PolLimitsTab
|
||||
```
|
||||
|
||||
### AI Observability tab page
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
Page[AIObservabilityPage] --> RA[RestrictedAccess<br/>permission.services.read]
|
||||
RA --> GP[GroupsProvider]
|
||||
GP --> PP[PeersProvider]
|
||||
PP --> AIP[AIProvidersProvider]
|
||||
AIP --> Tabs[Tabs / TabsList]
|
||||
Tabs --> T1[Access Logs<br/>AgentAccessLogTable]
|
||||
Tabs --> T2[Budget Dashboard<br/>AgentConsumptionPanel]
|
||||
Tabs --> T3[Budget Settings<br/>AgentBudgetRulesTable]
|
||||
Tabs --> T4[Log Settings<br/>AgentAccountControlsCard]
|
||||
T1 -.GET.-> EP[/events/proxy?agent_network=true/]
|
||||
T2 -.GET poll 5s.-> CONS[/agent-network/consumption/]
|
||||
T3 -.GET/PUT.-> BR[/agent-network/budget-rules/]
|
||||
T4 -.GET/PUT.-> ST[/agent-network/settings/]
|
||||
```
|
||||
|
||||
### Data fetch path
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
Page[Page component] --> Prov[AIProvidersProvider]
|
||||
Prov -->|useFetchApi| SWR[(SWR cache<br/>key = URL)]
|
||||
SWR -.GET.-> P[/agent-network/providers/]
|
||||
SWR -.GET.-> POL[/agent-network/policies/]
|
||||
SWR -.GET.-> G[/agent-network/guardrails/]
|
||||
SWR -.GET.-> BR[/agent-network/budget-rules/]
|
||||
SWR -.GET ignoreError.-> ST[/agent-network/settings/]
|
||||
SWR -.GET.-> CAT[/agent-network/catalog/providers/]
|
||||
SWR -.GET pageSize=100.-> EVT[/events/proxy agent_network=true/]
|
||||
Prov --> Mut[useApiCall.post/put/del]
|
||||
Mut -.on success.-> MutateSWR[SWR mutate keys]
|
||||
Prov --> Children[Tables / Modals via useAIProviders]
|
||||
```
|
||||
|
||||
Every list view reaches management through SWR over `/api/agent-network/*`. The provider context maps snake-case payloads to camelCase domain types (`fromAPI`, `policyFromAPI`, `guardrailFromAPI`, `budgetRuleFromAPI`, `settingsFromAPI`, `accessLogFromAPI` — AIProvidersProvider.tsx:138-562) and back via matching `*ToRequest` adaptors. The access log piggy-backs on `/events/proxy` with `agent_network=true&page_size=100` (line 707-709) and decodes LLM-specific fields from per-event `metadata`. Group IDs on events are resolved to current names through the surrounding GroupsProvider catalog (lines 515-521, 717-731) — no extra round trip. Mutators run `*ToRequest`, await `useApiCall.post/put/del`, call SWR `mutate()`, then `notify`. Errors caught and surfaced via `notify` — no exceptions escape into render. The Connect Provider modal's open state lives in the provider itself (`isWizardOpen` at lines 732-735) so the providers-page empty-state CTA and the table's + button share one modal. Control-center re-fetches `/agent-network/{providers,policies}` directly on top of `AIProvidersProvider` — SWR de-dupes but the code path is harder to reason about.
|
||||
|
||||
## Public contracts consumed
|
||||
|
||||
- `GET/POST /api/agent-network/providers`, `PUT/DELETE /:id`
|
||||
- `GET/POST /api/agent-network/policies`, `PUT/DELETE /:id`
|
||||
- `GET/POST /api/agent-network/guardrails`, `PUT/DELETE /:id`
|
||||
- `GET/POST /api/agent-network/budget-rules`, `PUT/DELETE /:id`
|
||||
- `GET/PUT /api/agent-network/settings` (ignoreError-tolerant; 404 = not yet bootstrapped — auto-bootstrap on first provider create via `bootstrap_cluster` field — AIProvidersProvider.tsx:737-760)
|
||||
- `GET /api/agent-network/catalog/providers` (read-only declarative; backend owns vendor list, IDs, brand colors, models, extra_headers, identity_injection — useProviderCatalog.ts:6-95)
|
||||
- `GET /api/agent-network/consumption` (polled every 5s on Budget Dashboard — ConsumptionPanel.tsx:53,65-71)
|
||||
- `GET /api/events/proxy?agent_network=true&page_size=100` (shared with Proxy Events)
|
||||
- `permission?.services?.read` gates every agent-network route via RestrictedAccess.
|
||||
|
||||
`AIProviderId` is a closed union in dashboard types (data/mockData.ts:8-21) but the converter tolerates anything the backend ships — unknown ids fall through to `"custom"` (AIProvidersProvider.tsx:497-506). Catalog values are pure read-through: anything declared in `extra_headers` renders in the modal automatically, copy keyed by header name (`EXTRA_HEADER_UI` in AIProviderModal.tsx:61-89), labeled-fallback for unknown ones.
|
||||
|
||||
## Invariants
|
||||
|
||||
- Provider context wrap order on user-attribution pages: `GroupsProvider > PeersProvider > AIProvidersProvider` (observability/page.tsx:87-89). Reverse it and access-log group resolution silently drops names.
|
||||
- Every agent-network route checks `permission?.services?.read` via `RestrictedAccess` (observability/page.tsx:85, providers/page.tsx:184, policies/page.tsx:53, guardrails/page.tsx:55).
|
||||
- Modal `key={open ? 1 : 0}` pattern is used to force unmount/remount on close so internal `useState` resets between edits (AgentBudgetRuleModal.tsx:60, AgentPolicyModal.tsx:66). Removing this would leak prior-row state into a new-row session.
|
||||
- `mockData.ts` is the canonical home for ALL agent-network domain types; `MOCK_*` constants must never reach a production code path. One leak remains (below).
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
|
||||
- **Tab-state URL hand-off is one-way.** observability/page.tsx:53-58 reads `?tab=` on mount (despite the file comment at line 28 saying URL hand-off is future) but `setTab` does NOT push back, so reload preserves the chosen tab only if it came in via the link. Inconsistent with control-center (page.tsx:1817-1831).
|
||||
- **Provider overlay runs only in `applySingleGroupView` / `applyPeerView`** (control-center/page.tsx:557, 1159-1166). User view does NOT show providers — if agent-network is a primary lens, that's a gap.
|
||||
- **Two useEffects race to invalidate the control-center layout.** page.tsx:1655-1657 drops `layoutInitialized` when `agentPolicies` / `agentProviders` arrive; the main effect (1786-1799) also lists them as deps. Functional but fragile — watch for flash-of-empty-graph.
|
||||
- **`updateProvider` / `updatePolicy` / `updateBudgetRule` use `??` on `enabled`** (AIProvidersProvider.tsx:784, 859, 1018). Toggle paths are safe; any caller sending `enabled: false` thinking "leave it off" gets `existing.enabled` instead. Audit modal callers.
|
||||
- **Form validation in modals is minimal.** Window-seconds picker — mockData.ts:209-215 documents "minimum 60 — one minute" but there is no matching UI guard in PolicyLimitsTab; the backend validator is the enforcement point.
|
||||
|
||||
### Security
|
||||
|
||||
- **No client-side enforcement claims** — every cap, allowlist, and toggle is display + edit; proxy is the source of truth for deny decisions (AccessLogTable.tsx:177-191 renders backend-emitted `denyReason` as-is).
|
||||
- **Prompt display is gated by what the backend stamps.** When `enable_prompt_collection` is OFF the proxy must not put prompt/completion into event metadata; the dashboard renders whatever it gets verbatim (AccessLogTable lines 532-534, AccessLogExpandedRow.tsx:42-57). No UI filter on top of backend collection switches.
|
||||
- Account Controls disables `Redact PII` when `Prompt Collection` is off (AgentAccountControlsCard.tsx:122) and clears it on off-transition (line 100), but relies on backend to enforce the same gate at write — confirm PUT handler rejects `redact_pii=true && enable_prompt_collection=false`.
|
||||
- **Bifrost identity-header overrides**: empty-string vs nil semantics documented in AIProvidersProvider.tsx:772-781 ("omitted = preserve, empty = explicit clear"). Mishandling could leak group attribution to a header the operator thought disabled. Focused read of Bifrost code path in AIProviderModal.tsx recommended.
|
||||
|
||||
### Accessibility
|
||||
|
||||
- Observability TabsList (observability/page.tsx:96-113) uses the shared Tabs component — should inherit Radix roving-tabindex. All four TabsTriggers carry only icon + text, no `aria-label`; fine because text is visible.
|
||||
- Modal focus traps are inherited from the shared Modal; agent-network modals don't override them. Quick keyboard pass recommended.
|
||||
- `EndpointBadge` Copy button (providers/page.tsx:66-76) has an `aria-label`, good.
|
||||
|
||||
### Performance
|
||||
|
||||
- `AgentConsumptionPanel` polls `/agent-network/consumption` every 5s (ConsumptionPanel.tsx:53,70). Tab switches unmount the panel, so the poll stops — verify in network panel.
|
||||
- `AgentAccessLogTable` is hard-capped at 100 rows via `page_size=100` (AIProvidersProvider.tsx:707-709). Server-side pagination is future work; high-traffic tenants miss everything past row 100 — known limitation.
|
||||
- Observability page mounts providers ONCE at page level (observability/page.tsx:87-89); tab switches keep SWR cache hot. Moving the provider mount inside `TabsContent` would re-fetch the access log on every switch.
|
||||
|
||||
### Visual consistency
|
||||
|
||||
- The observability tab style mirrors peers/page.tsx. Outer Tabs `pt-4 pb-0 mb-0`, TabsList `px-8` (observability/page.tsx:94-96) — confirm chrome height matches so the page doesn't visually jump.
|
||||
- Sidebar: `Boxes` for Providers, `AccessControlIcon` for Policies, `TelescopeIcon` for AI Observability (Navigation.tsx:113,120,133). Reusing `AccessControlIcon` makes Policies look identical to the (now hidden) Access Control item — if Access Control ever comes back, they collide.
|
||||
- `AgentNetworkIcon` is used in breadcrumbs on every agent-network page but NOT in the sidebar (per-page icons instead). Deliberate departure — record so it doesn't get reverted.
|
||||
|
||||
## Test coverage
|
||||
|
||||
- **Cypress**: One file (`cypress/e2e/test.cy.ts`) covering only the install-page copy-to-clipboard flow. NOTHING covers agent-network UI.
|
||||
- **Component / unit tests**: `src/utils/version.test.ts` is the only `.test.*` file in the repo. The agent-network modules ship without component tests.
|
||||
- Data-cy hooks exist on key controls: `save-account-controls` (AgentAccountControlsCard.tsx:71), `enable-log-collection`, `enable-prompt-collection`, `redact-pii`, plus existing `data-cy={policy.name}` / `data-cy={provider.name}` on ActiveInactiveRow. Sufficient hooks for Cypress flows; none written yet.
|
||||
- **Tooling gap (pre-existing):** `npm run lint` (`next lint`) is broken in Next 16 — the `lint` subcommand was removed from the Next CLI in 16.x, so the dashboard effectively has no working lint gate. The fix is to add either a flat-config `eslint .` script or wire ESLint via an explicit `eslint-config-next` invocation.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- **`data/mockData.ts` still contains `MOCK_GROUPS`, `MOCK_PROVIDERS`, `MOCK_PEERS`.** Only `MOCK_GROUPS` is referenced from production — AgentPoliciesTable.tsx:45,76 uses it as a name-lookup fallback when a policy references a group ID the real GroupsProvider doesn't know about. `MOCK_PROVIDERS` / `MOCK_PEERS` are unreferenced; safe to delete. The file is `/* eslint-disable */` so dead-code warnings don't flag them.
|
||||
- **Tab-state URL hand-off on observability page is one-way** (read-only).
|
||||
- **Access log hard-capped at 100 rows**; no server-side pagination.
|
||||
- **No optimistic updates.** All mutations are round-trip; failures rollback via SWR revalidation.
|
||||
- **`FlowView.NETWORKS` retained but hidden** from FlowSelector (FlowSelector.tsx:9-14). Old `?tab=networks` links still route to the hidden view because `applyNetworksView` still runs.
|
||||
- **Redirects are not query-preserving** — `router.replace("/peers/devices")` (peers/page.tsx:13) strips any incoming filter params.
|
||||
- **Control-center cross-fetches** `/agent-network/{providers,policies}` directly on top of `AIProvidersProvider`. Could be collapsed.
|
||||
- **Sidebar permanently hides Access Control, Networks, Reverse Proxy, standalone Guardrails, DNS, Activity, Consumption.** Routes still resolve via URL (Navigation.tsx:165-171); intentional.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Upstream API contracts: [shared/api](10-shared-api.md)
|
||||
- Backend persistence: [management/store](20-management-store.md)
|
||||
- Backend handler wiring: [management/handlers + wiring](22-management-handlers-wiring.md)
|
||||
- End-to-end flow narrative: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level overview: [../00-overview.md](../00-overview.md)
|
||||
251
docs/agent-networks/modules/50-path-routed-providers.md
Normal file
251
docs/agent-networks/modules/50-path-routed-providers.md
Normal file
@@ -0,0 +1,251 @@
|
||||
# path-routed providers — Vertex AI + Bedrock
|
||||
|
||||
This guide pulls the **path-routed** provider story together in one place
|
||||
because it crosses the catalog, the synthesiser, the request parser, and the
|
||||
router. The relevant building blocks are the `llm_router` /
|
||||
`llm_request_parser` middlewares
|
||||
([31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md)), the
|
||||
per-provider parser surface ([32-proxy-llm-parsers.md](32-proxy-llm-parsers.md)),
|
||||
and the synthesiser's catalog → `ProviderRoute` mapping
|
||||
([21-management-agentnetwork.md](21-management-agentnetwork.md)).
|
||||
|
||||
Sibling modules: [31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md)
|
||||
(router + request parser) and [32-proxy-llm-parsers.md](32-proxy-llm-parsers.md)
|
||||
(Bedrock parser + pricing).
|
||||
|
||||
---
|
||||
|
||||
## What "path-routed" means
|
||||
|
||||
Most catalog providers carry the model in the request **body** (`{"model": …}`),
|
||||
so `llm_router` selects an upstream by matching the model name against each
|
||||
provider's `Models` claim. Two providers instead carry the model in the **URL
|
||||
path**, so they are routed by path before the model/vendor table is consulted:
|
||||
|
||||
| Catalog id | Style flag | Request path shape |
|
||||
|---|---|---|
|
||||
| `vertex_ai_api` | `IsVertexPathStyle` → `ProviderRoute.Vertex` | `/v1/projects/{project}/locations/{region}/publishers/{publisher}/models/{model}:{action}` |
|
||||
| `bedrock_api` | `IsBedrockPathStyle` → `ProviderRoute.Bedrock` | `/model/{modelId}/{action}` (optionally behind `/bedrock`) |
|
||||
|
||||
The catalog declares the style with
|
||||
[`catalog.IsVertexPathStyle` / `catalog.IsBedrockPathStyle`](../../../management/server/agentnetwork/catalog/catalog.go)
|
||||
and the synthesiser copies the result onto the router route as the `Vertex` /
|
||||
`Bedrock` booleans
|
||||
([synthesizer.go:450-451](../../../management/server/agentnetwork/synthesizer.go)).
|
||||
On the request leg `llm_router.Invoke` dispatches `isVertexPath` / `isBedrockPath`
|
||||
**before** the model lookup
|
||||
([llm_router/middleware.go:138-216](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
so a model the parser extracted from the path can't be claimed by a same-vendor
|
||||
*body-routed* provider (e.g. `claude-*` on `api.anthropic.com`).
|
||||
|
||||
## Google Vertex AI (`vertex_ai_api`)
|
||||
|
||||
### Catalog entry
|
||||
|
||||
`KindProvider`, parser surface left unset on the catalog entry — the request
|
||||
parser picks the parser from the URL **publisher** segment, not from
|
||||
`ParserID`. Upstream host is `<region>-aiplatform.googleapis.com`
|
||||
(`https://aiplatform.googleapis.com` for the `global` location). The catalog
|
||||
lists the Claude-on-Vertex lineup (`claude-opus-4-*`, `claude-sonnet-4-*`,
|
||||
`claude-haiku-4-5`, `claude-fable-5`) at the same per-token rates as the
|
||||
first-party Anthropic entry
|
||||
([catalog.go:333-363](../../../management/server/agentnetwork/catalog/catalog.go)).
|
||||
|
||||
### Credential — service-account OAuth (`keyfile::`)
|
||||
|
||||
Vertex does **not** accept a static API key. The operator sets the provider
|
||||
`api_key` to:
|
||||
|
||||
```
|
||||
keyfile::<base64 of the GCP service-account JSON key>
|
||||
```
|
||||
|
||||
The synthesiser recognises the `keyfile::` prefix in `providerAuthHeader`
|
||||
([synthesizer.go:897-903](../../../management/server/agentnetwork/synthesizer.go)),
|
||||
emits **no** static auth value, and carries the base64 key material on the
|
||||
route as `GCPServiceAccountKeyB64`
|
||||
([factory.go:56-61](../../../proxy/internal/middleware/builtin/llm_router/factory.go)).
|
||||
At request time the router mints a short-lived OAuth2 access token from the key
|
||||
(cloud-platform scope) and injects `Authorization: Bearer <access-token>` —
|
||||
never the key itself
|
||||
([llm_router/middleware.go:621-692](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)):
|
||||
|
||||
- One auto-refreshing `oauth2.TokenSource` is cached per key (keyed by a
|
||||
SHA-256 of the base64 material), so token minting happens once and refreshes
|
||||
amortise across requests.
|
||||
- Mint / refresh is bounded by a 10s timeout HTTP client (`gcpTokenTimeout`) so
|
||||
a slow Google token endpoint can't hang the request.
|
||||
- A malformed key or an unreachable token endpoint fails the request with
|
||||
`llm_policy.upstream_auth_failed` at HTTP **502** (an upstream problem, not a
|
||||
policy denial) — see `denyUpstreamAuth`.
|
||||
|
||||
### Metering — Anthropic-on-Vertex only
|
||||
|
||||
The request parser extracts `{publisher, model, action}` from the path
|
||||
(`parseVertexPath`, [llm_request_parser/middleware.go:237-263](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)),
|
||||
strips the `@version` suffix from the model, and maps the publisher to a parser
|
||||
surface via `vertexPublisherVendor`:
|
||||
|
||||
- `anthropic` → `llm.provider="anthropic"` → metered through the Anthropic
|
||||
parser, priced under the **`anthropic`** block in `defaults_pricing.yaml`
|
||||
(the parser emits the standard Anthropic provider label, so Vertex Claude
|
||||
reuses first-party Anthropic prices).
|
||||
- `openai` → `llm.provider="openai"` (reserved; not in the catalog lineup
|
||||
today).
|
||||
- anything else (notably `google` / Gemini) → empty vendor → **no parser**.
|
||||
|
||||
**Gemini is intentionally denied as unmeterable.** When the parser emits no
|
||||
`llm.provider` for a Vertex publisher, `llm_router` returns
|
||||
`llm_policy.unmeterable_publisher` (403) rather than forwarding the request
|
||||
uncounted — serving it would bypass token / budget metering
|
||||
([llm_router/middleware.go:144-162, 712-728](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)).
|
||||
A Gemini parser would lift this restriction; until then the `google` publisher
|
||||
is omitted from the catalog.
|
||||
|
||||
> Caveat: cross-region inference profiles in `eu` / `apac` carry a ~10% price
|
||||
> premium that the base per-token rates do **not** model — cost annotations for
|
||||
> those regions read low. Operators who need exact regional billing override
|
||||
> the affected entries in `pricing.yaml`.
|
||||
|
||||
## AWS Bedrock (`bedrock_api`)
|
||||
|
||||
### Catalog entry
|
||||
|
||||
`KindProvider`, upstream host `bedrock-runtime.<region>.amazonaws.com`. Metered
|
||||
models are the Anthropic-on-Bedrock lineup (`anthropic.claude-*`) plus Amazon
|
||||
Nova and Llama 3.3 entries
|
||||
([catalog.go:300-332](../../../management/server/agentnetwork/catalog/catalog.go)).
|
||||
Anthropic-on-Bedrock reuses the first-party Claude prices (with additive cache
|
||||
buckets); Nova / Llama report no cache, so cost is `input + output`.
|
||||
|
||||
### Credential — static bearer token
|
||||
|
||||
Bedrock uses the **AWS Bedrock API key** as a static bearer. The operator sets
|
||||
the provider `api_key` directly (no `keyfile::` prefix); the catalog template
|
||||
is `Authorization: Bearer ${API_KEY}`
|
||||
([catalog.go:306-307](../../../management/server/agentnetwork/catalog/catalog.go)).
|
||||
No token minting — the synthesiser substitutes the key into the template and
|
||||
the router injects the resulting `Authorization` header after stripping inbound
|
||||
vendor auth (including client-supplied AWS SigV4 material: `X-Amz-Date`,
|
||||
`X-Amz-Security-Token`, `X-Amz-Content-Sha256`, see `strippedAuthHeaders`).
|
||||
|
||||
### Model id form — cross-region inference profiles
|
||||
|
||||
Bedrock model ids in the request path must be the cross-region
|
||||
**inference-profile** form, e.g.
|
||||
`eu.anthropic.claude-sonnet-4-5-20250929-v1:0`. The bare
|
||||
`anthropic.claude-…` id is rejected by AWS. `normalizeBedrockModel`
|
||||
([llm_request_parser/middleware.go:398-414](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go))
|
||||
strips the region prefix (`us.` / `eu.` / `apac.` / `global.`), an optional ARN
|
||||
wrapper, and the `-YYYYMMDD-vN[:N]` version/throughput suffix so the normalised
|
||||
id (`anthropic.claude-sonnet-4-5`) matches the catalog/pricing key.
|
||||
|
||||
### Supported endpoints + actions
|
||||
|
||||
`/model/{modelId}/{action}` where action ∈ `invoke`,
|
||||
`invoke-with-response-stream`, `converse`, `converse-stream`
|
||||
([llm_request_parser/middleware.go:363-390](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)).
|
||||
`invoke` / `converse` are non-streaming; the `-stream` actions set the streaming
|
||||
flag.
|
||||
|
||||
- **InvokeModel** body uses the vendor-native shape — for Anthropic that means
|
||||
`"anthropic_version":"bedrock-2023-05-31"` and snake_case usage with additive
|
||||
cache buckets.
|
||||
- **Converse** uses the unified camelCase shape with a precomputed `totalTokens`.
|
||||
- The `BedrockParser` reads both shapes on the response leg
|
||||
([bedrock.go](../../../proxy/internal/llm/bedrock.go)); the request parser
|
||||
doesn't need to distinguish them (`ParseRequest` is a no-op — model + stream
|
||||
come from the path).
|
||||
|
||||
### Streaming — AWS binary event-stream
|
||||
|
||||
The `-stream` actions return `application/vnd.amazon.eventstream` (the AWS
|
||||
binary event-stream framing), and streaming **is metered**.
|
||||
`accumulateBedrockStream`
|
||||
([llm_response_parser/streaming_bedrock.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock.go))
|
||||
decodes the frames with `aws-sdk-go-v2/aws/protocol/eventstream`:
|
||||
|
||||
- InvokeModel `chunk` frames wrap a base64 `{"bytes":…}` payload carrying a
|
||||
vendor-native (Anthropic) stream event — folded through the shared Anthropic
|
||||
stream accumulator.
|
||||
- Converse `contentBlockDelta` frames carry text; the trailing `metadata` frame
|
||||
carries the final usage block.
|
||||
- A truncated stream (cut at the body-tap capture cap) decodes best-effort:
|
||||
frames up to the cut are applied and partial usage is returned.
|
||||
|
||||
### Optional `/bedrock` gateway-namespace prefix
|
||||
|
||||
Clients may place an optional `/bedrock` prefix before the native path
|
||||
(`/bedrock/model/{modelId}/{action}`) to disambiguate Bedrock from other
|
||||
providers that also use `/model/...`. Both the request parser
|
||||
(`trimBedrockNamespace`) and the router (`splitBedrockNamespace`) accept it.
|
||||
When the prefix is present, the router sets
|
||||
`RewriteUpstream.StripPathPrefix = "/bedrock"` so the **native** path
|
||||
(`/model/...`) is what reaches `bedrock-runtime.<region>.amazonaws.com`
|
||||
([llm_router/middleware.go:168-184, 320-348](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)).
|
||||
|
||||
## Model allowlist on path-routed providers
|
||||
|
||||
Because the model lives in the URL rather than the body, a path-routed provider
|
||||
credential could otherwise be used for any model the upstream supports. The
|
||||
router still enforces the route's `Models` allowlist via `matchPathRoute`
|
||||
([llm_router/middleware.go:370-416](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)):
|
||||
|
||||
1. Filter to routes of the matching style (`Vertex` / `Bedrock`).
|
||||
2. Filter to routes whose `AllowedGroupIDs` authorise the caller's groups
|
||||
(else `no_authorised_provider`).
|
||||
3. Filter to routes that **claim the requested model**. As with body-routed
|
||||
providers, an **empty `Models` list = catch-all** (serve any model);
|
||||
a non-empty list serves only the listed models (else `model_not_routable`).
|
||||
4. Multiple survivors disambiguate by longest `UpstreamPath` prefix match.
|
||||
|
||||
So an operator who lists explicit models on a Vertex/Bedrock provider gets a
|
||||
hard allowlist; an operator who leaves `Models` empty accepts every model the
|
||||
upstream serves (still subject to the unmeterable-publisher gate on Vertex).
|
||||
|
||||
Model-less OpenAI endpoints (`GET /v1/models`) are **never** routed to a
|
||||
Vertex/Bedrock provider — `matchModelless` skips path-routed routes
|
||||
([llm_router/middleware.go:427-462](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
so a model-listing call can't be rewritten onto an upstream that would 404 it.
|
||||
|
||||
## Catalog ↔ pricing cross-check
|
||||
|
||||
Catalog prices and context windows are cross-checked against LiteLLM's
|
||||
`model_prices_and_context_window.json`. The proxy's embedded
|
||||
`defaults_pricing.yaml` covers **every metered first-party model** the catalog
|
||||
enumerates — guarded by
|
||||
`TestDefaultTable_FirstPartyModelCoverage`
|
||||
([pricing/defaults_coverage_test.go](../../../proxy/internal/llm/pricing/defaults_coverage_test.go)),
|
||||
which fails if a catalog model has no embedded price. Bedrock entries are keyed
|
||||
by the **normalised** id the request parser emits (region prefix + version
|
||||
suffix stripped). Vertex Claude carries no Bedrock-style prefix, so it prices
|
||||
straight off the `anthropic` block.
|
||||
|
||||
## Things to scrutinise
|
||||
|
||||
**Security.** The Vertex service-account key is never forwarded — only a minted
|
||||
short-lived bearer. Confirm the key material stays out of access logs (it lives
|
||||
on `ProviderRoute.GCPServiceAccountKeyB64`, not in any emitted metadata key).
|
||||
The unmeterable-publisher deny is the only thing standing between an
|
||||
operator-misconfigured Vertex provider and unmetered Gemini traffic; verify
|
||||
`vertexPublisherVendor` stays conservative (deny by default for unknown
|
||||
publishers).
|
||||
|
||||
**Correctness.** `normalizeBedrockModel` is the join between the wire id and the
|
||||
pricing key — a model that normalises to something not in `defaults_pricing.yaml`
|
||||
meters at `cost.skipped=unknown_model` rather than failing the request. The
|
||||
`/bedrock` prefix strip must run on both the parser side (so the model is
|
||||
extracted) and the router side (so the upstream path is native); a regression in
|
||||
either silently breaks the other.
|
||||
|
||||
**Metering caveats.** eu/apac cross-region Bedrock + Vertex profiles carry a
|
||||
~10% premium not modelled by base pricing — flagged in both the catalog comment
|
||||
and `defaults_pricing.yaml`. Operators needing exact regional billing override
|
||||
the relevant entries.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Router + request-parser detail: [31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md)
|
||||
- Bedrock parser + pricing + SSE / event-stream: [32-proxy-llm-parsers.md](32-proxy-llm-parsers.md)
|
||||
- Catalog → route synthesis + `keyfile::` handling: [21-management-agentnetwork.md](21-management-agentnetwork.md)
|
||||
- Overview: [../00-overview.md](../00-overview.md)
|
||||
2
go.mod
2
go.mod
@@ -35,6 +35,7 @@ require (
|
||||
github.com/DeRuina/timberjack v1.4.2
|
||||
github.com/awnumar/memguard v0.23.0
|
||||
github.com/aws/aws-sdk-go-v2 v1.38.3
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1
|
||||
github.com/aws/aws-sdk-go-v2/config v1.31.6
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.18.10
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
|
||||
@@ -156,7 +157,6 @@ require (
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect
|
||||
github.com/awnumar/memcall v0.4.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6 // indirect
|
||||
|
||||
616
infrastructure_files/getting-started-enterprise.sh
Executable file
616
infrastructure_files/getting-started-enterprise.sh
Executable file
@@ -0,0 +1,616 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# NetBird Enterprise — Getting Started
|
||||
# Single-node bootstrap for a self-hosted NetBird Enterprise stack with the
|
||||
# embedded identity provider. Owner is created via first-login flow.
|
||||
|
||||
SED_STRIP_PADDING='s/=//g'
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "docker-compose"
|
||||
return
|
||||
fi
|
||||
if docker compose --help &> /dev/null; then
|
||||
echo "docker compose"
|
||||
return
|
||||
fi
|
||||
echo "docker-compose is not installed or not in PATH. See https://docs.docker.com/engine/install/" > /dev/stderr
|
||||
exit 1
|
||||
}
|
||||
|
||||
check_openssl() {
|
||||
if ! command -v openssl &> /dev/null; then
|
||||
echo "openssl is not installed or not in PATH." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
rand_secret() {
|
||||
openssl rand -base64 32 | sed "$SED_STRIP_PADDING"
|
||||
}
|
||||
|
||||
rand_b64_key() {
|
||||
openssl rand -base64 32
|
||||
}
|
||||
|
||||
check_nb_domain() {
|
||||
local domain="$1"
|
||||
if [[ -z "$domain" ]]; then
|
||||
echo "The domain cannot be empty." > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
if [[ "$domain" == "netbird.example.com" ]]; then
|
||||
echo "The domain cannot be netbird.example.com" > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
if [[ "$domain" =~ ^[0-9.]+$ ]]; then
|
||||
echo "An IP address is not allowed. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
if [[ ! "$domain" =~ ^[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?(\.[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)+$ ]]; then
|
||||
echo "The value '$domain' is not a valid FQDN. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
check_domain_resolves() {
|
||||
local domain="$1"
|
||||
if command -v getent &> /dev/null && getent hosts "$domain" &> /dev/null; then return 0; fi
|
||||
if command -v host &> /dev/null && host "$domain" &> /dev/null; then return 0; fi
|
||||
if command -v dig &> /dev/null && [[ -n "$(dig +short "$domain" 2>/dev/null)" ]]; then return 0; fi
|
||||
if command -v nslookup &> /dev/null && nslookup "$domain" &> /dev/null; then return 0; fi
|
||||
return 1
|
||||
}
|
||||
|
||||
read_nb_domain() {
|
||||
local value=""
|
||||
echo -n "Enter the FQDN for NetBird (must resolve via DNS, e.g. netbird.my-domain.com): " > /dev/stderr
|
||||
read -r value < /dev/tty
|
||||
if ! check_nb_domain "$value"; then
|
||||
read_nb_domain
|
||||
return
|
||||
fi
|
||||
if ! check_domain_resolves "$value"; then
|
||||
echo "" > /dev/stderr
|
||||
echo "Warning: '$value' does not resolve via DNS from this host." > /dev/stderr
|
||||
echo "Caddy will not be able to issue TLS certificates until it does." > /dev/stderr
|
||||
local confirm=""
|
||||
echo -n "Continue anyway? [y/N]: " > /dev/stderr
|
||||
read -r confirm < /dev/tty
|
||||
if [[ ! "$confirm" =~ ^[Yy]$ ]]; then
|
||||
read_nb_domain
|
||||
return
|
||||
fi
|
||||
fi
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_required() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -r value < /dev/tty
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_secret() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -rs value < /dev/tty
|
||||
echo "" > /dev/stderr
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
# read_yes_no "<prompt>" [<default y|n>]
|
||||
read_yes_no() {
|
||||
local prompt="$1"
|
||||
local default="${2:-n}"
|
||||
local hint
|
||||
if [[ "$default" == "y" ]]; then
|
||||
hint="[Y/n]"
|
||||
else
|
||||
hint="[y/N]"
|
||||
fi
|
||||
echo -n "${prompt} ${hint}: " > /dev/stderr
|
||||
local ans=""
|
||||
read -r ans < /dev/tty
|
||||
if [[ -z "$ans" ]]; then
|
||||
ans="$default"
|
||||
fi
|
||||
case "$ans" in
|
||||
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
|
||||
*) echo "no" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
wait_postgres() {
|
||||
set +e
|
||||
echo -n "Waiting for postgres to become ready"
|
||||
local counter=1
|
||||
while true; do
|
||||
if $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" &> /dev/null; then
|
||||
break
|
||||
fi
|
||||
if [[ $counter -eq 60 ]]; then
|
||||
echo ""
|
||||
echo "Postgres is taking too long. Recent logs:"
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
|
||||
exit 1
|
||||
fi
|
||||
echo -n " ."
|
||||
sleep 2
|
||||
counter=$((counter + 1))
|
||||
done
|
||||
echo " done"
|
||||
set -e
|
||||
}
|
||||
|
||||
init_environment() {
|
||||
check_openssl
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
|
||||
if [[ -f .env ]] || [[ -f docker-compose.yml ]] || [[ -f config.yaml ]] || [[ -f Caddyfile ]]; then
|
||||
echo "Generated files already exist in $(pwd)."
|
||||
echo "If you want to reinitialize the environment, please remove them first:"
|
||||
echo " $DOCKER_COMPOSE_COMMAND down --volumes # removes all containers and volumes"
|
||||
echo " rm -f .env docker-compose.yml Caddyfile config.yaml"
|
||||
echo "Be aware this will remove all data from the database."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "NetBird Enterprise bootstrap"
|
||||
echo ""
|
||||
echo "Traffic flow:"
|
||||
echo " Enables traffic events logging on the management server."
|
||||
echo " When enabled, the NetBird stack also runs NATS along with two"
|
||||
echo " additional containers: netbird-receiver (the traffic log receiver"
|
||||
echo " service) and netbird-enricher (the traffic log enricher service)."
|
||||
echo " It still has to be turned on from the dashboard settings afterwards."
|
||||
echo " See https://docs.netbird.io/manage/activity/traffic-events-logging"
|
||||
NETBIRD_TRAFFIC_FLOW=$(read_yes_no "Enable traffic flow" "n")
|
||||
|
||||
echo ""
|
||||
NETBIRD_DOMAIN=$(read_nb_domain)
|
||||
|
||||
echo ""
|
||||
|
||||
NETBIRD_LICENSE_KEY=$(read_secret "Enter license key (input hidden)")
|
||||
|
||||
GHCR_USERNAME="netbirdExtAccess1"
|
||||
GHCR_TOKEN=$(read_secret "Enter GHCR token (input hidden)")
|
||||
|
||||
POSTGRES_USER="netbird"
|
||||
POSTGRES_DB="netbird"
|
||||
POSTGRES_PASSWORD=$(rand_secret)
|
||||
NETBIRD_ENCRYPTION_KEY=$(rand_b64_key)
|
||||
NETBIRD_RELAY_AUTH_SECRET=$(rand_secret)
|
||||
|
||||
POSTGRES_DSN="host=postgres user=${POSTGRES_USER} password=${POSTGRES_PASSWORD} dbname=${POSTGRES_DB} port=5432 sslmode=disable TimeZone=UTC"
|
||||
NETBIRD_RELAY_ENDPOINT="rels://${NETBIRD_DOMAIN}:443"
|
||||
|
||||
echo ""
|
||||
echo "Selected:"
|
||||
echo " Traffic flow: ${NETBIRD_TRAFFIC_FLOW}"
|
||||
echo " Domain: ${NETBIRD_DOMAIN}"
|
||||
echo ""
|
||||
echo "Rendering files into $(pwd) ..."
|
||||
install -m 600 /dev/null .env
|
||||
render_env >> .env
|
||||
render_docker_compose > docker-compose.yml
|
||||
|
||||
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' docker-compose.yml && rm -f docker-compose.yml.bak
|
||||
fi
|
||||
render_caddyfile > Caddyfile
|
||||
install -m 600 /dev/null config.yaml
|
||||
render_config_yaml >> config.yaml
|
||||
|
||||
echo "Logging in to ghcr.io ..."
|
||||
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
|
||||
unset GHCR_TOKEN
|
||||
|
||||
echo ""
|
||||
echo "Pulling images ..."
|
||||
$DOCKER_COMPOSE_COMMAND pull
|
||||
|
||||
echo ""
|
||||
echo "Starting postgres ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d postgres
|
||||
sleep 2
|
||||
wait_postgres
|
||||
|
||||
echo ""
|
||||
echo "Starting remaining services ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d
|
||||
|
||||
echo ""
|
||||
echo "Done."
|
||||
echo ""
|
||||
echo "Dashboard: https://${NETBIRD_DOMAIN}"
|
||||
echo ""
|
||||
echo "Open the dashboard in a browser to complete the first-login owner setup."
|
||||
echo "All configuration and secrets are stored (mode 600) in $(pwd)/.env"
|
||||
echo ""
|
||||
echo "Tail logs:"
|
||||
echo " cd $(pwd) && $DOCKER_COMPOSE_COMMAND logs -f netbird-server caddy"
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Renderers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
render_env() {
|
||||
cat <<EOF
|
||||
# Generated by getting-started-enterprise.sh
|
||||
# Holds all configuration and secrets for the stack. Mode 600.
|
||||
|
||||
# Features (set by the script; don't edit without re-running)
|
||||
NETBIRD_TRAFFIC_FLOW_ENABLED=${NETBIRD_TRAFFIC_FLOW}
|
||||
|
||||
# Domain
|
||||
NETBIRD_DOMAIN=${NETBIRD_DOMAIN}
|
||||
|
||||
# Image tags. Default to "latest"
|
||||
NETBIRD_DASHBOARD_TAG=${NETBIRD_DASHBOARD_TAG:-latest}
|
||||
NETBIRD_SERVER_TAG=${NETBIRD_SERVER_TAG:-latest}
|
||||
EOF
|
||||
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
NETBIRD_ENRICHER_TAG=${NETBIRD_ENRICHER_TAG:-latest}
|
||||
NETBIRD_RECEIVER_TAG=${NETBIRD_RECEIVER_TAG:-latest}
|
||||
EOF
|
||||
fi
|
||||
|
||||
cat <<EOF
|
||||
|
||||
# License keys
|
||||
EOF
|
||||
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
cat <<EOF
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
EOF
|
||||
fi
|
||||
cat <<EOF
|
||||
NETBIRD_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
EOF
|
||||
|
||||
cat <<EOF
|
||||
|
||||
# Postgres
|
||||
POSTGRES_USER=${POSTGRES_USER}
|
||||
POSTGRES_DB=${POSTGRES_DB}
|
||||
POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN=${POSTGRES_DSN}
|
||||
|
||||
# Relay
|
||||
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT}
|
||||
NETBIRD_RELAY_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
|
||||
|
||||
# Datastore encryption
|
||||
NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}
|
||||
|
||||
# Dashboard OIDC scopes
|
||||
NETBIRD_AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES:-openid profile email groups}
|
||||
EOF
|
||||
}
|
||||
|
||||
render_docker_compose() {
|
||||
render_compose_header
|
||||
render_compose_common
|
||||
render_compose_server
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
render_compose_flow
|
||||
fi
|
||||
render_compose_postgres
|
||||
render_compose_footer
|
||||
}
|
||||
|
||||
render_compose_header() {
|
||||
cat <<'EOF'
|
||||
x-default: &default
|
||||
restart: unless-stopped
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: '500m'
|
||||
max-file: '2'
|
||||
|
||||
services:
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_common() {
|
||||
cat <<'EOF'
|
||||
caddy:
|
||||
<<: *default
|
||||
image: caddy:2
|
||||
container_name: netbird-caddy
|
||||
networks: [netbird]
|
||||
environment:
|
||||
- CADDY_SECURE_DOMAIN=${NETBIRD_DOMAIN}
|
||||
ports:
|
||||
- '443:443'
|
||||
- '443:443/udp'
|
||||
- '80:80'
|
||||
volumes:
|
||||
- netbird_caddy_data:/data
|
||||
- ./Caddyfile:/etc/caddy/Caddyfile
|
||||
|
||||
dashboard:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/dashboard-cloud:${NETBIRD_DASHBOARD_TAG}
|
||||
container_name: netbird-dashboard
|
||||
networks: [netbird]
|
||||
environment:
|
||||
- NETBIRD_MGMT_API_ENDPOINT=https://${NETBIRD_DOMAIN}
|
||||
- NETBIRD_MGMT_GRPC_API_ENDPOINT=https://${NETBIRD_DOMAIN}
|
||||
- AUTH_AUDIENCE=netbird-dashboard
|
||||
- AUTH_CLIENT_ID=netbird-dashboard
|
||||
- AUTH_CLIENT_SECRET=
|
||||
- AUTH_AUTHORITY=https://${NETBIRD_DOMAIN}/oauth2
|
||||
- USE_AUTH0=false
|
||||
- AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES}
|
||||
- AUTH_REDIRECT_URI=/nb-auth
|
||||
- AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
|
||||
- NETBIRD_TOKEN_SOURCE=accessToken
|
||||
- NGINX_SSL_PORT=443
|
||||
- LETSENCRYPT_DOMAIN=
|
||||
- LETSENCRYPT_EMAIL=
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_server() {
|
||||
cat <<'EOF'
|
||||
netbird-server:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/netbird-server-cloud:${NETBIRD_SERVER_TAG}
|
||||
container_name: netbird-server
|
||||
networks: [netbird]
|
||||
depends_on:
|
||||
dashboard:
|
||||
condition: service_started
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
ports:
|
||||
- '3478:3478/udp'
|
||||
volumes:
|
||||
- netbird_data:/var/lib/netbird
|
||||
- ./config.yaml:/etc/netbird/config.yaml
|
||||
command: ["--config", "/etc/netbird/config.yaml"]
|
||||
environment:
|
||||
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_flow() {
|
||||
cat <<'EOF'
|
||||
nats:
|
||||
<<: *default
|
||||
image: nats:2
|
||||
container_name: netbird-nats
|
||||
networks: [netbird]
|
||||
volumes:
|
||||
- netbird_nats_data:/data
|
||||
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
|
||||
|
||||
enricher:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/flow-enricher-cloud:${NETBIRD_ENRICHER_TAG}
|
||||
container_name: netbird-enricher
|
||||
networks: [netbird]
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
nats:
|
||||
condition: service_started
|
||||
volumes:
|
||||
- netbird_enricher:/var/lib/netbird
|
||||
environment:
|
||||
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
- NB_DATADIR=/var/lib/netbird
|
||||
- NB_MANAGEMENT_STORE_ENGINE=postgres
|
||||
- NB_MANAGEMENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||
- NB_TRAFFIC_EVENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||
- NB_TRAFFIC_EVENT_STORE_ENGINE=postgres
|
||||
- NB_MANAGEMENT_STORE_KEY=${NETBIRD_ENCRYPTION_KEY}
|
||||
- NB_FLOW_ADAPTER_TYPE=nats
|
||||
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
|
||||
- NB_FLOW_NATS_STREAM=traffic-events
|
||||
- NB_METRICS_PORT=9091
|
||||
- NB_PERSISTENCE_RETENTION_PERIOD=168h
|
||||
|
||||
receiver:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/flow-receiver-cloud:${NETBIRD_RECEIVER_TAG}
|
||||
container_name: netbird-receiver
|
||||
networks: [netbird]
|
||||
depends_on:
|
||||
nats:
|
||||
condition: service_started
|
||||
environment:
|
||||
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
- NB_FLOW_LISTEN_PORT=80
|
||||
- NB_FLOW_ADAPTER_TYPE=nats
|
||||
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
|
||||
- NB_FLOW_NATS_STREAM=traffic-events
|
||||
- NB_FLOW_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_postgres() {
|
||||
cat <<'EOF'
|
||||
postgres:
|
||||
<<: *default
|
||||
image: postgres:17
|
||||
container_name: netbird-postgres
|
||||
networks: [netbird]
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
||||
- POSTGRES_DB=${POSTGRES_DB}
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
volumes:
|
||||
- netbird_postgres:/var/lib/postgresql/data
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_footer() {
|
||||
cat <<'EOF'
|
||||
volumes:
|
||||
netbird_data:
|
||||
EOF
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<'EOF'
|
||||
netbird_nats_data:
|
||||
netbird_enricher:
|
||||
EOF
|
||||
fi
|
||||
cat <<'EOF'
|
||||
netbird_postgres:
|
||||
netbird_caddy_data:
|
||||
|
||||
networks:
|
||||
netbird:
|
||||
EOF
|
||||
}
|
||||
|
||||
render_caddyfile() {
|
||||
cat <<'EOF'
|
||||
{
|
||||
servers :80,:443 {
|
||||
protocols h1 h2c h2 h3
|
||||
}
|
||||
}
|
||||
|
||||
(security_headers) {
|
||||
header * {
|
||||
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
|
||||
X-Content-Type-Options "nosniff"
|
||||
X-Frame-Options "SAMEORIGIN"
|
||||
X-XSS-Protection "1; mode=block"
|
||||
-Server
|
||||
Referrer-Policy strict-origin-when-cross-origin
|
||||
}
|
||||
}
|
||||
|
||||
:80 {
|
||||
redir https://{$CADDY_SECURE_DOMAIN}{uri} permanent
|
||||
}
|
||||
|
||||
{$CADDY_SECURE_DOMAIN}:443 {
|
||||
import security_headers
|
||||
# Signal (gRPC over h2c)
|
||||
reverse_proxy /signalexchange.SignalExchange/* h2c://netbird-server:80
|
||||
# Management (gRPC over h2c + HTTP)
|
||||
reverse_proxy /management.ManagementService/* h2c://netbird-server:80
|
||||
reverse_proxy /api/* netbird-server:80
|
||||
reverse_proxy /ws-proxy/* netbird-server:80
|
||||
# Embedded IdP (OAuth2 endpoints served by netbird server)
|
||||
reverse_proxy /oauth2/* netbird-server:80
|
||||
# Relay (WebSocket multiplexed on the same port)
|
||||
reverse_proxy /relay* netbird-server:80
|
||||
EOF
|
||||
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<'EOF'
|
||||
# Flow receiver (gRPC over h2c)
|
||||
reverse_proxy /flow.FlowService/* h2c://receiver:80
|
||||
EOF
|
||||
fi
|
||||
|
||||
cat <<'EOF'
|
||||
# Dashboard
|
||||
reverse_proxy /* dashboard:80
|
||||
}
|
||||
EOF
|
||||
}
|
||||
|
||||
render_config_yaml() {
|
||||
cat <<EOF
|
||||
# NetBird Enterprise server configuration.
|
||||
# Generated by getting-started-enterprise.sh. Mode 600.
|
||||
|
||||
server:
|
||||
listenAddress: ":80"
|
||||
exposedAddress: "https://${NETBIRD_DOMAIN}:443"
|
||||
|
||||
metricsPort: 9090
|
||||
healthcheckAddress: ":9000"
|
||||
|
||||
logLevel: "info"
|
||||
logFile: "console"
|
||||
|
||||
# TLS is terminated by Caddy in front; leave this block empty.
|
||||
tls:
|
||||
certFile: ""
|
||||
keyFile: ""
|
||||
letsencrypt:
|
||||
enabled: false
|
||||
|
||||
authSecret: "${NETBIRD_RELAY_AUTH_SECRET}"
|
||||
dataDir: "/var/lib/netbird/"
|
||||
|
||||
disableAnonymousMetrics: false
|
||||
disableGeoliteUpdate: false
|
||||
|
||||
auth:
|
||||
issuer: "https://${NETBIRD_DOMAIN}/oauth2"
|
||||
localAuthDisabled: false
|
||||
signKeyRefreshEnabled: false
|
||||
dashboardRedirectURIs:
|
||||
- "https://${NETBIRD_DOMAIN}/nb-auth"
|
||||
- "https://${NETBIRD_DOMAIN}/nb-silent-auth"
|
||||
cliRedirectURIs:
|
||||
- "http://localhost:53000/"
|
||||
|
||||
store:
|
||||
engine: "postgres"
|
||||
dsn: "${POSTGRES_DSN}"
|
||||
encryptionKey: "${NETBIRD_ENCRYPTION_KEY}"
|
||||
|
||||
activityStore:
|
||||
engine: "postgres"
|
||||
dsn: "${POSTGRES_DSN}"
|
||||
EOF
|
||||
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
|
||||
trafficFlow:
|
||||
enabled: true
|
||||
address: "https://${NETBIRD_DOMAIN}:443"
|
||||
interval: "60s"
|
||||
EOF
|
||||
fi
|
||||
}
|
||||
|
||||
init_environment
|
||||
@@ -398,7 +398,42 @@ configure_domain() {
|
||||
return 0
|
||||
}
|
||||
|
||||
apply_agent_network_preset() {
|
||||
# Agent-network turnkey install: built-in Traefik + NetBird Proxy with
|
||||
# NB_PROXY_PRIVATE=true, dashboard locked to agent-network-only mode.
|
||||
# Bypasses every reverse-proxy / proxy / CrowdSec prompt. The only
|
||||
# inputs we still need from the operator are the domain (handled by
|
||||
# configure_domain via NETBIRD_DOMAIN env var or interactive prompt)
|
||||
# and the ACME email — both honor env vars first and fall back to a
|
||||
# prompt only when unset. CrowdSec is intentionally off.
|
||||
REVERSE_PROXY_TYPE="0"
|
||||
ENABLE_PROXY="true"
|
||||
ENABLE_CROWDSEC="false"
|
||||
|
||||
if [[ -n "${NETBIRD_LETSENCRYPT_EMAIL}" ]]; then
|
||||
TRAEFIK_ACME_EMAIL="${NETBIRD_LETSENCRYPT_EMAIL}"
|
||||
else
|
||||
TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email)
|
||||
fi
|
||||
|
||||
echo "" > /dev/stderr
|
||||
echo "Agent-network preset enabled (NETBIRD_AGENT_NETWORK=true):" > /dev/stderr
|
||||
echo " - reverse proxy: built-in Traefik" > /dev/stderr
|
||||
echo " - NetBird Proxy: enabled with NB_PROXY_PRIVATE=true" > /dev/stderr
|
||||
echo " - dashboard: NETBIRD_AGENT_NETWORK_ONLY=true" > /dev/stderr
|
||||
echo " - CrowdSec: disabled" > /dev/stderr
|
||||
echo " - Let's Encrypt email: ${TRAEFIK_ACME_EMAIL}" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
}
|
||||
|
||||
configure_reverse_proxy() {
|
||||
# Short-circuit: agent-network preset locks every reverse-proxy /
|
||||
# proxy / CrowdSec choice and bypasses the interactive prompts.
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
apply_agent_network_preset
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Prompt for reverse proxy type
|
||||
REVERSE_PROXY_TYPE=$(read_reverse_proxy_type)
|
||||
|
||||
@@ -910,6 +945,15 @@ NGINX_SSL_PORT=443
|
||||
# Letsencrypt
|
||||
LETSENCRYPT_DOMAIN=none
|
||||
EOF
|
||||
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
cat <<EOF
|
||||
# Agent-network preset: dashboard hides the standard NetBird surfaces
|
||||
# and exposes only the AI Observability + agent-network configuration
|
||||
# pages. Paired with NB_PROXY_PRIVATE=true on the proxy side.
|
||||
NETBIRD_AGENT_NETWORK_ONLY=true
|
||||
EOF
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -946,6 +990,17 @@ NB_PROXY_PROXY_PROTOCOL=true
|
||||
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
|
||||
EOF
|
||||
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
cat <<EOF
|
||||
# Agent-network preset: turn the proxy into the private reverse-proxy
|
||||
# ingress for agent-network synth services. Disables the public-facing
|
||||
# surface so the proxy serves only synth-generated routes (the
|
||||
# llm_router-driven LLM endpoints) and the per-account inbound
|
||||
# listeners on the embedded netstack.
|
||||
NB_PROXY_PRIVATE=true
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [[ "$ENABLE_CROWDSEC" == "true" && -n "$CROWDSEC_BOUNCER_KEY" ]]; then
|
||||
cat <<EOF
|
||||
NB_PROXY_CROWDSEC_API_URL=http://crowdsec:8080
|
||||
|
||||
638
infrastructure_files/migrate-to-enterprise.sh
Executable file
638
infrastructure_files/migrate-to-enterprise.sh
Executable file
@@ -0,0 +1,638 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# NetBird — community combined → Enterprise combined migration
|
||||
#
|
||||
# Non-destructive migration: produces docker-compose.override.yml (auto-loaded
|
||||
# by docker compose) and config.yaml.enterprise alongside the operator's
|
||||
# existing files. Original docker-compose.yml and config.yaml are never
|
||||
# modified.
|
||||
#
|
||||
# Steps (all optional, asked interactively):
|
||||
# 1. Image swap — replace community images with enterprise cloud images.
|
||||
# 2. Postgres migration — add Postgres, migrate SQLite data via migrate-store.
|
||||
# 3. Traffic flow — add NATS + flow-enricher + flow-receiver.
|
||||
#
|
||||
# To revert:
|
||||
# docker compose down
|
||||
# rm -f docker-compose.override.yml config.yaml.enterprise
|
||||
# # If Postgres migration was done, also restore the SQLite backup printed
|
||||
# # at the end of this script's run.
|
||||
# docker compose up -d
|
||||
|
||||
OVERRIDE_FILE="docker-compose.override.yml"
|
||||
ENTERPRISE_CONFIG_FILE="config.yaml.enterprise"
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "docker-compose"
|
||||
return
|
||||
fi
|
||||
if docker compose --help &> /dev/null; then
|
||||
echo "docker compose"
|
||||
return
|
||||
fi
|
||||
echo "docker-compose is not installed or not in PATH." > /dev/stderr
|
||||
exit 1
|
||||
}
|
||||
|
||||
check_yq() {
|
||||
if ! command -v yq &> /dev/null; then
|
||||
cat > /dev/stderr <<'EOF'
|
||||
yq is required to parse and update YAML safely.
|
||||
|
||||
macOS: brew install yq
|
||||
Linux: https://github.com/mikefarah/yq/releases (download binary into PATH)
|
||||
Debian: apt-get install yq (Note: must be the mikefarah Go yq, not the Python wrapper.)
|
||||
|
||||
EOF
|
||||
exit 1
|
||||
fi
|
||||
if ! yq --version 2>&1 | grep -q "mikefarah"; then
|
||||
echo "yq is present but appears to be the wrong implementation. The mikefarah Go-based yq is required (https://github.com/mikefarah/yq)." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
check_openssl() {
|
||||
if ! command -v openssl &> /dev/null; then
|
||||
echo "openssl is not installed or not in PATH." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
rand_password() {
|
||||
openssl rand -hex 32
|
||||
}
|
||||
|
||||
read_required() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -r value < /dev/tty
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_secret() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -rs value < /dev/tty
|
||||
echo "" > /dev/stderr
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_yes_no() {
|
||||
local prompt="$1"
|
||||
local default="${2:-n}"
|
||||
local hint
|
||||
if [[ "$default" == "y" ]]; then
|
||||
hint="[Y/n]"
|
||||
else
|
||||
hint="[y/N]"
|
||||
fi
|
||||
echo -n "${prompt} ${hint}: " > /dev/stderr
|
||||
local ans=""
|
||||
read -r ans < /dev/tty
|
||||
if [[ -z "$ans" ]]; then
|
||||
ans="$default"
|
||||
fi
|
||||
case "$ans" in
|
||||
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
|
||||
*) echo "no" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection — read the operator's existing compose to find service names and
|
||||
# paths we need to override. Bail loudly if shape isn't recognised.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
detect_combined_service() {
|
||||
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/netbird-server"))) | .[0].key // ""' "$COMPOSE_FILE"
|
||||
}
|
||||
|
||||
detect_dashboard_service() {
|
||||
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/dashboard"))) | .[0].key // ""' "$COMPOSE_FILE"
|
||||
}
|
||||
|
||||
detect_config_yaml_host_path() {
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/etc/netbird/config.yaml\")) | sub(\":/etc/netbird/config.yaml.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
|
||||
}
|
||||
|
||||
detect_data_volume() {
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/var/lib/netbird\")) | sub(\":/var/lib/netbird.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
|
||||
}
|
||||
|
||||
detect_exposed_address() {
|
||||
yq eval '.server.exposedAddress // ""' "$CONFIG_YAML_HOST"
|
||||
}
|
||||
|
||||
detect_compose_network() {
|
||||
local tag
|
||||
tag=$(yq eval ".services[\"$COMBINED_SERVICE\"].networks | tag" "$COMPOSE_FILE" 2>/dev/null)
|
||||
case "$tag" in
|
||||
"!!seq")
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].networks[0]" "$COMPOSE_FILE"
|
||||
;;
|
||||
"!!map")
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].networks | keys | .[0]" "$COMPOSE_FILE"
|
||||
;;
|
||||
*)
|
||||
echo "default"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Renderers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Build docker-compose.override.yml from the steps the operator selected.
|
||||
# Service names match what we detected on the operator's side.
|
||||
render_override() {
|
||||
cat <<EOF
|
||||
# Generated by migrate-to-enterprise.sh. Mode 644.
|
||||
# Merged with docker-compose.yml automatically by Docker Compose.
|
||||
# Remove this file (and config.yaml.enterprise if present) to revert.
|
||||
|
||||
services:
|
||||
${DASHBOARD_SERVICE}:
|
||||
image: \${NETBIRD_DASHBOARD_IMAGE:-ghcr.io/netbirdio/dashboard-cloud:latest}
|
||||
|
||||
${COMBINED_SERVICE}:
|
||||
image: \${NETBIRD_SERVER_IMAGE:-ghcr.io/netbirdio/netbird-server-cloud:latest}
|
||||
environment:
|
||||
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
EOF
|
||||
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- ./${ENTERPRISE_CONFIG_FILE}:/etc/netbird/config.yaml.enterprise:ro
|
||||
command: ["--config", "/etc/netbird/config.yaml.enterprise"]
|
||||
|
||||
postgres:
|
||||
image: postgres:17
|
||||
container_name: netbird-postgres
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
environment:
|
||||
POSTGRES_USER: netbird
|
||||
POSTGRES_PASSWORD: \${POSTGRES_PASSWORD}
|
||||
POSTGRES_DB: netbird
|
||||
volumes:
|
||||
- netbird_postgres:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U netbird -d netbird"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 20
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
|
||||
nats:
|
||||
image: nats:2
|
||||
container_name: netbird-nats
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
|
||||
volumes:
|
||||
- netbird_nats_data:/data
|
||||
|
||||
flow-enricher:
|
||||
image: ghcr.io/netbirdio/flow-enricher-cloud:latest
|
||||
container_name: netbird-flow-enricher
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
nats:
|
||||
condition: service_started
|
||||
environment:
|
||||
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
NB_DATADIR: /var/lib/netbird
|
||||
NB_MANAGEMENT_STORE_ENGINE: postgres
|
||||
NB_MANAGEMENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
NB_STORE_ENGINE_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
NB_TRAFFIC_EVENT_STORE_ENGINE: postgres
|
||||
NB_TRAFFIC_EVENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
NB_MANAGEMENT_STORE_KEY: \${NETBIRD_ENCRYPTION_KEY}
|
||||
NB_FLOW_ADAPTER_TYPE: nats
|
||||
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
|
||||
NB_FLOW_NATS_STREAM: traffic-events
|
||||
NB_METRICS_PORT: 9091
|
||||
NB_PERSISTENCE_RETENTION_PERIOD: 168h
|
||||
|
||||
flow-receiver:
|
||||
image: ghcr.io/netbirdio/flow-receiver-cloud:latest
|
||||
container_name: netbird-flow-receiver
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
depends_on:
|
||||
nats:
|
||||
condition: service_started
|
||||
environment:
|
||||
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
NB_FLOW_LISTEN_PORT: 80
|
||||
NB_FLOW_ADAPTER_TYPE: nats
|
||||
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
|
||||
NB_FLOW_NATS_STREAM: traffic-events
|
||||
NB_FLOW_AUTH_SECRET: \${NB_FLOW_AUTH_SECRET}
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
- traefik.http.routers.netbird-flow.rule=Host(\`${NETBIRD_HOSTNAME}\`) && PathPrefix(\`/flow.FlowService/\`)
|
||||
- traefik.http.routers.netbird-flow.entrypoints=websecure
|
||||
- traefik.http.routers.netbird-flow.tls=true
|
||||
- traefik.http.routers.netbird-flow.tls.certresolver=letsencrypt
|
||||
- traefik.http.routers.netbird-flow.service=netbird-flow-h2c
|
||||
- traefik.http.routers.netbird-flow.priority=100
|
||||
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.port=80
|
||||
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.scheme=h2c
|
||||
EOF
|
||||
fi
|
||||
|
||||
# Volume declarations for anything new the override introduced
|
||||
local has_volumes="no"
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]] || [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
has_volumes="yes"
|
||||
fi
|
||||
|
||||
if [[ "$has_volumes" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
|
||||
volumes:
|
||||
EOF
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo " netbird_postgres:"
|
||||
fi
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
echo " netbird_nats_data:"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Build config.yaml.enterprise by yq-editing the operator's existing
|
||||
# config.yaml. We don't touch the original file.
|
||||
render_enterprise_config() {
|
||||
local pg_dsn="host=postgres user=netbird password=${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
|
||||
yq eval "
|
||||
.server.store.engine = \"postgres\" |
|
||||
.server.store.dsn = \"$pg_dsn\" |
|
||||
.server.activityStore.engine = \"postgres\" |
|
||||
.server.activityStore.dsn = \"$pg_dsn\" |
|
||||
.server.authStore.engine = \"postgres\" |
|
||||
.server.authStore.dsn = \"$pg_dsn\"
|
||||
" "$CONFIG_YAML_HOST" > "$ENTERPRISE_CONFIG_FILE"
|
||||
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
local flow_addr="${NETBIRD_DOMAIN}"
|
||||
yq eval -i "
|
||||
.server.trafficFlow.enabled = true |
|
||||
.server.trafficFlow.address = \"$flow_addr\" |
|
||||
.server.trafficFlow.interval = \"60s\"
|
||||
" "$ENTERPRISE_CONFIG_FILE"
|
||||
fi
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Execution steps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
resolve_data_volume() {
|
||||
local short="$1"
|
||||
local actual
|
||||
# Resolve project-prefixed volume name from Docker Compose config first.
|
||||
actual=$($DOCKER_COMPOSE_COMMAND config 2>/dev/null | yq eval ".volumes.\"$short\".name" - 2>/dev/null)
|
||||
if [[ -n "$actual" && "$actual" != "null" ]]; then
|
||||
echo "$actual"
|
||||
return
|
||||
fi
|
||||
# Relative bind mount: docker-compose resolves it against the compose
|
||||
# file's directory, but `docker run -v` resolves it against the current
|
||||
# working directory. Normalize to an absolute path so both interpretations
|
||||
# agree (and the printed revert command works from any CWD).
|
||||
if [[ "$short" == ./* || "$short" == ../* ]]; then
|
||||
local compose_dir
|
||||
compose_dir="$(cd "$(dirname "$COMPOSE_FILE")" && pwd)"
|
||||
(
|
||||
cd "$compose_dir"
|
||||
cd "$(dirname "$short")"
|
||||
printf '%s/%s\n' "$(pwd)" "$(basename "$short")"
|
||||
)
|
||||
return
|
||||
fi
|
||||
# Not a named volume (e.g. an absolute bind-mount path) — use it as-is.
|
||||
echo "$short"
|
||||
}
|
||||
|
||||
backup_sqlite() {
|
||||
BACKUP_DIR="$(pwd)/backups/sqlite-pre-enterprise-$(date +%Y%m%d-%H%M%S)"
|
||||
mkdir -p "$BACKUP_DIR"
|
||||
local data_volume_actual
|
||||
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
|
||||
echo "Backing up SQLite store from volume '$data_volume_actual' to $BACKUP_DIR ..."
|
||||
docker run --rm \
|
||||
-v "${data_volume_actual}:/var/lib/netbird:ro" \
|
||||
-v "${BACKUP_DIR}:/backup" \
|
||||
busybox \
|
||||
sh -c 'cp -a /var/lib/netbird/. /backup/ 2>/dev/null || true'
|
||||
local copied
|
||||
copied=$(find "$BACKUP_DIR" -mindepth 1 | head -1)
|
||||
if [[ -z "$copied" ]]; then
|
||||
echo " ⚠ Backup directory is empty — the volume '$data_volume_actual' didn't contain data. Aborting." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
echo " done"
|
||||
}
|
||||
|
||||
run_migrate_store() {
|
||||
echo "Running migrate-store (SQLite → Postgres) ..."
|
||||
$DOCKER_COMPOSE_COMMAND run --rm "$COMBINED_SERVICE" migrate-store --config /etc/netbird/config.yaml.enterprise --verify
|
||||
echo " done"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
init_migration() {
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
check_yq
|
||||
check_openssl
|
||||
|
||||
COMPOSE_FILE="${COMPOSE_FILE:-docker-compose.yml}"
|
||||
|
||||
if [[ ! -f "$COMPOSE_FILE" ]]; then
|
||||
echo "$COMPOSE_FILE not found in $(pwd)." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -f "$OVERRIDE_FILE" ]] || [[ -f "$ENTERPRISE_CONFIG_FILE" ]]; then
|
||||
echo "Migration artifacts already exist in $(pwd):"
|
||||
[[ -f "$OVERRIDE_FILE" ]] && echo " $OVERRIDE_FILE"
|
||||
[[ -f "$ENTERPRISE_CONFIG_FILE" ]] && echo " $ENTERPRISE_CONFIG_FILE"
|
||||
echo ""
|
||||
echo "Either you've already migrated, or a previous run was interrupted."
|
||||
echo "To re-run cleanly: rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
COMBINED_SERVICE=$(detect_combined_service)
|
||||
DASHBOARD_SERVICE=$(detect_dashboard_service)
|
||||
CONFIG_YAML_HOST=$(detect_config_yaml_host_path)
|
||||
DATA_VOLUME=$(detect_data_volume)
|
||||
COMPOSE_NETWORK=$(detect_compose_network)
|
||||
|
||||
if [[ -z "$COMBINED_SERVICE" ]]; then
|
||||
echo "Could not find a service running netbirdio/netbird-server* in $COMPOSE_FILE." > /dev/stderr
|
||||
echo "This script targets the community combined-server deployment." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$DASHBOARD_SERVICE" ]]; then
|
||||
echo "Could not find a service running netbirdio/dashboard* in $COMPOSE_FILE." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$CONFIG_YAML_HOST" ]]; then
|
||||
echo "Could not find a config.yaml mount on $COMBINED_SERVICE (expected to bind-mount to /etc/netbird/config.yaml)." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ ! -f "$CONFIG_YAML_HOST" ]]; then
|
||||
echo "config.yaml host file not found at $CONFIG_YAML_HOST." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$DATA_VOLUME" ]]; then
|
||||
echo "Could not find a volume mounted at /var/lib/netbird on $COMBINED_SERVICE." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Detected existing deployment:"
|
||||
echo " Combined service: $COMBINED_SERVICE"
|
||||
echo " Dashboard: $DASHBOARD_SERVICE"
|
||||
echo " config.yaml: $CONFIG_YAML_HOST"
|
||||
echo " Data volume: $DATA_VOLUME"
|
||||
echo " Network: $COMPOSE_NETWORK"
|
||||
echo ""
|
||||
|
||||
local proceed
|
||||
proceed=$(read_yes_no "Proceed with migration?" "y")
|
||||
if [[ "$proceed" != "yes" ]]; then
|
||||
echo "Aborted."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Step 1 — always (this is the point of the script)
|
||||
MIGRATE_IMAGES="yes"
|
||||
echo ""
|
||||
echo "Step 1: Image swap (community → Enterprise). License key required."
|
||||
NB_LICENSE_KEY=$(read_secret " License key")
|
||||
GHCR_USERNAME="netbirdExtAccess1"
|
||||
GHCR_TOKEN=$(read_secret " GHCR token (input hidden)")
|
||||
|
||||
# Step 2 — optional
|
||||
echo ""
|
||||
MIGRATE_POSTGRES=$(read_yes_no "Step 2: Migrate storage from SQLite to Postgres? (recommended)" "n")
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo ""
|
||||
echo " ⚠ Data will be migrated from SQLite to Postgres. The SQLite store"
|
||||
echo " will be backed up automatically. To fully revert later, restore"
|
||||
echo " that backup and delete docker-compose.override.yml +"
|
||||
echo " config.yaml.enterprise."
|
||||
local confirm
|
||||
confirm=$(read_yes_no " Continue?" "y")
|
||||
if [[ "$confirm" != "yes" ]]; then
|
||||
MIGRATE_POSTGRES="no"
|
||||
echo " Skipping Postgres migration."
|
||||
else
|
||||
POSTGRES_PASSWORD=$(rand_password)
|
||||
fi
|
||||
fi
|
||||
|
||||
# Step 3 — optional, only if Postgres is on (flow requires Postgres)
|
||||
echo ""
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
ENABLE_FLOW=$(read_yes_no "Step 3: Enable traffic flow? (requires Postgres)" "n")
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
# Auth secret MUST match server.authSecret from config.yaml
|
||||
NB_FLOW_AUTH_SECRET=$(yq eval '.server.authSecret // ""' "$CONFIG_YAML_HOST")
|
||||
if [[ -z "$NB_FLOW_AUTH_SECRET" ]] || [[ "$NB_FLOW_AUTH_SECRET" == "null" ]]; then
|
||||
echo "Could not read server.authSecret from $CONFIG_YAML_HOST." > /dev/stderr
|
||||
echo "Flow receiver auth must match the combined server's authSecret." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
|
||||
NETBIRD_DOMAIN=$(detect_exposed_address)
|
||||
if [[ -z "$NETBIRD_DOMAIN" ]] || [[ "$NETBIRD_DOMAIN" == "null" ]]; then
|
||||
NETBIRD_DOMAIN=$(read_required " Public NetBird URL (e.g. https://netbird.example.com)")
|
||||
fi
|
||||
# Strip protocol + port to leave just the hostname for the Traefik Host() rule.
|
||||
NETBIRD_HOSTNAME=$(echo "$NETBIRD_DOMAIN" | sed -E 's,^https?://,,' | sed 's,:.*,,' | sed 's,/.*,,')
|
||||
|
||||
# We need the encryption key from the existing config.yaml for the enricher
|
||||
NETBIRD_ENCRYPTION_KEY=$(yq eval '.server.store.encryptionKey // ""' "$CONFIG_YAML_HOST")
|
||||
if [[ -z "$NETBIRD_ENCRYPTION_KEY" ]] || [[ "$NETBIRD_ENCRYPTION_KEY" == "null" ]]; then
|
||||
echo "Could not read server.store.encryptionKey from $CONFIG_YAML_HOST." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
else
|
||||
ENABLE_FLOW="no"
|
||||
echo "Step 3 (traffic flow) skipped — requires Postgres."
|
||||
fi
|
||||
}
|
||||
|
||||
apply_changes() {
|
||||
echo ""
|
||||
echo "Writing $OVERRIDE_FILE ..."
|
||||
install -m 644 /dev/null "$OVERRIDE_FILE"
|
||||
render_override > "$OVERRIDE_FILE"
|
||||
|
||||
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' "$OVERRIDE_FILE" && rm -f "$OVERRIDE_FILE.bak"
|
||||
fi
|
||||
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo "Writing $ENTERPRISE_CONFIG_FILE ..."
|
||||
install -m 600 /dev/null "$ENTERPRISE_CONFIG_FILE"
|
||||
render_enterprise_config
|
||||
fi
|
||||
|
||||
# Persist secrets that the override file references via env interpolation.
|
||||
# We write them to a .env file in the current directory; docker compose
|
||||
# picks it up automatically.
|
||||
echo "Writing .env additions (mode 600) ..."
|
||||
local ENV_FILE=".env"
|
||||
touch "$ENV_FILE"
|
||||
chmod 600 "$ENV_FILE"
|
||||
{
|
||||
echo ""
|
||||
echo "# Added by migrate-to-enterprise.sh on $(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||
echo "NB_LICENSE_KEY=${NB_LICENSE_KEY}"
|
||||
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
echo "NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}"
|
||||
fi
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo "POSTGRES_PASSWORD=${POSTGRES_PASSWORD}"
|
||||
fi
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
echo "NB_FLOW_AUTH_SECRET=${NB_FLOW_AUTH_SECRET}"
|
||||
echo "NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}"
|
||||
fi
|
||||
} >> "$ENV_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Logging in to ghcr.io ..."
|
||||
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
|
||||
unset GHCR_TOKEN
|
||||
|
||||
echo ""
|
||||
echo "Pulling enterprise images ..."
|
||||
$DOCKER_COMPOSE_COMMAND pull
|
||||
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo ""
|
||||
echo "Stopping existing services (volumes preserved) ..."
|
||||
$DOCKER_COMPOSE_COMMAND down
|
||||
|
||||
backup_sqlite
|
||||
|
||||
echo ""
|
||||
echo "Starting Postgres ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d postgres
|
||||
|
||||
# Wait for healthy
|
||||
local counter=0
|
||||
echo -n "Waiting for Postgres to become ready"
|
||||
while ! $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U netbird -d netbird &> /dev/null; do
|
||||
echo -n " ."
|
||||
sleep 2
|
||||
counter=$((counter + 1))
|
||||
if [[ $counter -ge 60 ]]; then
|
||||
echo ""
|
||||
echo "Postgres did not become ready in 120s. Recent logs:"
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
echo " done"
|
||||
|
||||
run_migrate_store
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Bringing up all services ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d
|
||||
|
||||
echo ""
|
||||
echo "Migration complete."
|
||||
}
|
||||
|
||||
print_summary() {
|
||||
echo ""
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " Summary"
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " Images: swapped to enterprise"
|
||||
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " Storage: Postgres (data migrated from SQLite)"
|
||||
[[ "$MIGRATE_POSTGRES" != "yes" ]] && echo " Storage: SQLite (unchanged)"
|
||||
[[ "$ENABLE_FLOW" == "yes" ]] && echo " Traffic flow: enabled"
|
||||
[[ "$ENABLE_FLOW" != "yes" ]] && echo " Traffic flow: disabled"
|
||||
echo ""
|
||||
echo " Generated files (next to your docker-compose.yml):"
|
||||
echo " $OVERRIDE_FILE"
|
||||
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " $ENTERPRISE_CONFIG_FILE"
|
||||
echo " .env (license key + secrets, mode 600)"
|
||||
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " backups/sqlite-pre-enterprise-*/ (SQLite backup)"
|
||||
echo ""
|
||||
echo " Tail logs:"
|
||||
echo " $DOCKER_COMPOSE_COMMAND logs -f $COMBINED_SERVICE"
|
||||
echo ""
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " To revert"
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " $DOCKER_COMPOSE_COMMAND down"
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
# Resolve project-prefixed volume names now (before override is removed).
|
||||
local pg_volume data_volume_actual
|
||||
pg_volume=$(resolve_data_volume "netbird_postgres")
|
||||
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
|
||||
echo " # Remove the Postgres volume FIRST, before deleting the override file:"
|
||||
echo " docker volume rm $pg_volume"
|
||||
echo " # Restore SQLite from the backup created during this run:"
|
||||
echo " docker run --rm -v ${data_volume_actual}:/var/lib/netbird -v ${BACKUP_DIR}:/backup busybox sh -c 'cp -a /backup/. /var/lib/netbird/'"
|
||||
fi
|
||||
echo " rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
|
||||
echo " # Remove migrate-to-enterprise.sh additions from .env (search for the timestamp marker)"
|
||||
echo " $DOCKER_COMPOSE_COMMAND up -d"
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
init_migration
|
||||
apply_changes
|
||||
print_summary
|
||||
@@ -116,6 +116,24 @@ func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, p
|
||||
c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
// injectAllProxyPolicies prepares an account for the per-peer network-map
|
||||
// computation. It prepends the in-memory agent-network services synthesised
|
||||
// from the account's current provider/policy state to account.Services so
|
||||
// the existing InjectProxyPolicies + injectPrivateServicePolicies walks pick
|
||||
// them up alongside persisted reverse-proxy services. Synthesised services
|
||||
// are never persisted; the account is loaded fresh per cycle so re-prepending
|
||||
// is safe and idempotent. Accounts without agent-network providers get an
|
||||
// empty synth slice — no behaviour change.
|
||||
func (c *Controller) injectAllProxyPolicies(ctx context.Context, account *types.Account) {
|
||||
synth, err := c.repo.SynthesizeAgentNetworkServices(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("synthesise agent-network services for account %s: %v", account.Id, err)
|
||||
} else if len(synth) > 0 {
|
||||
account.Services = append(synth, account.Services...)
|
||||
}
|
||||
account.InjectProxyPolicies(ctx)
|
||||
}
|
||||
|
||||
func (c *Controller) CountStreams() int {
|
||||
return c.peersUpdateManager.CountStreams()
|
||||
}
|
||||
@@ -150,7 +168,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
@@ -281,7 +299,15 @@ func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID s
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
// The affected-peer path MUST mirror sendUpdateAccountPeers (line 171)
|
||||
// here: injectAllProxyPolicies prepends the synthesised agent-network
|
||||
// services BEFORE InjectProxyPolicies + private-service policies run.
|
||||
// Previously this path called only account.InjectProxyPolicies, which
|
||||
// skipped the synth-services prepend — so peer-level changes
|
||||
// (proxy restart, embedded peer connect/disconnect) propagated a
|
||||
// network map that omitted the synth DNS zone, and the agent kept
|
||||
// resolving against the stale or absent record.
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
@@ -399,7 +425,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return fmt.Errorf("failed to get validated peers: %v", err)
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
@@ -497,7 +523,7 @@ func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID st
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s with reason %s/%s", len(peerIDs), accountID, util.GetCallerName(), reason.Operation, reason.Resource)
|
||||
|
||||
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
|
||||
peerIDs: make(map[string]struct{}),
|
||||
@@ -603,19 +629,17 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
startPosture := time.Now()
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
@@ -876,7 +900,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
@@ -3,7 +3,9 @@ package controller
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -16,6 +18,10 @@ type Repository interface {
|
||||
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
|
||||
GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error)
|
||||
// SynthesizeAgentNetworkServices returns the in-memory reverse-proxy
|
||||
// services synthesised from the account's agent-network provider/policy
|
||||
// state. Empty for accounts without agent-network providers.
|
||||
SynthesizeAgentNetworkServices(ctx context.Context, accountID string) ([]*service.Service, error)
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
@@ -50,6 +56,10 @@ func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID s
|
||||
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
func (r *repository) SynthesizeAgentNetworkServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||
return agentnetwork.SynthesizeServices(ctx, r.store, accountID)
|
||||
}
|
||||
|
||||
func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) {
|
||||
return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package agentnetwork
|
||||
|
||||
import "github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
|
||||
// init registers the agent-network service synthesiser with the affectedpeers
|
||||
// resolver. Agent-network reverse-proxy services are synthesised on demand and
|
||||
// never persisted, so the resolver can't load them from the store; without them
|
||||
// it can't fold the embedded proxy peer into the affected set on a client
|
||||
// group/peer change, and the proxy never learns a newly authorised client until
|
||||
// it reconnects. Registered here (rather than via a direct
|
||||
// affectedpeers→agentnetwork import) to avoid an import cycle
|
||||
// (agentnetwork → account → affectedpeers).
|
||||
func init() {
|
||||
affectedpeers.SetAgentNetworkSynthesizer(SynthesizeServices)
|
||||
}
|
||||
749
management/internals/modules/agentnetwork/catalog/catalog.go
Normal file
749
management/internals/modules/agentnetwork/catalog/catalog.go
Normal file
@@ -0,0 +1,749 @@
|
||||
// Package catalog defines the static set of Agent Network providers
|
||||
// recognized by the management server. The catalog is consulted both to
|
||||
// validate provider_id on create/update and to surface the available
|
||||
// providers (and their models) to the dashboard.
|
||||
package catalog
|
||||
|
||||
import "github.com/netbirdio/netbird/shared/management/http/api"
|
||||
|
||||
// Model is the in-memory representation of a catalog model.
|
||||
type Model struct {
|
||||
ID string
|
||||
Label string
|
||||
InputPer1k float64
|
||||
OutputPer1k float64
|
||||
ContextWindow int
|
||||
}
|
||||
|
||||
// ProviderKind groups catalog entries for UI presentation. The split
|
||||
// is semantic, not technical:
|
||||
// - KindProvider: the upstream is a vendor's first-party API (OpenAI,
|
||||
// Anthropic, Mistral, Bedrock, etc.) — NetBird talks straight to
|
||||
// the model provider.
|
||||
// - KindGateway: the upstream is itself a routing / aggregation layer
|
||||
// in front of multiple providers (LiteLLM, Portkey, Helicone, …).
|
||||
// These typically need NetBird identity stamped onto upstream
|
||||
// requests so the gateway's analytics and budgets attribute to the
|
||||
// real caller; that's what IdentityInjection is for.
|
||||
// - KindCustom: the catch-all "OpenAI-compatible self-hosted endpoint"
|
||||
// entry (vLLM, Ollama, custom inference servers).
|
||||
//
|
||||
// Frontend uses Kind to group the provider Select in the modal so an
|
||||
// operator can spot at a glance which catalog entries proxy other
|
||||
// providers vs. talk straight to one. Backend doesn't dispatch on Kind
|
||||
// today; it's purely a presentation hint.
|
||||
type ProviderKind string
|
||||
|
||||
const (
|
||||
KindProvider ProviderKind = "provider"
|
||||
KindGateway ProviderKind = "gateway"
|
||||
KindCustom ProviderKind = "custom"
|
||||
)
|
||||
|
||||
// Provider is the in-memory representation of a catalog provider.
|
||||
type Provider struct {
|
||||
ID string
|
||||
Name string
|
||||
Description string
|
||||
DefaultHost string
|
||||
// Kind groups this entry for UI presentation; see ProviderKind.
|
||||
Kind ProviderKind
|
||||
// AuthHeaderName is the HTTP header the provider's API expects
|
||||
// the credential under (e.g. "Authorization" for OpenAI,
|
||||
// "x-api-key" for Anthropic). Combined with AuthHeaderTemplate
|
||||
// at synthesis time to inject the auth header on every upstream
|
||||
// request.
|
||||
AuthHeaderName string
|
||||
AuthHeaderTemplate string
|
||||
DefaultContentType string
|
||||
BrandColor string
|
||||
// ParserID names the proxy LLM parser surface this provider
|
||||
// speaks (matches llm.Parser.ProviderName: "openai",
|
||||
// "anthropic"). Multiple catalog ids may share a parser surface
|
||||
// (e.g. azure_openai_api and mistral_api both speak the OpenAI
|
||||
// shape). Empty when no parser is yet implemented for the
|
||||
// surface — the proxy middleware then falls back to URL sniffing
|
||||
// or skips request-side enrichment.
|
||||
ParserID string
|
||||
// IdentityInjection, when non-nil, instructs the proxy to stamp
|
||||
// the caller's NetBird identity onto upstream requests under the
|
||||
// configured header names. Used for gateways like LiteLLM that
|
||||
// key budgets and attribution off request headers (the gateway
|
||||
// otherwise has no way to learn which user / group made the call).
|
||||
// The proxy strips the same header names from the inbound request
|
||||
// before stamping ours, so an app can't spoof identity by setting
|
||||
// these headers itself.
|
||||
IdentityInjection *IdentityInjection
|
||||
// ExtraHeaders is a catalog-declared list of additional per-
|
||||
// provider routing/config headers the proxy stamps on every
|
||||
// upstream request. Distinct from AuthHeaderName/Template (which
|
||||
// always carries the API_KEY) and from IdentityInjection (caller
|
||||
// identity). Each entry surfaces an optional input on the
|
||||
// dashboard's provider modal whose value lives on the provider
|
||||
// record's ExtraValues map (keyed by ExtraHeader.Name). Empty
|
||||
// list = no extra inputs rendered. Used today by Portkey for
|
||||
// "x-portkey-config: pc-..." (a saved-config id that resolves
|
||||
// upstream provider + credentials on Portkey's hosted side).
|
||||
ExtraHeaders []ExtraHeader
|
||||
Models []Model
|
||||
}
|
||||
|
||||
// ExtraHeader names a single optional per-provider routing/config
|
||||
// header. Catalog declares N of these per provider type; the operator
|
||||
// fills any subset on the provider record (see Provider.ExtraValues).
|
||||
// At synth time, only entries with a non-empty operator value are
|
||||
// stamped; the proxy's identity-inject middleware applies anti-spoof
|
||||
// (Remove + Add) so a client can't supply these headers themselves.
|
||||
//
|
||||
// UI copy (label / help text / tooltip) for each known Name lives on
|
||||
// the dashboard, not here — the backend's job is just to declare
|
||||
// which wire headers are accepted. New provider needs an extra
|
||||
// header? Add the Name here AND the matching UI copy on the dashboard.
|
||||
type ExtraHeader struct {
|
||||
// Name is the wire header name, e.g. "x-portkey-config".
|
||||
Name string
|
||||
}
|
||||
|
||||
// IdentityInjection describes how the proxy stamps NetBird identity onto
|
||||
// upstream gateway requests. Exactly one shape must be set — they're
|
||||
// mutually exclusive and dispatched by the inject middleware.
|
||||
//
|
||||
// Shape choice tracks the wire convention the upstream gateway uses,
|
||||
// not the vendor name. New gateways with a known shape become a catalog
|
||||
// entry, not a new code path.
|
||||
type IdentityInjection struct {
|
||||
// HeaderPair emits separate headers per identity dimension
|
||||
// (end-user id, tags as CSV). LiteLLM and OpenAI-compatible
|
||||
// self-hosted gateways that read identity from dedicated headers.
|
||||
HeaderPair *HeaderPairInjection
|
||||
// JSONMetadata emits a single header carrying a JSON object with
|
||||
// reserved keys for user / groups / etc. Portkey, Helicone-style
|
||||
// metadata headers, anything that wants a structured envelope.
|
||||
JSONMetadata *JSONMetadataInjection
|
||||
}
|
||||
|
||||
// HeaderPairInjection is the LiteLLM-style wire convention.
|
||||
type HeaderPairInjection struct {
|
||||
// Customizable, when true, marks the wire header names as
|
||||
// operator-overridable: the dashboard surfaces EndUserIDHeader
|
||||
// and TagsHeader as editable inputs (defaults shown as
|
||||
// placeholders) and the synthesizer pulls the actual values from
|
||||
// the provider record's IdentityHeader* fields rather than from
|
||||
// these defaults. An empty operator value disables stamping for
|
||||
// that dimension. Used today for Bifrost, whose log-metadata /
|
||||
// telemetry header prefix (x-bf-lh-* vs x-bf-dim-*) is a
|
||||
// per-operator choice; LiteLLM and similar gateways with a fixed
|
||||
// wire protocol leave this false so the catalog defaults are
|
||||
// authoritative.
|
||||
Customizable bool
|
||||
// EndUserIDHeader receives the caller's display identity (user
|
||||
// email when the peer is attached to a user, else peer.Name),
|
||||
// e.g. "x-litellm-end-user-id".
|
||||
EndUserIDHeader string
|
||||
// TagsHeader receives the caller's NetBird group display names
|
||||
// as a CSV, e.g. "x-litellm-tags".
|
||||
TagsHeader string
|
||||
// TagsInBody, when true, additionally writes the tag list into
|
||||
// the request body's metadata.tags array (a JSON path the
|
||||
// gateway parses for budget enforcement). LiteLLM only honours
|
||||
// metadata.tags for tag-budget gating — its x-litellm-tags
|
||||
// header path feeds spend tracking but bypasses
|
||||
// _tag_max_budget_check entirely. Body inject is skipped when
|
||||
// the request body is empty, truncated, non-JSON, or when an
|
||||
// existing metadata field is a non-object value (defensive: we
|
||||
// never clobber a client-supplied non-object). The header path
|
||||
// remains a robust fallback for spend tracking in those cases.
|
||||
TagsInBody bool
|
||||
// EndUserIDInBody, when true, additionally writes the display
|
||||
// identity into the request body's top-level "user" field (the
|
||||
// OpenAI-standard end-user identifier). LiteLLM resolves the end
|
||||
// user id from headers first then body, so for LiteLLM this is
|
||||
// belt-and-suspenders. It matters when an OpenAI-compatible
|
||||
// gateway downstream of LiteLLM (or OpenAI direct, bypassing
|
||||
// LiteLLM) only reads the body, and as anti-spoof: client-
|
||||
// supplied "user" values are overwritten with our trusted
|
||||
// identity. Same skip rules as TagsInBody.
|
||||
EndUserIDInBody bool
|
||||
}
|
||||
|
||||
// JSONMetadataInjection is the Portkey-style wire convention: a single
|
||||
// header carrying a JSON object. NetBird identity fields land under the
|
||||
// configured reserved keys; missing keys (empty string) are skipped at
|
||||
// emit time.
|
||||
type JSONMetadataInjection struct {
|
||||
// Customizable, when true, marks the JSON keys as operator-
|
||||
// overridable. The dashboard surfaces UserKey and GroupsKey as
|
||||
// editable inputs (the catalog values shown as placeholders) and
|
||||
// the synthesizer pulls the actual JSON-key names from the
|
||||
// provider record's IdentityHeader* fields. Same field reuse as
|
||||
// HeaderPair's customizable path — the dimensions (user identity,
|
||||
// groups) are the same, only the wire encoding differs (JSON key
|
||||
// vs HTTP header name). An empty operator value disables emission
|
||||
// for that dimension. Used today for Cloudflare AI Gateway, whose
|
||||
// cf-aig-metadata header accepts arbitrary JSON keys; Portkey
|
||||
// leaves this false because its keys are reserved by the Portkey
|
||||
// schema.
|
||||
Customizable bool
|
||||
// Header is the wire header name carrying the JSON payload, e.g.
|
||||
// "x-portkey-metadata".
|
||||
Header string
|
||||
// UserKey is the JSON key for the caller's display identity.
|
||||
// Portkey reserves "_user" for this dimension.
|
||||
UserKey string
|
||||
// GroupsKey is the JSON key for the caller's NetBird groups,
|
||||
// emitted as a CSV string value (Portkey requires string values).
|
||||
GroupsKey string
|
||||
// MaxValueLength caps each emitted JSON value, in bytes. Portkey
|
||||
// enforces a 128-char limit per value; oversized values are
|
||||
// truncated rather than failing the request. 0 disables the cap.
|
||||
MaxValueLength int
|
||||
}
|
||||
|
||||
// providers is the canonical list of supported Agent Network providers.
|
||||
// Update this list together with the dashboard's PROVIDER_CATALOG.
|
||||
var providers = []Provider{
|
||||
{
|
||||
ID: "openai_api",
|
||||
Kind: KindProvider,
|
||||
Name: "OpenAI API",
|
||||
Description: "GPT, Responses API, and Embeddings",
|
||||
DefaultHost: "api.openai.com",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#10A37F",
|
||||
ParserID: "openai",
|
||||
// Pricing + context windows cross-checked against LiteLLM's
|
||||
// model_prices_and_context_window.json. Notable corrections from
|
||||
// earlier values: o4-mini repriced from $4/$16 to $1.10/$4.40
|
||||
// per MTok, gpt-4o from $5/$15 to $2.50/$10, and the GPT-5
|
||||
// family context windows split between 1.05M for full-size
|
||||
// models and 272K for mini/nano/codex variants.
|
||||
Models: []Model{
|
||||
{ID: "gpt-5.5", Label: "GPT-5.5", InputPer1k: 0.005, OutputPer1k: 0.030, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.5-pro", Label: "GPT-5.5 Pro", InputPer1k: 0.030, OutputPer1k: 0.180, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4", Label: "GPT-5.4", InputPer1k: 0.0025, OutputPer1k: 0.015, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4-pro", Label: "GPT-5.4 Pro", InputPer1k: 0.030, OutputPer1k: 0.180, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4-mini", Label: "GPT-5.4 Mini", InputPer1k: 0.00075, OutputPer1k: 0.0045, ContextWindow: 272000},
|
||||
{ID: "gpt-5.4-nano", Label: "GPT-5.4 Nano", InputPer1k: 0.0002, OutputPer1k: 0.00125, ContextWindow: 272000},
|
||||
{ID: "gpt-5.3-codex", Label: "GPT-5.3 Codex", InputPer1k: 0.00175, OutputPer1k: 0.014, ContextWindow: 272000},
|
||||
{ID: "gpt-5.3-chat-latest", Label: "GPT-5.3 Chat", InputPer1k: 0.00175, OutputPer1k: 0.014, ContextWindow: 128000},
|
||||
{ID: "o4-mini", Label: "o4-mini", InputPer1k: 0.0011, OutputPer1k: 0.0044, ContextWindow: 200000},
|
||||
{ID: "gpt-4.1", Label: "GPT-4.1", InputPer1k: 0.002, OutputPer1k: 0.008, ContextWindow: 1047576},
|
||||
{ID: "gpt-4.1-mini", Label: "GPT-4.1 mini", InputPer1k: 0.0004, OutputPer1k: 0.0016, ContextWindow: 1047576},
|
||||
{ID: "gpt-4.1-nano", Label: "GPT-4.1 nano", InputPer1k: 0.0001, OutputPer1k: 0.0004, ContextWindow: 1047576},
|
||||
{ID: "gpt-4o", Label: "GPT-4o", InputPer1k: 0.0025, OutputPer1k: 0.010, ContextWindow: 128000},
|
||||
{ID: "gpt-4o-mini", Label: "GPT-4o mini", InputPer1k: 0.00015, OutputPer1k: 0.0006, ContextWindow: 128000},
|
||||
{ID: "gpt-4-turbo", Label: "GPT-4 Turbo", InputPer1k: 0.01, OutputPer1k: 0.03, ContextWindow: 128000},
|
||||
{ID: "gpt-3.5-turbo", Label: "GPT-3.5 Turbo", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 16385},
|
||||
{ID: "text-embedding-3-large", Label: "text-embedding-3-large", InputPer1k: 0.00013, OutputPer1k: 0, ContextWindow: 8191},
|
||||
{ID: "text-embedding-3-small", Label: "text-embedding-3-small", InputPer1k: 0.00002, OutputPer1k: 0, ContextWindow: 8191},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "anthropic_api",
|
||||
Kind: KindProvider,
|
||||
Name: "Anthropic API",
|
||||
Description: "Claude Messages API",
|
||||
DefaultHost: "api.anthropic.com",
|
||||
AuthHeaderName: "x-api-key",
|
||||
AuthHeaderTemplate: "${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#D97757",
|
||||
ParserID: "anthropic",
|
||||
// Per Anthropic's current model lineup. Pricing in USD per 1k
|
||||
// tokens. Context windows: 4.6+ family is 1M; Haiku 4.5 stays at
|
||||
// 200K. claude-3-7-sonnet and claude-3-5-haiku retired
|
||||
// 2026-02-19 — dropped from the catalog. claude-opus-4-1
|
||||
// deprecated, retires 2026-08-05 — kept until the cutover.
|
||||
// claude-mythos-5 omitted: Project Glasswing access only, not a
|
||||
// general-availability target. claude-fable-5 requires the
|
||||
// account to be on >= 30-day data retention or all requests
|
||||
// 400.
|
||||
Models: []Model{
|
||||
{ID: "claude-fable-5", Label: "Claude Fable 5", InputPer1k: 0.010, OutputPer1k: 0.050, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-8", Label: "Claude Opus 4.8", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-7", Label: "Claude Opus 4.7", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-6", Label: "Claude Opus 4.6", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-1", Label: "Claude Opus 4.1 (deprecated, retires 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000},
|
||||
{ID: "claude-sonnet-4-6", Label: "Claude Sonnet 4.6", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000},
|
||||
{ID: "claude-sonnet-4-5", Label: "Claude Sonnet 4.5", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000},
|
||||
{ID: "claude-haiku-4-5", Label: "Claude Haiku 4.5", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "azure_openai_api",
|
||||
Kind: KindProvider,
|
||||
Name: "Azure OpenAI API",
|
||||
Description: "Azure-hosted OpenAI deployments",
|
||||
DefaultHost: "<resource>.openai.azure.com",
|
||||
AuthHeaderName: "api-key",
|
||||
AuthHeaderTemplate: "${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#0078D4",
|
||||
ParserID: "openai",
|
||||
// Mirrors openai_api pricing — Azure resells OpenAI models at the
|
||||
// same per-token rates, just under different deployment names.
|
||||
Models: []Model{
|
||||
{ID: "gpt-5.5", Label: "GPT-5.5 (Azure)", InputPer1k: 0.005, OutputPer1k: 0.030, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4", Label: "GPT-5.4 (Azure)", InputPer1k: 0.0025, OutputPer1k: 0.015, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4-mini", Label: "GPT-5.4 Mini (Azure)", InputPer1k: 0.00075, OutputPer1k: 0.0045, ContextWindow: 272000},
|
||||
{ID: "gpt-5.4-nano", Label: "GPT-5.4 Nano (Azure)", InputPer1k: 0.0002, OutputPer1k: 0.00125, ContextWindow: 272000},
|
||||
{ID: "o4-mini", Label: "o4-mini (Azure)", InputPer1k: 0.0011, OutputPer1k: 0.0044, ContextWindow: 200000},
|
||||
{ID: "gpt-4.1", Label: "GPT-4.1 (Azure)", InputPer1k: 0.002, OutputPer1k: 0.008, ContextWindow: 1047576},
|
||||
{ID: "gpt-4.1-mini", Label: "GPT-4.1 mini (Azure)", InputPer1k: 0.0004, OutputPer1k: 0.0016, ContextWindow: 1047576},
|
||||
{ID: "gpt-4o", Label: "GPT-4o (Azure)", InputPer1k: 0.0025, OutputPer1k: 0.010, ContextWindow: 128000},
|
||||
{ID: "gpt-4o-mini", Label: "GPT-4o mini (Azure)", InputPer1k: 0.00015, OutputPer1k: 0.0006, ContextWindow: 128000},
|
||||
{ID: "gpt-35-turbo", Label: "GPT-3.5 Turbo (Azure)", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 16385},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "bedrock_api",
|
||||
Kind: KindProvider,
|
||||
Name: "AWS Bedrock API",
|
||||
Description: "Anthropic, Meta, Cohere via Bedrock",
|
||||
DefaultHost: "bedrock-runtime.<region>.amazonaws.com",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#FF9900",
|
||||
// Anthropic models on Bedrock take the anthropic.* prefix and
|
||||
// follow the same lineup / pricing as the first-party Anthropic
|
||||
// catalog entry above. claude-3-7-sonnet and claude-3-5-haiku
|
||||
// were retired upstream on 2026-02-19 — dropped from the
|
||||
// Bedrock list too. Amazon Nova entries cross-checked against
|
||||
// LiteLLM (added Nova Micro + the new Nova 2 Lite preview).
|
||||
// Llama 3.3 70B entry kept unchanged — LiteLLM tracks only
|
||||
// per-region Llama 3 entries; standalone 3.3 not yet listed.
|
||||
Models: []Model{
|
||||
{ID: "anthropic.claude-opus-4-8", Label: "Claude Opus 4.8 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "anthropic.claude-opus-4-7", Label: "Claude Opus 4.7 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "anthropic.claude-opus-4-6", Label: "Claude Opus 4.6 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "anthropic.claude-opus-4-1", Label: "Claude Opus 4.1 (Bedrock, deprecated 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000},
|
||||
{ID: "anthropic.claude-sonnet-4-6", Label: "Claude Sonnet 4.6 (Bedrock)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000},
|
||||
{ID: "anthropic.claude-sonnet-4-5", Label: "Claude Sonnet 4.5 (Bedrock)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000},
|
||||
{ID: "anthropic.claude-haiku-4-5", Label: "Claude Haiku 4.5 (Bedrock)", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000},
|
||||
{ID: "meta.llama3-3-70b-instruct", Label: "Llama 3.3 70B (Bedrock)", InputPer1k: 0.00072, OutputPer1k: 0.00072, ContextWindow: 128000},
|
||||
{ID: "amazon.nova-2-lite", Label: "Amazon Nova 2 Lite (Bedrock, preview)", InputPer1k: 0.0003, OutputPer1k: 0.0025, ContextWindow: 1000000},
|
||||
{ID: "amazon.nova-pro", Label: "Amazon Nova Pro (Bedrock)", InputPer1k: 0.0008, OutputPer1k: 0.0032, ContextWindow: 300000},
|
||||
{ID: "amazon.nova-lite", Label: "Amazon Nova Lite (Bedrock)", InputPer1k: 0.00006, OutputPer1k: 0.00024, ContextWindow: 300000},
|
||||
{ID: "amazon.nova-micro", Label: "Amazon Nova Micro (Bedrock)", InputPer1k: 0.000035, OutputPer1k: 0.00014, ContextWindow: 128000},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "vertex_ai_api",
|
||||
Kind: KindProvider,
|
||||
Name: "Google Vertex AI API",
|
||||
Description: "Anthropic Claude models hosted on Vertex AI",
|
||||
DefaultHost: "<region>-aiplatform.googleapis.com",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#4285F4",
|
||||
// Vertex carries the model in the URL path and authenticates with a
|
||||
// service-account-minted OAuth token (api_key = "keyfile::<base64 SA>").
|
||||
// Only Anthropic-on-Vertex is metered today: the request parser maps the
|
||||
// anthropic publisher to the Anthropic parser, so the lineup + prices
|
||||
// mirror the first-party Anthropic catalog (LiteLLM vertex_ai/claude-*
|
||||
// confirms the same per-token rates; cross-region profiles in eu/apac
|
||||
// carry a ~10% premium that base pricing does not model). Gemini (the
|
||||
// google publisher) is intentionally omitted until a Gemini parser
|
||||
// exists — the router denies unmeterable publishers rather than forward
|
||||
// them uncounted.
|
||||
Models: []Model{
|
||||
{ID: "claude-fable-5", Label: "Claude Fable 5 (Vertex)", InputPer1k: 0.010, OutputPer1k: 0.050, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-8", Label: "Claude Opus 4.8 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-7", Label: "Claude Opus 4.7 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-6", Label: "Claude Opus 4.6 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-1", Label: "Claude Opus 4.1 (Vertex, deprecated 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000},
|
||||
{ID: "claude-sonnet-4-6", Label: "Claude Sonnet 4.6 (Vertex)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000},
|
||||
{ID: "claude-sonnet-4-5", Label: "Claude Sonnet 4.5 (Vertex)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000},
|
||||
{ID: "claude-haiku-4-5", Label: "Claude Haiku 4.5 (Vertex)", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "mistral_api",
|
||||
Kind: KindProvider,
|
||||
Name: "Mistral API",
|
||||
Description: "Mistral cloud API",
|
||||
DefaultHost: "api.mistral.ai",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#FF7000",
|
||||
ParserID: "openai",
|
||||
// Pricing + context windows cross-checked against LiteLLM. Key
|
||||
// gotchas the marketing page hides:
|
||||
// - `mistral-medium-latest` aliases to Medium 3.1 ($0.40/$2),
|
||||
// NOT Medium 3.5 ($1.50/$7.50). Catalog exposes both.
|
||||
// - `mistral-large-latest` aliases to Large 3 — 262K context,
|
||||
// cheaper than Medium 3.5.
|
||||
// - Magistral models are tuned for reasoning but cap context
|
||||
// at only 40K (vs 128K-262K elsewhere).
|
||||
// - `codestral-latest` still routes to the old 2405 build
|
||||
// ($1/$3) per LiteLLM; the newer codestral-2508 is both
|
||||
// cheaper and longer-context. Both exposed.
|
||||
// - Pixtral was folded into the main Large/Medium series; no
|
||||
// standalone vision entry.
|
||||
Models: []Model{
|
||||
{ID: "mistral-large-latest", Label: "Mistral Large 3", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 262144},
|
||||
{ID: "mistral-medium-latest", Label: "Mistral Medium 3.1", InputPer1k: 0.0004, OutputPer1k: 0.002, ContextWindow: 131072},
|
||||
{ID: "mistral-medium-3-5", Label: "Mistral Medium 3.5", InputPer1k: 0.0015, OutputPer1k: 0.0075, ContextWindow: 262144},
|
||||
{ID: "mistral-small-latest", Label: "Mistral Small 3.2", InputPer1k: 0.00006, OutputPer1k: 0.00018, ContextWindow: 131072},
|
||||
{ID: "magistral-medium-latest", Label: "Magistral Medium (reasoning)", InputPer1k: 0.002, OutputPer1k: 0.005, ContextWindow: 40000},
|
||||
{ID: "magistral-small-latest", Label: "Magistral Small (reasoning)", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 40000},
|
||||
{ID: "devstral-medium-latest", Label: "Devstral Medium 2 (coding)", InputPer1k: 0.0004, OutputPer1k: 0.002, ContextWindow: 256000},
|
||||
{ID: "devstral-small-latest", Label: "Devstral Small 2 (coding)", InputPer1k: 0.0001, OutputPer1k: 0.0003, ContextWindow: 256000},
|
||||
{ID: "codestral-2508", Label: "Codestral 2508", InputPer1k: 0.0003, OutputPer1k: 0.0009, ContextWindow: 256000},
|
||||
{ID: "codestral-latest", Label: "Codestral (legacy 2405)", InputPer1k: 0.001, OutputPer1k: 0.003, ContextWindow: 32000},
|
||||
{ID: "ministral-3-14b-2512", Label: "Ministral 3 14B", InputPer1k: 0.0002, OutputPer1k: 0.0002, ContextWindow: 262144},
|
||||
{ID: "ministral-8b-latest", Label: "Ministral 8B", InputPer1k: 0.00015, OutputPer1k: 0.00015, ContextWindow: 262144},
|
||||
{ID: "ministral-3-3b-2512", Label: "Ministral 3 3B", InputPer1k: 0.0001, OutputPer1k: 0.0001, ContextWindow: 131072},
|
||||
{ID: "mistral-embed", Label: "Mistral Embed", InputPer1k: 0.0001, OutputPer1k: 0, ContextWindow: 8192},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "litellm_proxy",
|
||||
Kind: KindGateway,
|
||||
Name: "LiteLLM Proxy",
|
||||
Description: "Bring your own LiteLLM proxy with NetBird identity stamped on every request",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#0EA5E9",
|
||||
ParserID: "openai",
|
||||
// IdentityInjection requires a LiteLLM virtual key minted with
|
||||
// metadata.allow_client_tags=true; the master key silently drops
|
||||
// caller tags. Tags go out via both the x-litellm-tags header and
|
||||
// body metadata.tags: LiteLLM enforces budgets from the body only,
|
||||
// so the header is the spend-tracking fallback when body injection
|
||||
// can't run. See the Agent Network provider docs for key setup.
|
||||
IdentityInjection: &IdentityInjection{
|
||||
HeaderPair: &HeaderPairInjection{
|
||||
EndUserIDHeader: "x-litellm-end-user-id",
|
||||
TagsHeader: "x-litellm-tags",
|
||||
TagsInBody: true,
|
||||
EndUserIDInBody: true,
|
||||
},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "portkey",
|
||||
Kind: KindGateway,
|
||||
Name: "Portkey AI Gateway",
|
||||
Description: "Portkey AI Gateway with NetBird identity stamped via x-portkey-metadata",
|
||||
DefaultHost: "api.portkey.ai",
|
||||
// Portkey hosted requires x-portkey-api-key (account key)
|
||||
// plus a routing decision per request. The simplest routing
|
||||
// path is a saved Portkey config id stamped via
|
||||
// x-portkey-config — operators paste the pc-... id once and
|
||||
// Portkey resolves the upstream provider + virtual key from
|
||||
// it. ExtraHeaders below surfaces the input. Alternative:
|
||||
// callers author "@org/model" in the body; both flows
|
||||
// coexist (per-request authoring still works without a
|
||||
// configured value).
|
||||
AuthHeaderName: "x-portkey-api-key",
|
||||
AuthHeaderTemplate: "${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#FF5C00",
|
||||
ParserID: "openai",
|
||||
IdentityInjection: &IdentityInjection{
|
||||
JSONMetadata: &JSONMetadataInjection{
|
||||
Header: "x-portkey-metadata",
|
||||
UserKey: "_user",
|
||||
GroupsKey: "groups",
|
||||
MaxValueLength: 128,
|
||||
},
|
||||
},
|
||||
ExtraHeaders: []ExtraHeader{
|
||||
{Name: "x-portkey-config"},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "bifrost",
|
||||
Kind: KindGateway,
|
||||
Name: "Bifrost",
|
||||
Description: "Maxim AI's Bifrost gateway. Point upstream URL at /openai/v1 or /anthropic/v1 on your Bifrost host depending on which body shape your apps use.",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#7C3AED",
|
||||
// ParserID empty: the proxy's request parser sniffs the URL
|
||||
// path. Bifrost's /openai/v1/... contains "/v1/chat/completions"
|
||||
// (matches OpenAIParser.DetectFromURL); /anthropic/v1/messages
|
||||
// contains "/v1/messages" (matches AnthropicParser). Operators
|
||||
// who paste a different prefix get no usage parsing and the
|
||||
// cost meter skips with skipMissingProvider — degraded but
|
||||
// non-fatal.
|
||||
ParserID: "",
|
||||
// Identity-injection headers are operator-customisable. The
|
||||
// HeaderPair values below are PLACEHOLDERS surfaced by the
|
||||
// dashboard; the actual values stamped on the wire come from
|
||||
// the provider record's IdentityHeaderUserID /
|
||||
// IdentityHeaderGroups fields. An empty operator value
|
||||
// disables stamping for that dimension (the inject middleware
|
||||
// already no-ops on empty header names). Defaulting to the
|
||||
// x-bf-dim- family so the values land in Bifrost's
|
||||
// Prometheus/OTEL pipelines when the operator declares the
|
||||
// label names in their client.prometheus_labels config — see
|
||||
// docs.getbifrost.ai/features/telemetry. Operators who use
|
||||
// the always-on x-bf-lh- log-metadata family (no Bifrost-side
|
||||
// declaration required) just edit the inputs.
|
||||
//
|
||||
// Bifrost virtual keys (sk-bf-*) ride Authorization: Bearer.
|
||||
// Operators provision the VK on their Bifrost (UI /
|
||||
// config.json / POST /api/governance/virtual-keys) and paste
|
||||
// the returned sk-bf-... as ${API_KEY}. Pin v1.4+ to avoid
|
||||
// the v1.3.0 x-bf-vk regression (maximhq/bifrost#632).
|
||||
IdentityInjection: &IdentityInjection{
|
||||
HeaderPair: &HeaderPairInjection{
|
||||
EndUserIDHeader: "x-bf-dim-netbird_user_id",
|
||||
TagsHeader: "x-bf-dim-netbird_groups",
|
||||
Customizable: true,
|
||||
},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "cloudflare_ai_gateway",
|
||||
Kind: KindGateway,
|
||||
Name: "Cloudflare AI Gateway",
|
||||
Description: "Cloudflare AI Gateway. Operator pastes the gateway URL (with the upstream provider slug like /openai or /anthropic so the URL sniffer dispatches to the right parser) and a per-gateway authentication token. Recommended setup is BYOK / Stored Keys: Cloudflare manages the upstream provider credential and the gateway token is the only secret NetBird needs.",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "cf-aig-authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#F38020",
|
||||
// ParserID empty: like Bifrost, the proxy's parser-detect
|
||||
// sniffs the URL path. /openai/... contains the OpenAI hint
|
||||
// substrings; /anthropic/v1/messages contains /v1/messages
|
||||
// (matches AnthropicParser). The /compat universal endpoint
|
||||
// also speaks OpenAI shape so OpenAIParser handles it.
|
||||
// Operators who paste a different prefix degrade to no-cost
|
||||
// (skipMissingProvider) but the request still flows.
|
||||
ParserID: "",
|
||||
// cf-aig-metadata is a single header carrying a JSON object;
|
||||
// up to five string/number/boolean values per request. NetBird
|
||||
// occupies two slots (user id + groups CSV) and leaves three
|
||||
// for operator-added context. JSON keys are operator-
|
||||
// customisable so Cloudflare-side log filters can use the
|
||||
// operator's existing label conventions instead of NetBird's
|
||||
// defaults — hence Customizable=true. The dashboard surfaces
|
||||
// the catalog values as placeholders; only the values stored
|
||||
// on the provider record's IdentityHeader* fields land on the
|
||||
// wire (empty operator value = key is omitted from the JSON,
|
||||
// since applyJSONMetadata already skips empty keys).
|
||||
IdentityInjection: &IdentityInjection{
|
||||
JSONMetadata: &JSONMetadataInjection{
|
||||
Header: "cf-aig-metadata",
|
||||
UserKey: "netbird_user_id",
|
||||
GroupsKey: "netbird_groups",
|
||||
Customizable: true,
|
||||
// Cloudflare's docs don't specify a per-value cap;
|
||||
// leaving 0 disables the truncate path. Header-level
|
||||
// constraint is "5 entries max" rather than length.
|
||||
MaxValueLength: 0,
|
||||
},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "vercel_ai_gateway",
|
||||
Kind: KindGateway,
|
||||
Name: "Vercel AI Gateway",
|
||||
Description: "Vercel's unified API for hundreds of models. Single endpoint, OpenAI-compatible body, model dispatch via prefix (openai/..., anthropic/..., google/..., xai/...). Per-user / per-tag attribution lands in Vercel's Custom Reporting API and observability dashboard.",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#000000",
|
||||
// Vercel always speaks OpenAI shape on /v1/chat/completions —
|
||||
// the model prefix in the body picks the upstream provider.
|
||||
// No URL sniffing needed; pin the parser directly.
|
||||
ParserID: "openai",
|
||||
// HeaderPair shape with fixed wire names dictated by Vercel's
|
||||
// Custom Reporting API contract. Customizable=false because
|
||||
// renaming the headers makes Vercel silently stop attributing
|
||||
// — the gateway's reporting endpoint only matches its own
|
||||
// header names. Same fixed-protocol position as LiteLLM.
|
||||
//
|
||||
// Caveats operators should know:
|
||||
// - up to 10 tags total per request (deduped); 11+ → HTTP 400
|
||||
// - each tag must be 1-64 chars
|
||||
// - user up to 256 chars (NetBird user emails fit)
|
||||
// - $0.075 per 1k unique user/tag values written
|
||||
// We don't enforce the caps in the inject middleware today;
|
||||
// operators in groups beyond the 10-tag limit will see Vercel
|
||||
// 400s and need to re-scope their group memberships.
|
||||
IdentityInjection: &IdentityInjection{
|
||||
HeaderPair: &HeaderPairInjection{
|
||||
EndUserIDHeader: "ai-reporting-user",
|
||||
TagsHeader: "ai-reporting-tags",
|
||||
},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "openrouter",
|
||||
Kind: KindGateway,
|
||||
Name: "OpenRouter",
|
||||
Description: "OpenRouter's unified API for hundreds of models. Single endpoint at openrouter.ai/api/v1, OpenAI-compatible body, model dispatch via prefix (anthropic/claude-..., openai/gpt-..., google/gemini-..., etc.). Per-user attribution lands in OpenRouter's analytics via the OpenAI-standard `user` body field; OpenRouter has no groups / tags dimension at request time.",
|
||||
DefaultHost: "openrouter.ai/api/v1",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#6F4FF2",
|
||||
// OpenRouter is single-endpoint OpenAI-shape on /api/v1/chat/completions —
|
||||
// model prefix in the body picks the upstream provider.
|
||||
// Pinning the parser saves URL sniffing.
|
||||
ParserID: "openai",
|
||||
// HeaderPair shape with EndUserIDInBody as the only active
|
||||
// dimension. OpenRouter's per-user attribution is the
|
||||
// OpenAI-standard `user` body field, not a header — and
|
||||
// OpenRouter offers no per-request groups / tags dimension at
|
||||
// all. Customizable=false because the field name is locked by
|
||||
// OpenAI's spec; renaming would just defeat the inject.
|
||||
IdentityInjection: &IdentityInjection{
|
||||
HeaderPair: &HeaderPairInjection{
|
||||
EndUserIDInBody: true,
|
||||
},
|
||||
},
|
||||
// HTTP-Referer + X-OpenRouter-Title surface in OpenRouter's
|
||||
// app rankings and per-app analytics. Operators paste their
|
||||
// own app URL + display name on the provider record so their
|
||||
// requests show under their brand instead of "no app". Both
|
||||
// are static per-deployment, not per-request, hence the
|
||||
// ExtraHeaders mechanism (operator-typed value, stamped on
|
||||
// every request to this provider). Skip X-OpenRouter-Categories
|
||||
// for now — the marketplace-categories dimension is
|
||||
// niche-enough that we'd add it on demand.
|
||||
ExtraHeaders: []ExtraHeader{
|
||||
{Name: "HTTP-Referer"},
|
||||
{Name: "X-OpenRouter-Title"},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "custom",
|
||||
Kind: KindCustom,
|
||||
Name: "Custom / Self-hosted",
|
||||
Description: "OpenAI-compatible endpoint (vLLM, Ollama, …)",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#9CA3AF",
|
||||
Models: []Model{},
|
||||
},
|
||||
}
|
||||
|
||||
// All returns a copy of the full catalog.
|
||||
func All() []Provider {
|
||||
out := make([]Provider, len(providers))
|
||||
copy(out, providers)
|
||||
return out
|
||||
}
|
||||
|
||||
// Lookup returns the catalog entry with the given id, if any.
|
||||
func Lookup(id string) (Provider, bool) {
|
||||
for _, p := range providers {
|
||||
if p.ID == id {
|
||||
return p, true
|
||||
}
|
||||
}
|
||||
return Provider{}, false
|
||||
}
|
||||
|
||||
// IsKnown reports whether the given id refers to a catalog entry.
|
||||
func IsKnown(id string) bool {
|
||||
_, ok := Lookup(id)
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsVertexPathStyle reports whether a provider uses the Google Vertex AI
|
||||
// request shape — the model is carried in the URL path
|
||||
// (/v1/projects/{p}/locations/{r}/publishers/{pub}/models/{model}:{action})
|
||||
// rather than the body, so the proxy routes it by path instead of by model.
|
||||
func IsVertexPathStyle(providerID string) bool {
|
||||
return providerID == "vertex_ai_api"
|
||||
}
|
||||
|
||||
// IsBedrockPathStyle reports whether a provider uses the AWS Bedrock request
|
||||
// shape — the model is carried in the URL path (/model/{modelId}/{action},
|
||||
// action being invoke, invoke-with-response-stream, converse, or
|
||||
// converse-stream) rather than the body, so the proxy routes it by path.
|
||||
func IsBedrockPathStyle(providerID string) bool {
|
||||
return providerID == "bedrock_api"
|
||||
}
|
||||
|
||||
// ToAPIResponse renders a catalog provider as the API representation.
|
||||
func (p Provider) ToAPIResponse() api.AgentNetworkCatalogProvider {
|
||||
models := make([]api.AgentNetworkCatalogModel, 0, len(p.Models))
|
||||
for _, m := range p.Models {
|
||||
models = append(models, api.AgentNetworkCatalogModel{
|
||||
Id: m.ID,
|
||||
Label: m.Label,
|
||||
InputPer1k: m.InputPer1k,
|
||||
OutputPer1k: m.OutputPer1k,
|
||||
ContextWindow: m.ContextWindow,
|
||||
})
|
||||
}
|
||||
kind := api.AgentNetworkCatalogProviderKindProvider
|
||||
switch p.Kind {
|
||||
case KindGateway:
|
||||
kind = api.AgentNetworkCatalogProviderKindGateway
|
||||
case KindCustom:
|
||||
kind = api.AgentNetworkCatalogProviderKindCustom
|
||||
}
|
||||
resp := api.AgentNetworkCatalogProvider{
|
||||
Id: p.ID,
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
DefaultHost: p.DefaultHost,
|
||||
Kind: kind,
|
||||
AuthHeaderTemplate: p.AuthHeaderTemplate,
|
||||
DefaultContentType: p.DefaultContentType,
|
||||
BrandColor: p.BrandColor,
|
||||
Models: models,
|
||||
}
|
||||
if len(p.ExtraHeaders) > 0 {
|
||||
extras := make([]api.AgentNetworkCatalogExtraHeader, 0, len(p.ExtraHeaders))
|
||||
for _, h := range p.ExtraHeaders {
|
||||
extras = append(extras, api.AgentNetworkCatalogExtraHeader{
|
||||
Name: h.Name,
|
||||
})
|
||||
}
|
||||
resp.ExtraHeaders = &extras
|
||||
}
|
||||
// Surface IdentityInjection so the dashboard can decide whether
|
||||
// to render editable inputs vs. a read-only mappings strip per
|
||||
// shape's customizable flag. HeaderPair (Bifrost) and
|
||||
// JSONMetadata (Cloudflare, Portkey) are mutually exclusive on a
|
||||
// given catalog entry; emit whichever shape is set.
|
||||
if p.IdentityInjection != nil {
|
||||
injection := &api.AgentNetworkCatalogIdentityInjection{}
|
||||
if hp := p.IdentityInjection.HeaderPair; hp != nil {
|
||||
injection.HeaderPair = &api.AgentNetworkCatalogHeaderPairInjection{
|
||||
Customizable: hp.Customizable,
|
||||
EndUserIdHeader: hp.EndUserIDHeader,
|
||||
TagsHeader: hp.TagsHeader,
|
||||
}
|
||||
}
|
||||
if jm := p.IdentityInjection.JSONMetadata; jm != nil {
|
||||
injection.JsonMetadata = &api.AgentNetworkCatalogJSONMetadataInjection{
|
||||
Customizable: jm.Customizable,
|
||||
Header: jm.Header,
|
||||
UserKey: jm.UserKey,
|
||||
GroupsKey: jm.GroupsKey,
|
||||
}
|
||||
}
|
||||
if injection.HeaderPair != nil || injection.JsonMetadata != nil {
|
||||
resp.IdentityInjection = injection
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
// addAccessLogEndpoints registers the read-only, server-side-filtered
|
||||
// agent-network access-log listing and the aggregated usage overview.
|
||||
func (h *handler) addAccessLogEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/access-logs", h.listAccessLogs).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/usage/overview", h.getUsageOverview).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getUsageOverview(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
// Reuse the access-log filter for the shared date/user/group/provider/model
|
||||
// params; pagination/sort/search are irrelevant for an aggregate.
|
||||
var filter types.AgentNetworkAccessLogFilter
|
||||
if err := filter.ParseFromRequest(r); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
// Bound the aggregation window so an unbounded or over-wide query can't load
|
||||
// an account's entire usage history into memory.
|
||||
filter.ApplyUsageOverviewBounds(time.Now())
|
||||
granularity := types.ParseUsageGranularity(r.URL.Query().Get("granularity"))
|
||||
|
||||
buckets, err := h.manager.GetUsageOverview(r.Context(), userAuth.AccountId, userAuth.UserId, filter, granularity)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]api.AgentNetworkUsageBucket, 0, len(buckets))
|
||||
for _, b := range buckets {
|
||||
out = append(out, b.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) listAccessLogs(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var filter types.AgentNetworkAccessLogFilter
|
||||
if err := filter.ParseFromRequest(r); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rows, total, err := h.manager.ListAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId, filter)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
data := make([]api.AgentNetworkAccessLog, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
data = append(data, row.ToAPIResponse())
|
||||
}
|
||||
|
||||
pageSize := filter.GetLimit()
|
||||
totalPages := 0
|
||||
if pageSize > 0 {
|
||||
totalPages = int((total + int64(pageSize) - 1) / int64(pageSize))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, api.AgentNetworkAccessLogsResponse{
|
||||
Data: data,
|
||||
Page: filter.Page,
|
||||
PageSize: pageSize,
|
||||
TotalRecords: int(total),
|
||||
TotalPages: totalPages,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// addBudgetRuleEndpoints registers the account-level budget rule routes.
|
||||
func (h *handler) addBudgetRuleEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/budget-rules", h.getAllBudgetRules).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/budget-rules", h.createBudgetRule).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.getBudgetRule).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.updateBudgetRule).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.deleteBudgetRule).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllBudgetRules(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rules, err := h.manager.GetAllBudgetRules(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*api.AgentNetworkBudgetRule, 0, len(rules))
|
||||
for _, rule := range rules {
|
||||
out = append(out, rule.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getBudgetRule(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ruleID := mux.Vars(r)["ruleId"]
|
||||
if ruleID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
rule, err := h.manager.GetBudgetRule(r.Context(), userAuth.AccountId, userAuth.UserId, ruleID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, rule.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) createBudgetRule(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkBudgetRuleRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateBudgetRule(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rule := types.NewAccountBudgetRule(userAuth.AccountId)
|
||||
rule.FromAPIRequest(&req)
|
||||
|
||||
created, err := h.manager.CreateBudgetRule(r.Context(), userAuth.UserId, rule)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateBudgetRule(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ruleID := mux.Vars(r)["ruleId"]
|
||||
if ruleID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkBudgetRuleRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateBudgetRule(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rule := &types.AccountBudgetRule{ID: ruleID, AccountID: userAuth.AccountId}
|
||||
rule.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdateBudgetRule(r.Context(), userAuth.UserId, rule)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteBudgetRule(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ruleID := mux.Vars(r)["ruleId"]
|
||||
if ruleID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeleteBudgetRule(r.Context(), userAuth.AccountId, userAuth.UserId, ruleID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
// validateBudgetRule rejects malformed budget rules. It reuses the policy limit
|
||||
// validation since the cap shape is identical, and rejects empty target entries.
|
||||
func validateBudgetRule(req *api.AgentNetworkBudgetRuleRequest) error {
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
if req.TargetGroups != nil {
|
||||
for _, id := range *req.TargetGroups {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "target_groups must not contain empty entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.TargetUsers != nil {
|
||||
for _, id := range *req.TargetUsers {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "target_users must not contain empty entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
return validatePolicyLimits(req.Limits)
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// TestBudgetRuleHandler_RoundTrip seeds a budget rule via the store and asserts
|
||||
// the GET wire shape carries targets and the reused PolicyLimits cap shape. The
|
||||
// create/update/delete success paths go through accountManager.StoreEvent which
|
||||
// this fixture doesn't wire — they are covered by the manager-level no-mock
|
||||
// test (TestAgentNetwork_BudgetRuleCRUD_RealManager).
|
||||
func TestBudgetRuleHandler_RoundTrip(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
rule := &agentNetworkTypes.AccountBudgetRule{
|
||||
ID: "ainbud_test",
|
||||
AccountID: testAccountID,
|
||||
Name: "org-monthly",
|
||||
Enabled: true,
|
||||
TargetGroups: []string{"grp-eng"},
|
||||
TargetUsers: []string{"user-alice"},
|
||||
Limits: agentNetworkTypes.PolicyLimits{
|
||||
TokenLimit: agentNetworkTypes.PolicyTokenLimit{Enabled: true, GroupCap: 100000, UserCap: 10000, WindowSeconds: 2_592_000},
|
||||
BudgetLimit: agentNetworkTypes.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 500, WindowSeconds: 2_592_000},
|
||||
},
|
||||
}
|
||||
require.NoError(t, f.store.SaveAgentNetworkBudgetRule(context.Background(), rule))
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/budget-rules/"+rule.ID, "")
|
||||
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
|
||||
|
||||
var got api.AgentNetworkBudgetRule
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
assert.Equal(t, "org-monthly", got.Name, "name must round-trip")
|
||||
assert.Equal(t, []string{"grp-eng"}, got.TargetGroups, "target groups must round-trip")
|
||||
assert.Equal(t, []string{"user-alice"}, got.TargetUsers, "target users must round-trip")
|
||||
assert.Equal(t, int64(100000), got.Limits.TokenLimit.GroupCap, "token group cap must round-trip")
|
||||
assert.Equal(t, int64(2_592_000), got.Limits.BudgetLimit.WindowSeconds, "budget window must round-trip")
|
||||
}
|
||||
|
||||
// TestBudgetRuleHandler_ListReturnsArray asserts the list endpoint returns a
|
||||
// JSON array (never null) for an account with no rules.
|
||||
func TestBudgetRuleHandler_ListReturnsArray(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/budget-rules", "")
|
||||
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
|
||||
assert.Equal(t, "[]", trimSpace(rec.Body.String()), "empty account must return an empty array, not null")
|
||||
}
|
||||
|
||||
// TestBudgetRuleHandler_RejectsMissingName covers the validation path (which
|
||||
// runs before the manager call, so it works without a wired accountManager).
|
||||
func TestBudgetRuleHandler_RejectsMissingName(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
body := `{
|
||||
"name": "",
|
||||
"limits": {
|
||||
"token_limit": {"enabled": false, "group_cap": 0, "user_cap": 0, "window_seconds": 0},
|
||||
"budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0}
|
||||
}
|
||||
}`
|
||||
rec := f.do(t, http.MethodPost, "/agent-network/budget-rules", body)
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code,
|
||||
"missing name must be rejected as a validation error (not a route/auth 4xx): got %d body=%s", rec.Code, rec.Body.String())
|
||||
assert.Contains(t, rec.Body.String(), "name",
|
||||
"rejection body must name the offending field, proving the validation path: %s", rec.Body.String())
|
||||
}
|
||||
|
||||
// TestBudgetRuleHandler_RejectsSubMinuteWindow proves budget rules reuse the
|
||||
// policy-limit validation (enabled limit needs window >= 60s).
|
||||
func TestBudgetRuleHandler_RejectsSubMinuteWindow(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
body := `{
|
||||
"name": "bad-window",
|
||||
"limits": {
|
||||
"token_limit": {"enabled": true, "group_cap": 1000, "user_cap": 0, "window_seconds": 30},
|
||||
"budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0}
|
||||
}
|
||||
}`
|
||||
rec := f.do(t, http.MethodPost, "/agent-network/budget-rules", body)
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code,
|
||||
"sub-minute window must be rejected as a validation error (not a route/auth 4xx): got %d body=%s", rec.Code, rec.Body.String())
|
||||
assert.Contains(t, rec.Body.String(), "window_seconds",
|
||||
"rejection body must name the offending window_seconds field, proving the validation path: %s", rec.Body.String())
|
||||
}
|
||||
|
||||
// TestSettingsHandler_GetExposesCollectionToggles asserts the GET settings wire
|
||||
// shape carries the account-level collection toggles after a store seed.
|
||||
func TestSettingsHandler_GetExposesCollectionToggles(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
require.NoError(t, f.store.SaveAgentNetworkSettings(context.Background(), &agentNetworkTypes.Settings{
|
||||
AccountID: testAccountID,
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
Subdomain: "violet",
|
||||
EnableLogCollection: true,
|
||||
EnablePromptCollection: true,
|
||||
RedactPii: false,
|
||||
}))
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/settings", "")
|
||||
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
|
||||
|
||||
var got api.AgentNetworkSettings
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
assert.True(t, got.EnableLogCollection, "log collection toggle must surface on the wire")
|
||||
assert.True(t, got.EnablePromptCollection, "prompt collection toggle must surface on the wire")
|
||||
assert.False(t, got.RedactPii, "redact toggle must surface its false value")
|
||||
assert.Equal(t, "violet.eu.proxy.netbird.io", got.Endpoint, "endpoint stays computed from immutable cluster+subdomain")
|
||||
}
|
||||
|
||||
func trimSpace(s string) string {
|
||||
for len(s) > 0 && (s[len(s)-1] == '\n' || s[len(s)-1] == ' ' || s[len(s)-1] == '\t' || s[len(s)-1] == '\r') {
|
||||
s = s[:len(s)-1]
|
||||
}
|
||||
for len(s) > 0 && (s[0] == '\n' || s[0] == ' ' || s[0] == '\t' || s[0] == '\r') {
|
||||
s = s[1:]
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
// addConsumptionEndpoints registers the read-only Agent Network
|
||||
// consumption listing — backs the dashboard's basic counter view.
|
||||
func (h *handler) addConsumptionEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/consumption", h.listConsumption).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) listConsumption(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.manager.ListConsumption(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]api.AgentNetworkConsumption, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, consumptionToAPI(row))
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func consumptionToAPI(c *types.Consumption) api.AgentNetworkConsumption {
|
||||
windowStart := c.WindowStartUTC
|
||||
updatedAt := c.UpdatedAt
|
||||
return api.AgentNetworkConsumption{
|
||||
DimensionKind: api.AgentNetworkConsumptionDimensionKind(c.DimensionKind),
|
||||
DimensionId: c.DimensionID,
|
||||
WindowSeconds: c.WindowSeconds,
|
||||
WindowStartUtc: windowStart,
|
||||
TokensInput: c.TokensInput,
|
||||
TokensOutput: c.TokensOutput,
|
||||
CostUsd: c.CostUSD,
|
||||
UpdatedAt: &updatedAt,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// addGuardrailEndpoints registers all Agent Network guardrail routes.
|
||||
func (h *handler) addGuardrailEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/guardrails", h.getAllGuardrails).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/guardrails", h.createGuardrail).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.getGuardrail).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.updateGuardrail).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.deleteGuardrail).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllGuardrails(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrails, err := h.manager.GetAllGuardrails(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*api.AgentNetworkGuardrail, 0, len(guardrails))
|
||||
for _, g := range guardrails {
|
||||
out = append(out, g.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getGuardrail(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrailID := mux.Vars(r)["guardrailId"]
|
||||
if guardrailID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrail, err := h.manager.GetGuardrail(r.Context(), userAuth.AccountId, userAuth.UserId, guardrailID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, guardrail.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) createGuardrail(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkGuardrailRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateGuardrail(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrail := types.NewGuardrail(userAuth.AccountId)
|
||||
guardrail.FromAPIRequest(&req)
|
||||
|
||||
created, err := h.manager.CreateGuardrail(r.Context(), userAuth.UserId, guardrail)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateGuardrail(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrailID := mux.Vars(r)["guardrailId"]
|
||||
if guardrailID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkGuardrailRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateGuardrail(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrail := &types.Guardrail{
|
||||
ID: guardrailID,
|
||||
AccountID: userAuth.AccountId,
|
||||
}
|
||||
guardrail.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdateGuardrail(r.Context(), userAuth.UserId, guardrail)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteGuardrail(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrailID := mux.Vars(r)["guardrailId"]
|
||||
if guardrailID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeleteGuardrail(r.Context(), userAuth.AccountId, userAuth.UserId, guardrailID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func validateGuardrail(req *api.AgentNetworkGuardrailRequest) error {
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
|
||||
c := req.Checks
|
||||
if c.ModelAllowlist.Enabled {
|
||||
for _, id := range c.ModelAllowlist.Models {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "model_allowlist.models must not contain empty entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,256 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
const (
|
||||
testAccountID = "acc-1"
|
||||
testUserID = "user-bob"
|
||||
)
|
||||
|
||||
// agentNetworkHandlerFixture builds a real agentnetwork.Manager with
|
||||
// a sqlite store and an always-allow permissions mock, then exposes
|
||||
// the HTTP handlers via a gorilla router. Tests issue requests
|
||||
// through httptest and assert on the wire shape — the same path the
|
||||
// dashboard exercises.
|
||||
type agentNetworkHandlerFixture struct {
|
||||
store store.Store
|
||||
manager agentnetwork.Manager
|
||||
router *mux.Router
|
||||
}
|
||||
|
||||
func newAgentNetworkHandlerFixture(t *testing.T) *agentNetworkHandlerFixture {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("sqlite store not properly supported on Windows yet")
|
||||
}
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(nbtypes.SqliteStoreEngine))
|
||||
|
||||
st, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
perms := permissions.NewMockManager(ctrl)
|
||||
// Always-allow: the handler tests are about wire shape, not
|
||||
// authz. Authz is covered by the manager's own tests.
|
||||
perms.EXPECT().
|
||||
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(true, context.Background(), nil).
|
||||
AnyTimes()
|
||||
|
||||
manager := agentnetwork.NewManager(st, perms, nil, nil)
|
||||
h := &handler{manager: manager}
|
||||
|
||||
router := mux.NewRouter()
|
||||
h.addPolicyEndpoints(router)
|
||||
h.addConsumptionEndpoints(router)
|
||||
h.addBudgetRuleEndpoints(router)
|
||||
h.addSettingsEndpoints(router)
|
||||
|
||||
return &agentNetworkHandlerFixture{
|
||||
store: st,
|
||||
manager: manager,
|
||||
router: router,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *agentNetworkHandlerFixture) do(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
var reader io.Reader
|
||||
if body != "" {
|
||||
reader = strings.NewReader(body)
|
||||
}
|
||||
req := httptest.NewRequest(method, path, reader)
|
||||
if body != "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
f.router.ServeHTTP(rec, req)
|
||||
return rec
|
||||
}
|
||||
|
||||
// seedProvider persists a minimal provider record so policy create
|
||||
// passes the manager's destination_provider_ids existence check.
|
||||
func (f *agentNetworkHandlerFixture) seedProvider(t *testing.T, id string) {
|
||||
t.Helper()
|
||||
require.NoError(t, f.store.SaveAgentNetworkProvider(context.Background(), &agentNetworkTypes.Provider{
|
||||
ID: id,
|
||||
AccountID: testAccountID,
|
||||
ProviderID: "openai_api",
|
||||
Name: "test-" + id,
|
||||
UpstreamURL: "https://api.openai.com",
|
||||
APIKey: "sk-test",
|
||||
Enabled: true,
|
||||
SessionPrivateKey: "test-priv-key",
|
||||
SessionPublicKey: "test-pub-key",
|
||||
}))
|
||||
}
|
||||
|
||||
// TestPolicyHandler_WindowSecondsRoundTrip ports bash 10 to Go:
|
||||
// assert that a policy with window_seconds on both Token + Budget
|
||||
// halves round-trips through GET unchanged AND that legacy
|
||||
// window_hours / window_days are absent from the JSON response. We
|
||||
// seed the policy directly via the store rather than POST-ing
|
||||
// because the create path goes through the manager's
|
||||
// accountManager.StoreEvent which we don't wire in this fixture; the
|
||||
// on-wire shape is what matters here, and the POST validation path
|
||||
// is covered separately by the RejectsSubMinuteWindow test.
|
||||
func TestPolicyHandler_WindowSecondsRoundTrip(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
policy := &agentNetworkTypes.Policy{
|
||||
ID: "ainpol_test",
|
||||
AccountID: testAccountID,
|
||||
Name: "round-trip",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-engineers"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: agentNetworkTypes.PolicyLimits{
|
||||
TokenLimit: agentNetworkTypes.PolicyTokenLimit{Enabled: true, GroupCap: 10000, UserCap: 5000, WindowSeconds: 86_400},
|
||||
BudgetLimit: agentNetworkTypes.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 10.0, UserCapUsd: 2.5, WindowSeconds: 2_592_000},
|
||||
},
|
||||
}
|
||||
require.NoError(t, f.store.SaveAgentNetworkPolicy(context.Background(), policy))
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/policies/"+policy.ID, "")
|
||||
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
|
||||
|
||||
var got api.AgentNetworkPolicy
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
assert.Equal(t, int64(86_400), got.Limits.TokenLimit.WindowSeconds, "token_limit.window_seconds must round-trip")
|
||||
assert.Equal(t, int64(2_592_000), got.Limits.BudgetLimit.WindowSeconds, "budget_limit.window_seconds must round-trip")
|
||||
|
||||
// Legacy field names must NOT appear in the response — would
|
||||
// signal that the management server is still emitting the old
|
||||
// shape and would fool a v1 dashboard into rendering days/hours.
|
||||
assert.NotContains(t, rec.Body.String(), "window_hours",
|
||||
"legacy window_hours field must be absent from the on-wire response")
|
||||
assert.NotContains(t, rec.Body.String(), "window_days",
|
||||
"legacy window_days field must be absent from the on-wire response")
|
||||
}
|
||||
|
||||
// TestPolicyHandler_RejectsSubMinuteWindow ports bash 20 to Go: an
|
||||
// enabled limit with window_seconds < 60 must surface as a 4xx
|
||||
// because anything finer than per-minute produces an untenable
|
||||
// volume of consumption rows for a feature whose value comes from
|
||||
// per-window cap enforcement.
|
||||
func TestPolicyHandler_RejectsSubMinuteWindow(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
f.seedProvider(t, "prov-1")
|
||||
|
||||
body := `{
|
||||
"name": "sub-minute-window",
|
||||
"enabled": true,
|
||||
"source_groups": ["grp-engineers"],
|
||||
"destination_provider_ids": ["prov-1"],
|
||||
"guardrail_ids": [],
|
||||
"limits": {
|
||||
"token_limit": {"enabled": true, "group_cap": 10000, "user_cap": 5000, "window_seconds": 30},
|
||||
"budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0}
|
||||
}
|
||||
}`
|
||||
rec := f.do(t, http.MethodPost, "/agent-network/policies", body)
|
||||
// 422 specifically (InvalidArgument) proves the window-validation path —
|
||||
// a route miss would be 404 and an auth failure 403, so a generic 4xx
|
||||
// would let those false-pass.
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code,
|
||||
"enabled token_limit with window_seconds<60 must be rejected as a validation error: got %d body=%s", rec.Code, rec.Body.String())
|
||||
assert.Contains(t, rec.Body.String(), "window_seconds",
|
||||
"rejection body must name the offending window_seconds field, proving it's the validation path: %s", rec.Body.String())
|
||||
}
|
||||
|
||||
// TestConsumptionHandler_EmptyAccountReturnsArray ports bash 30 to
|
||||
// Go: GET /agent-network/consumption on a clean account always
|
||||
// returns a JSON array (possibly empty), never a 404 / 500. The
|
||||
// dashboard depends on this shape to render its empty state.
|
||||
func TestConsumptionHandler_EmptyAccountReturnsArray(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/consumption", "")
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var rows []api.AgentNetworkConsumption
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rows),
|
||||
"response must always be a JSON array — even when empty: %s", rec.Body.String())
|
||||
assert.Empty(t, rows)
|
||||
}
|
||||
|
||||
// TestConsumptionHandler_PopulatedAccountListsRows mirrors the
|
||||
// /consumption read after a few RecordConsumption calls. Validates
|
||||
// the wire shape carries every field the dashboard reads (dim_kind,
|
||||
// dim_id, window_seconds, window_start_utc, tokens, cost_usd) and
|
||||
// rows are ordered window-newest-first.
|
||||
func TestConsumptionHandler_PopulatedAccountListsRows(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
require.NoError(t, f.manager.RecordConsumption(
|
||||
context.Background(), testAccountID,
|
||||
agentNetworkTypes.DimensionGroup, "grp-engineers",
|
||||
86_400, 100, 50, 0.0125,
|
||||
))
|
||||
require.NoError(t, f.manager.RecordConsumption(
|
||||
context.Background(), testAccountID,
|
||||
agentNetworkTypes.DimensionUser, testUserID,
|
||||
86_400, 100, 50, 0.0125,
|
||||
))
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/consumption", "")
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var rows []api.AgentNetworkConsumption
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rows))
|
||||
require.Len(t, rows, 2, "two RecordConsumption calls must yield two rows")
|
||||
|
||||
// Index by dim_kind so we can assert the full wire shape of each row,
|
||||
// including the dimension id and the aligned window start the dashboard
|
||||
// keys on. Both rows share totals and window.
|
||||
byKind := make(map[string]api.AgentNetworkConsumption, len(rows))
|
||||
for _, row := range rows {
|
||||
assert.Equal(t, int64(100), row.TokensInput)
|
||||
assert.Equal(t, int64(50), row.TokensOutput)
|
||||
assert.InDelta(t, 0.0125, row.CostUsd, 1e-9)
|
||||
assert.Equal(t, int64(86_400), row.WindowSeconds)
|
||||
assert.False(t, row.WindowStartUtc.IsZero(), "window_start_utc must be set on every row")
|
||||
byKind[string(row.DimensionKind)] = row
|
||||
}
|
||||
|
||||
groupRow, ok := byKind["group"]
|
||||
require.True(t, ok, "group dimension must surface")
|
||||
assert.Equal(t, "grp-engineers", groupRow.DimensionId, "group row must carry the source group id as dimension_id")
|
||||
|
||||
userRow, ok := byKind["user"]
|
||||
require.True(t, ok, "user dimension must surface")
|
||||
assert.Equal(t, testUserID, userRow.DimensionId, "user row must carry the user id as dimension_id")
|
||||
|
||||
// Both rows fall in the same aligned window (same length, recorded
|
||||
// together), so window_start_utc must match across them.
|
||||
assert.Equal(t, groupRow.WindowStartUtc, userRow.WindowStartUtc,
|
||||
"rows recorded in the same window must share the aligned window_start_utc")
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// minWindowSeconds is the floor enforced on enabled token / budget
|
||||
// limit windows. One minute is short enough for fine-grained burst
|
||||
// control without producing untenable consumption-row volume at scale.
|
||||
const minWindowSeconds int64 = 60
|
||||
|
||||
// addPolicyEndpoints registers all Agent Network policy routes on the
|
||||
// shared handler.
|
||||
func (h *handler) addPolicyEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/policies", h.getAllPolicies).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/policies", h.createPolicy).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/policies/{policyId}", h.getPolicy).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/policies/{policyId}", h.updatePolicy).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/policies/{policyId}", h.deletePolicy).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policies, err := h.manager.GetAllPolicies(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*api.AgentNetworkPolicy, 0, len(policies))
|
||||
for _, p := range policies {
|
||||
out = append(out, p.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policyID := mux.Vars(r)["policyId"]
|
||||
if policyID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
policy, err := h.manager.GetPolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policyID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, policy.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkPolicyRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validatePolicy(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policy := types.NewPolicy(userAuth.AccountId)
|
||||
policy.FromAPIRequest(&req)
|
||||
|
||||
created, err := h.manager.CreatePolicy(r.Context(), userAuth.UserId, policy)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policyID := mux.Vars(r)["policyId"]
|
||||
if policyID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkPolicyRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validatePolicy(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policy := &types.Policy{
|
||||
ID: policyID,
|
||||
AccountID: userAuth.AccountId,
|
||||
}
|
||||
policy.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdatePolicy(r.Context(), userAuth.UserId, policy)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policyID := mux.Vars(r)["policyId"]
|
||||
if policyID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeletePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policyID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func validatePolicy(req *api.AgentNetworkPolicyRequest) error {
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
if len(req.SourceGroups) == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "source_groups must contain at least one group id")
|
||||
}
|
||||
for _, id := range req.SourceGroups {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "source_groups must not contain empty entries")
|
||||
}
|
||||
}
|
||||
if len(req.DestinationProviderIds) == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "destination_provider_ids must contain at least one provider id")
|
||||
}
|
||||
for _, id := range req.DestinationProviderIds {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "destination_provider_ids must not contain empty entries")
|
||||
}
|
||||
}
|
||||
if req.GuardrailIds != nil {
|
||||
for _, id := range *req.GuardrailIds {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "guardrail_ids must not contain empty entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.Limits != nil {
|
||||
if err := validatePolicyLimits(*req.Limits); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePolicyLimits(l api.AgentNetworkPolicyLimits) error {
|
||||
if l.TokenLimit.Enabled {
|
||||
if l.TokenLimit.WindowSeconds < minWindowSeconds {
|
||||
return status.Errorf(status.InvalidArgument, "limits.token_limit.window_seconds must be at least %d (one minute) when enabled", minWindowSeconds)
|
||||
}
|
||||
if l.TokenLimit.GroupCap < 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.token_limit.group_cap must not be negative")
|
||||
}
|
||||
if l.TokenLimit.UserCap < 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.token_limit.user_cap must not be negative")
|
||||
}
|
||||
if l.TokenLimit.GroupCap == 0 && l.TokenLimit.UserCap == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.token_limit requires group_cap or user_cap to be greater than zero when enabled")
|
||||
}
|
||||
}
|
||||
if l.BudgetLimit.Enabled {
|
||||
if l.BudgetLimit.WindowSeconds < minWindowSeconds {
|
||||
return status.Errorf(status.InvalidArgument, "limits.budget_limit.window_seconds must be at least %d (one minute) when enabled", minWindowSeconds)
|
||||
}
|
||||
if l.BudgetLimit.GroupCapUsd < 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.budget_limit.group_cap_usd must not be negative")
|
||||
}
|
||||
if l.BudgetLimit.UserCapUsd < 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.budget_limit.user_cap_usd must not be negative")
|
||||
}
|
||||
if l.BudgetLimit.GroupCapUsd == 0 && l.BudgetLimit.UserCapUsd == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.budget_limit requires group_cap_usd or user_cap_usd to be greater than zero when enabled")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
// Package handlers serves the Agent Network HTTP API.
|
||||
//
|
||||
// All persistence is delegated to agentnetwork.Manager so this layer only
|
||||
// translates between the wire format (api.AgentNetworkProvider*) and the
|
||||
// domain types.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/catalog"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
manager agentnetwork.Manager
|
||||
}
|
||||
|
||||
// RegisterEndpoints registers all Agent Network routes.
|
||||
func RegisterEndpoints(manager agentnetwork.Manager, router *mux.Router) {
|
||||
h := &handler{manager: manager}
|
||||
router.HandleFunc("/agent-network/catalog/providers", h.getCatalogProviders).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers", h.getAllProviders).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers", h.createProvider).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers/{providerId}", h.getProvider).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers/{providerId}", h.updateProvider).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers/{providerId}", h.deleteProvider).Methods("DELETE", "OPTIONS")
|
||||
h.addPolicyEndpoints(router)
|
||||
h.addGuardrailEndpoints(router)
|
||||
h.addSettingsEndpoints(router)
|
||||
h.addConsumptionEndpoints(router)
|
||||
h.addAccessLogEndpoints(router)
|
||||
h.addBudgetRuleEndpoints(router)
|
||||
}
|
||||
|
||||
func (h *handler) getCatalogProviders(w http.ResponseWriter, r *http.Request) {
|
||||
if _, err := nbcontext.GetUserAuthFromContext(r.Context()); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
entries := catalog.All()
|
||||
out := make([]api.AgentNetworkCatalogProvider, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
out = append(out, e.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getAllProviders(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.manager.GetAllProviders(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*api.AgentNetworkProvider, 0, len(providers))
|
||||
for _, p := range providers {
|
||||
out = append(out, p.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
providerID := mux.Vars(r)["providerId"]
|
||||
if providerID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.manager.GetProvider(r.Context(), userAuth.AccountId, userAuth.UserId, providerID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, provider.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) createProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkProviderRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validate(&req, true); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
provider := types.NewProvider(userAuth.AccountId)
|
||||
provider.FromAPIRequest(&req)
|
||||
|
||||
bootstrapCluster := ""
|
||||
if req.BootstrapCluster != nil {
|
||||
bootstrapCluster = *req.BootstrapCluster
|
||||
}
|
||||
|
||||
created, err := h.manager.CreateProvider(r.Context(), userAuth.UserId, provider, bootstrapCluster)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
providerID := mux.Vars(r)["providerId"]
|
||||
if providerID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkProviderRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validate(&req, false); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
provider := &types.Provider{
|
||||
ID: providerID,
|
||||
AccountID: userAuth.AccountId,
|
||||
}
|
||||
provider.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdateProvider(r.Context(), userAuth.UserId, provider)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
providerID := mux.Vars(r)["providerId"]
|
||||
if providerID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeleteProvider(r.Context(), userAuth.AccountId, userAuth.UserId, providerID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func validate(req *api.AgentNetworkProviderRequest, requireAPIKey bool) error {
|
||||
if strings.TrimSpace(req.ProviderId) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "provider_id is required")
|
||||
}
|
||||
if !catalog.IsKnown(req.ProviderId) {
|
||||
return status.Errorf(status.InvalidArgument, "provider_id %q is not a known catalog provider", req.ProviderId)
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
if strings.TrimSpace(req.UpstreamUrl) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "upstream_url is required")
|
||||
}
|
||||
u, err := url.Parse(strings.TrimSpace(req.UpstreamUrl))
|
||||
if err != nil || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") {
|
||||
return status.Errorf(status.InvalidArgument, "upstream_url must be a full http(s) URL")
|
||||
}
|
||||
if requireAPIKey && (req.ApiKey == nil || strings.TrimSpace(*req.ApiKey) == "") {
|
||||
return status.Errorf(status.InvalidArgument, "api_key is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// addSettingsEndpoints registers the Agent Network settings routes. The
|
||||
// settings row is bootstrapped server-side on first provider create; GET reads
|
||||
// it and PUT updates the mutable collection toggles (cluster/subdomain stay
|
||||
// immutable).
|
||||
func (h *handler) addSettingsEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/settings", h.getSettings).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/settings", h.updateSettings).Methods("PUT", "OPTIONS")
|
||||
}
|
||||
|
||||
// updateSettings applies the collection toggles to the account's settings row.
|
||||
func (h *handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkSettingsRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings := &types.Settings{AccountID: userAuth.AccountId}
|
||||
settings.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdateSettings(r.Context(), userAuth.UserId, settings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
// getSettings returns the account's agent-network settings. The settings
|
||||
// row is bootstrapped on first provider create, so freshly-onboarded
|
||||
// accounts have nothing to read. Rather than 404-ing in that case (which
|
||||
// the dashboard would have to special-case), return a JSON null with 200
|
||||
// so consumers can branch on the body alone.
|
||||
func (h *handler) getSettings(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.manager.GetSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
var sErr *status.Error
|
||||
if errors.As(err, &sErr) && sErr.Type() == status.NotFound {
|
||||
util.WriteJSONObject(r.Context(), w, nil)
|
||||
return
|
||||
}
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, settings.ToAPIResponse())
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
// Package labelgen produces DNS-safe Agent Network subdomain labels.
|
||||
package labelgen
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// pickAttempts caps the random retries before falling back to the
|
||||
// suffixed form. Eight is a soft compromise: with a near-empty taken
|
||||
// set the very first pick almost always succeeds; when the wordlist is
|
||||
// densely populated the fallback eventually fires anyway.
|
||||
const pickAttempts = 8
|
||||
|
||||
var (
|
||||
dedupOnce sync.Once
|
||||
uniqWords []string
|
||||
)
|
||||
|
||||
// uniqueWords returns the wordlist deduplicated and sorted for
|
||||
// deterministic exhaustion behaviour. Lazy-built once per process.
|
||||
func uniqueWords() []string {
|
||||
dedupOnce.Do(func() {
|
||||
seen := make(map[string]struct{}, len(words))
|
||||
uniqWords = make([]string, 0, len(words))
|
||||
for _, w := range words {
|
||||
if _, ok := seen[w]; ok {
|
||||
continue
|
||||
}
|
||||
seen[w] = struct{}{}
|
||||
uniqWords = append(uniqWords, w)
|
||||
}
|
||||
sort.Strings(uniqWords)
|
||||
})
|
||||
return uniqWords
|
||||
}
|
||||
|
||||
// PickUnique selects a label not already in `taken`. It tries up to
|
||||
// pickAttempts random picks; on exhaustion it scans the deduplicated
|
||||
// wordlist for any remaining free entry, and if none is left appends
|
||||
// `-<fallbackSuffix>` to a deterministic word and returns. The caller
|
||||
// is responsible for seeding rng (math/rand).
|
||||
func PickUnique(rng *rand.Rand, taken map[string]struct{}, fallbackSuffix string) string {
|
||||
pool := uniqueWords()
|
||||
if len(pool) == 0 {
|
||||
return fallbackSuffix
|
||||
}
|
||||
|
||||
for i := 0; i < pickAttempts; i++ {
|
||||
w := pool[rng.Intn(len(pool))]
|
||||
if _, ok := taken[w]; !ok {
|
||||
return w
|
||||
}
|
||||
}
|
||||
|
||||
for _, w := range pool {
|
||||
if _, ok := taken[w]; !ok {
|
||||
return w
|
||||
}
|
||||
}
|
||||
|
||||
w := pool[rng.Intn(len(pool))]
|
||||
return fmt.Sprintf("%s-%s", w, fallbackSuffix)
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package labelgen
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestPickUnique_DeterministicWithSeededRng locks the property the
|
||||
// caller relies on: same seed + same taken set → same pick. Without
|
||||
// that, the bootstrap flow can't reproduce a label across retries.
|
||||
func TestPickUnique_DeterministicWithSeededRng(t *testing.T) {
|
||||
taken := map[string]struct{}{}
|
||||
|
||||
rngA := rand.New(rand.NewSource(42))
|
||||
rngB := rand.New(rand.NewSource(42))
|
||||
|
||||
a := PickUnique(rngA, taken, "abcd")
|
||||
b := PickUnique(rngB, taken, "abcd")
|
||||
|
||||
assert.Equal(t, a, b, "Same seed and taken set must produce identical pick")
|
||||
}
|
||||
|
||||
// TestPickUnique_AvoidsTakenWordsWhenMostAreReserved seeds taken with
|
||||
// every word in the pool except a handful and confirms PickUnique
|
||||
// finds one of the remaining free entries instead of returning the
|
||||
// fallback form.
|
||||
func TestPickUnique_AvoidsTakenWordsWhenMostAreReserved(t *testing.T) {
|
||||
pool := uniqueWords()
|
||||
require.NotEmpty(t, pool, "wordlist must be populated for the test to mean anything")
|
||||
|
||||
free := map[string]struct{}{
|
||||
pool[0]: {},
|
||||
pool[len(pool)/2]: {},
|
||||
pool[len(pool)-1]: {},
|
||||
}
|
||||
|
||||
taken := make(map[string]struct{}, len(pool))
|
||||
for _, w := range pool {
|
||||
if _, ok := free[w]; ok {
|
||||
continue
|
||||
}
|
||||
taken[w] = struct{}{}
|
||||
}
|
||||
|
||||
rng := rand.New(rand.NewSource(7))
|
||||
got := PickUnique(rng, taken, "abcd")
|
||||
|
||||
_, isFree := free[got]
|
||||
assert.True(t, isFree, "PickUnique must return one of the free words; got %q", got)
|
||||
assert.NotContains(t, got, "-", "Free pick must not be the suffix fallback form")
|
||||
}
|
||||
|
||||
// TestPickUnique_FallsBackWhenAllReserved exhausts the pool and
|
||||
// confirms PickUnique appends the supplied suffix instead of
|
||||
// returning a duplicate.
|
||||
func TestPickUnique_FallsBackWhenAllReserved(t *testing.T) {
|
||||
pool := uniqueWords()
|
||||
|
||||
taken := make(map[string]struct{}, len(pool))
|
||||
for _, w := range pool {
|
||||
taken[w] = struct{}{}
|
||||
}
|
||||
|
||||
rng := rand.New(rand.NewSource(99))
|
||||
got := PickUnique(rng, taken, "abcd")
|
||||
|
||||
assert.True(t, strings.HasSuffix(got, "-abcd"), "Exhausted pool must produce <word>-<suffix>; got %q", got)
|
||||
|
||||
prefix := strings.TrimSuffix(got, "-abcd")
|
||||
found := false
|
||||
for _, w := range pool {
|
||||
if w == prefix {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Fallback prefix must be drawn from the wordlist; got %q", prefix)
|
||||
}
|
||||
|
||||
// TestUniqueWords_DropsDuplicates guards against authoring slips in
|
||||
// words.go: every entry must be unique and DNS-safe.
|
||||
func TestUniqueWords_DropsDuplicates(t *testing.T) {
|
||||
pool := uniqueWords()
|
||||
seen := make(map[string]struct{}, len(pool))
|
||||
for _, w := range pool {
|
||||
_, dup := seen[w]
|
||||
assert.False(t, dup, "Duplicate entry %q in deduplicated pool", w)
|
||||
seen[w] = struct{}{}
|
||||
assert.GreaterOrEqual(t, len(w), 4, "Word %q is shorter than 4 chars", w)
|
||||
assert.LessOrEqual(t, len(w), 12, "Word %q is longer than 12 chars", w)
|
||||
for _, r := range w {
|
||||
ok := r >= 'a' && r <= 'z'
|
||||
assert.True(t, ok, "Word %q contains non-lowercase-ASCII rune %q", w, r)
|
||||
}
|
||||
}
|
||||
assert.GreaterOrEqual(t, len(pool), 500, "Pool must contain at least 500 unique words")
|
||||
}
|
||||
136
management/internals/modules/agentnetwork/labelgen/words.go
Normal file
136
management/internals/modules/agentnetwork/labelgen/words.go
Normal file
@@ -0,0 +1,136 @@
|
||||
// Package labelgen produces DNS-safe Agent Network subdomain labels.
|
||||
//
|
||||
// The wordlist below is a curated subset drawn from public-domain
|
||||
// nature / common-noun pools (e.g. EFF's diceware lists). Every entry
|
||||
// is lowercase ASCII, 4–12 chars, no hyphens, no digits, and was
|
||||
// hand-checked to avoid offensive, brand, or region-specific terms.
|
||||
package labelgen
|
||||
|
||||
// words is the pool PickUnique selects from. The slice is intentionally
|
||||
// not sorted — random picks distribute across the list naturally.
|
||||
var words = []string{
|
||||
"acorn", "adobe", "agate", "alder", "almond", "alpine", "amber", "amethyst",
|
||||
"anchor", "antler", "apple", "apricot", "arcade", "arctic", "arrow", "ashen",
|
||||
"aspen", "atlas", "atom", "aurora", "autumn", "azure",
|
||||
"badger", "bamboo", "banana", "banjo", "barley", "barn", "basalt", "basil",
|
||||
"basin", "bayou", "beach", "beacon", "beaver", "beech", "beetle", "berry",
|
||||
"birch", "bison", "blossom", "blue", "bobcat", "bonsai", "boulder", "branch",
|
||||
"brass", "breeze", "bridge", "bright", "brook", "broom", "brown", "buffalo",
|
||||
"bumble", "burrow", "butter", "button",
|
||||
"cabin", "cactus", "calm", "camel", "campfire", "canary", "candle", "canoe",
|
||||
"canyon", "cardinal", "carrot", "cascade", "castle", "cedar", "celery", "cello",
|
||||
"cement", "cherry", "chestnut", "chime", "cinnamon", "cinder", "citron", "clay",
|
||||
"clear", "cliff", "clock", "cloud", "clover", "coast", "cobalt", "cobble",
|
||||
"cocoa", "coffee", "comet", "compass", "copper", "coral", "corner", "cosmos",
|
||||
"cotton", "cougar", "country", "coyote", "cove", "crane", "crater", "creek",
|
||||
"crescent", "crimson", "crocus", "crystal", "cypress",
|
||||
"daffodil", "dahlia", "daisy", "dawn", "deer", "delta", "denim", "desert",
|
||||
"dewdrop", "diamond", "dolphin", "doodle", "dove", "dragon", "drift", "drop",
|
||||
"dune", "dusk", "dusty",
|
||||
"eagle", "earth", "echo", "elder", "elkhorn", "ember", "emerald", "emperor",
|
||||
"evergreen", "evening",
|
||||
"falcon", "fawn", "feather", "fern", "fiddle", "field", "fiesta", "finch",
|
||||
"firepit", "firefly", "fjord", "flame", "flax", "fleece", "flint", "floral",
|
||||
"flower", "flute", "foal", "foggy", "forest", "fountain", "foxglove", "fresh",
|
||||
"frost", "fuchsia", "fudge",
|
||||
"gable", "galaxy", "garden", "garnet", "gazelle", "geode", "geyser", "ginger",
|
||||
"glacier", "glade", "glass", "glow", "gold", "goose", "gorge", "gourd",
|
||||
"granite", "grape", "grass", "gravel", "grayling", "greenery", "grizzly", "grove",
|
||||
"gull", "gumdrop", "gust",
|
||||
"hammock", "harbor", "harvest", "hawk", "hazel", "heather", "hedge", "heron",
|
||||
"hibiscus", "hickory", "hideaway", "highland", "hill", "hive", "hollow", "honey",
|
||||
"hopper", "horizon", "hummingbird", "husky",
|
||||
"iceberg", "indigo", "iris", "island", "ivory", "ivybush",
|
||||
"jade", "jasmine", "jasper", "jaybird", "jelly", "jewel", "jonquil", "journey",
|
||||
"juniper", "jupiter", "jute",
|
||||
"kale", "kangaroo", "kayak", "kelp", "kestrel", "kettle", "khaki", "kindling",
|
||||
"kingfisher", "kiwi", "knapweed", "koala",
|
||||
"lagoon", "lake", "lantern", "larch", "lark", "laurel", "lava", "lavender",
|
||||
"leaf", "lemon", "lichen", "light", "lilac", "lily", "lime", "limestone",
|
||||
"linden", "linen", "lion", "lobster", "locust", "loon", "lotus", "lumber",
|
||||
"lunar", "lupine", "lynx",
|
||||
"madrone", "magenta", "magnolia", "mahogany", "mallow", "mango", "manor", "maple",
|
||||
"marble", "marigold", "marina", "marlin", "marsh", "mauve", "meadow", "melody",
|
||||
"melon", "merlin", "metal", "midnight", "milk", "millet", "mineral", "mint",
|
||||
"mirror", "mist", "mitten", "molasses", "moon", "moose", "morning", "moss",
|
||||
"mountain", "mulberry", "muscat", "mustard",
|
||||
"narwhal", "navy", "nectar", "needle", "nest", "nettle", "newt", "nightfall",
|
||||
"noon", "nook", "north", "nova", "nutmeg",
|
||||
"oaken", "oasis", "oatmeal", "ocean", "ochre", "octagon", "olive", "onyx",
|
||||
"opal", "orange", "orbit", "orchard", "orchid", "oregano", "orion", "osprey",
|
||||
"otter", "outpost", "owlet", "oyster",
|
||||
"painter", "palace", "palm", "pansy", "panther", "papaya", "paprika", "parsley",
|
||||
"partridge", "passage", "pastel", "patio", "peach", "peacock", "pear", "pearl",
|
||||
"pebble", "pecan", "pelican", "penguin", "peony", "pepper", "perch", "peridot",
|
||||
"pewter", "phoenix", "pier", "pillar", "pine", "pineapple", "pinto", "piper",
|
||||
"pistachio", "plain", "planet", "plateau", "platinum", "plum", "plume", "polar",
|
||||
"pollen", "pond", "poplar", "poppy", "porcelain", "portal", "portrait", "potato",
|
||||
"prairie", "primrose", "prism", "puffin", "pumpkin",
|
||||
"quail", "quartz", "quaver", "quill", "quince", "quinoa",
|
||||
"rabbit", "raccoon", "radish", "rain", "rainbow", "raindrop", "rapids", "raspberry",
|
||||
"raven", "ravine", "redwood", "reed", "reef", "ridge", "river", "robin",
|
||||
"rocket", "rubyred", "rose", "rosemary", "rosewood", "ruffle", "rugby", "russet",
|
||||
"rustic", "ryefield",
|
||||
"saffron", "sage", "salmon", "sand", "sandstone", "sapphire", "savanna", "scarlet",
|
||||
"scout", "seal", "season", "seaweed", "sequoia", "shadow", "shamrock", "shell",
|
||||
"sherbet", "shore", "silver", "siskin", "skybloom", "skyline", "sleet", "smoke",
|
||||
"snail", "snapdragon", "snow", "snowflake", "snowy", "solar", "song", "sonic",
|
||||
"sorrel", "south", "sparkle", "sparrow", "spice", "spider", "spinach", "spire",
|
||||
"spring", "sprout", "spruce", "squirrel", "starfish", "starlight", "stoat", "stone",
|
||||
"stork", "storm", "stream", "studio", "summer", "sunbeam", "sundew", "sunny",
|
||||
"sunrise", "sunset", "swallow", "swan", "sweet", "sycamore",
|
||||
"tangelo", "tangerine", "tansy", "taupe", "teak", "teal", "thicket", "thistle",
|
||||
"thrush", "thunder", "tide", "tiger", "tinder", "topaz", "torch", "tortoise",
|
||||
"tower", "trail", "tranquil", "tundra", "tulip", "turquoise", "turtle", "twig",
|
||||
"twilight",
|
||||
"umber", "uplands",
|
||||
"valley", "vanilla", "velvet", "venus", "verdant", "verdigris", "vermilion", "violet",
|
||||
"vista", "vivid", "volcano", "vortex",
|
||||
"walnut", "warbler", "watercress", "waterfall", "wave", "waxwing", "weasel", "westwind",
|
||||
"whale", "whisker", "whisper", "wicker", "wildwood", "willow", "winter", "wisp",
|
||||
"wisteria", "wolf", "wombat", "woodland", "woolly", "wren", "wreath",
|
||||
"yarrow", "yellow", "yewtree", "yodel",
|
||||
"zebra", "zenith", "zephyr", "zinnia",
|
||||
"alabaster", "alfalfa", "almanac", "anise", "antelope", "arbor", "arena", "armadillo",
|
||||
"avocet", "azalea", "balsam", "bayou", "beacon", "blizzard", "bluebell", "bluebird",
|
||||
"bluejay", "bobolink", "borage", "boreal", "buckeye", "buckthorn", "buttercup",
|
||||
"cabana", "calico", "canopy", "caraway", "cardamom", "cattail", "celadon", "centaur",
|
||||
"chambray", "chamois", "champlain", "chestnuts", "chickadee", "chinook", "chipmunk", "cinnabar",
|
||||
"cirrus", "citrine", "clematis", "copperhead",
|
||||
"crocodile", "currant", "cuttlebone", "daffy", "dapple", "delphinium", "dervish", "diamondback",
|
||||
"dogwood", "dolphins", "dragonfly", "driftwood", "dusk", "dustpan", "ebony", "edelweiss",
|
||||
"emperor", "endive", "estuary", "everglade", "fairway", "feldspar", "fennel", "fieldstone",
|
||||
"firebrand", "firefly", "fireweed", "firework", "flagstone", "fossil", "frostbite", "galleon",
|
||||
"gardener", "geranium", "gingko", "ginseng", "goldfish", "goldfinch", "goldenrod", "graphite",
|
||||
"greenfinch", "guppy", "haiku", "halibut", "hammerhead", "harbinger", "harvest", "hatchling",
|
||||
"havana", "hawthorn", "hazelnut", "heartwood", "henna", "heron", "highrise", "homestead",
|
||||
"honeycomb", "honeydew", "horseshoe", "hyacinth", "iceland", "icicle", "indigobird", "ironwood",
|
||||
"jacaranda", "jamboree", "javelina", "jellyfish", "junebug", "kaleido", "kayaker", "kerchief",
|
||||
"keystone", "kingdom", "labrador", "lacewing", "ladybug", "lakeside", "lamplight", "leopard",
|
||||
"lighthouse", "lilypad", "lullaby", "magnet", "mahonia", "mandolin", "manzanita", "maraschino",
|
||||
"mariner", "marsupial", "mastodon", "matterhorn", "mayflower", "mayfly", "meadowlark", "merlot",
|
||||
"meteor", "midshipman", "millpond", "mimosa", "minnow", "mockingbird", "molten", "monarch",
|
||||
"monsoon", "moondust", "moonlight", "moorland", "morning", "mossland", "mountain", "mulch",
|
||||
"narcissus", "nautilus", "nettlebush", "northstar", "nuthatch", "obsidian", "okra", "olivine",
|
||||
"opalescent", "orchidea", "orchard", "ornament", "outrigger", "oxalis", "paddler", "paintbrush",
|
||||
"papyrus", "paradise", "pasture", "patchwork", "pathway", "peridot", "periwinkle", "petalbloom",
|
||||
"petrel", "petunia", "phlox", "pikeperch", "pinecone", "pioneer", "pipevine", "platypus",
|
||||
"pomelo", "pondweed", "porpoise", "powder", "promise", "puddle", "pumice", "puzzle",
|
||||
"quetzal", "quicksilver", "raccoon", "ragwort", "rainforest", "ramble", "rapid", "rascal",
|
||||
"raspberry", "redbud", "redfern", "redpoll", "reedling", "ringtail", "riverbed", "riverbird",
|
||||
"riverstone", "rockcress", "roebuck", "rosebay", "rosehip", "rosemary", "rowan", "rumble",
|
||||
"runaway", "rustler", "sagebrush", "sailcloth", "salamander", "salsify", "samphire", "sandbar",
|
||||
"sanddollar", "sandpiper", "santolina", "sapodilla", "sassafras", "scallion", "schooner", "seafoam",
|
||||
"seafrost", "seagrass", "seahorse", "seaport", "seashell", "seaspray", "shamble", "shimmer",
|
||||
"shoreline", "silkmoth", "silverfox", "skylark", "snapdragon", "snowberry", "snowdrop", "snowfall",
|
||||
"snowmelt", "softwood", "songbird", "sorghum", "southwind", "speedwell", "spinnaker", "spruce",
|
||||
"starlight", "starling", "stormcloud", "summit", "sundance", "sundew", "sundial", "sunflower",
|
||||
"surface", "swallowtail", "sweetcorn", "sycamore", "tabletop", "tamarack", "tamarind", "tangerine",
|
||||
"tarragon", "telescope", "thicket", "thrasher", "thunder", "thyme", "tideline", "timberland",
|
||||
"tinderbox", "topiary", "torchwood", "totem", "tradewind", "treasure", "tremolo", "trinket",
|
||||
"trumpetvine", "tugboat", "tundra", "turnstone", "underbrush", "vagabond", "valerian", "vanilla",
|
||||
"velveteen", "vermilion", "vinca", "vineyard", "violet", "voyager", "wagonwheel", "walnutwood",
|
||||
"watermark", "watershed", "waterway", "wavefront", "westerly", "whaleback", "whetstone", "wicker",
|
||||
"wildbloom", "wildflower", "wilderness", "windsong", "windward", "winterberry", "woodbine", "woodfern",
|
||||
"woodland", "woodthrush", "woolgrass", "yellowfin", "zenithal", "zucchini",
|
||||
}
|
||||
896
management/internals/modules/agentnetwork/manager.go
Normal file
896
management/internals/modules/agentnetwork/manager.go
Normal file
@@ -0,0 +1,896 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/labelgen"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// ensureSessionKeys mints an ed25519 session keypair on the provider
|
||||
// when one is missing. Idempotent: skips when both fields are already
|
||||
// populated (e.g. update or migrated rows). The keys are used by the
|
||||
// synthesised reverse-proxy service to sign / verify session JWTs
|
||||
// after a successful OIDC handshake.
|
||||
func ensureSessionKeys(p *types.Provider) error {
|
||||
if p.SessionPrivateKey != "" && p.SessionPublicKey != "" {
|
||||
return nil
|
||||
}
|
||||
pair, err := sessionkey.GenerateKeyPair()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate provider session keys: %w", err)
|
||||
}
|
||||
p.SessionPrivateKey = pair.PrivateKey
|
||||
p.SessionPublicKey = pair.PublicKey
|
||||
return nil
|
||||
}
|
||||
|
||||
// Manager governs the lifecycle of Agent Network providers and policies.
|
||||
type Manager interface {
|
||||
GetAllProviders(ctx context.Context, accountID, userID string) ([]*types.Provider, error)
|
||||
GetProvider(ctx context.Context, accountID, userID, providerID string) (*types.Provider, error)
|
||||
CreateProvider(ctx context.Context, userID string, provider *types.Provider, bootstrapCluster string) (*types.Provider, error)
|
||||
UpdateProvider(ctx context.Context, userID string, provider *types.Provider) (*types.Provider, error)
|
||||
DeleteProvider(ctx context.Context, accountID, userID, providerID string) error
|
||||
|
||||
GetAllPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
||||
GetPolicy(ctx context.Context, accountID, userID, policyID string) (*types.Policy, error)
|
||||
CreatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error)
|
||||
UpdatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error)
|
||||
DeletePolicy(ctx context.Context, accountID, userID, policyID string) error
|
||||
|
||||
GetAllGuardrails(ctx context.Context, accountID, userID string) ([]*types.Guardrail, error)
|
||||
GetGuardrail(ctx context.Context, accountID, userID, guardrailID string) (*types.Guardrail, error)
|
||||
CreateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error)
|
||||
UpdateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error)
|
||||
DeleteGuardrail(ctx context.Context, accountID, userID, guardrailID string) error
|
||||
|
||||
GetAllBudgetRules(ctx context.Context, accountID, userID string) ([]*types.AccountBudgetRule, error)
|
||||
GetBudgetRule(ctx context.Context, accountID, userID, ruleID string) (*types.AccountBudgetRule, error)
|
||||
CreateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error)
|
||||
UpdateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error)
|
||||
DeleteBudgetRule(ctx context.Context, accountID, userID, ruleID string) error
|
||||
|
||||
GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error)
|
||||
UpdateSettings(ctx context.Context, userID string, settings *types.Settings) (*types.Settings, error)
|
||||
|
||||
ListConsumption(ctx context.Context, accountID, userID string) ([]*types.Consumption, error)
|
||||
ListAccessLogs(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error)
|
||||
GetUsageOverview(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter, granularity types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error)
|
||||
StartAccessLogCleanup(ctx context.Context, cleanupIntervalHours int)
|
||||
RecordConsumption(ctx context.Context, accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds, tokensIn, tokensOut int64, costUSD float64) error
|
||||
RecordAccountBudgetUsage(ctx context.Context, accountID, userID string, groupIDs []string, tokensIn, tokensOut int64, costUSD float64) error
|
||||
RecordUsage(ctx context.Context, in RecordUsageInput) error
|
||||
SelectPolicyForRequest(ctx context.Context, in PolicySelectionInput) (*PolicySelectionResult, error)
|
||||
}
|
||||
|
||||
// PolicySelectionInput is the per-request selection envelope. The
|
||||
// proxy populates it from CapturedData (account, user, groups) plus
|
||||
// the provider llm_router resolved.
|
||||
type PolicySelectionInput struct {
|
||||
AccountID string
|
||||
UserID string
|
||||
GroupIDs []string
|
||||
ProviderID string
|
||||
}
|
||||
|
||||
// PolicySelectionResult names the policy that "pays" for this request
|
||||
// plus the deny envelope when every applicable policy has exhausted
|
||||
// every cap. AttributionGroupID is the lowest group id (string sort)
|
||||
// of caller_groups ∩ selected_policy.source_groups; empty when no
|
||||
// group dimension applies. WindowSeconds is the chosen policy's
|
||||
// effective window length in seconds (token_limit's wins when both
|
||||
// halves are enabled with mismatched windows; budget_limit's
|
||||
// otherwise; 0 when no caps are configured at all).
|
||||
type PolicySelectionResult struct {
|
||||
Allow bool
|
||||
SelectedPolicyID string
|
||||
AttributionGroupID string
|
||||
WindowSeconds int64
|
||||
DenyCode string
|
||||
DenyReason string
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
proxyController proxy.Controller
|
||||
|
||||
// reconcileCache holds the last set of synthesised proxy mappings
|
||||
// per account so reconcile can emit precise Create/Update/Delete
|
||||
// updates instead of a full re-push on every mutation. Keyed by
|
||||
// accountID, then by synthesised service ID.
|
||||
reconcileMu sync.Mutex
|
||||
reconcileCache map[string]map[string]*proto.ProxyMapping
|
||||
|
||||
// labelRngMu guards labelRng. PickUnique consumes math/rand.Source
|
||||
// state; concurrent provider creates would otherwise race.
|
||||
labelRngMu sync.Mutex
|
||||
labelRng *rand.Rand
|
||||
}
|
||||
|
||||
// NewManager constructs the persistent Agent Network manager. The
|
||||
// manager persists provider/policy/guardrail configuration and, on
|
||||
// every mutation, reconciles the in-memory synthesised reverse-proxy
|
||||
// services with the proxy cluster via proxyController. Pass nil for
|
||||
// proxyController to disable the reconcile push (useful in tests).
|
||||
func NewManager(
|
||||
store store.Store,
|
||||
permissionsManager permissions.Manager,
|
||||
accountManager account.Manager,
|
||||
proxyController proxy.Controller,
|
||||
) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
proxyController: proxyController,
|
||||
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
|
||||
labelRng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllProviders(ctx context.Context, accountID, userID string) ([]*types.Provider, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetProvider(ctx context.Context, accountID, userID, providerID string) (*types.Provider, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthNone, accountID, providerID)
|
||||
}
|
||||
|
||||
// CreateProvider persists a new provider for the account. bootstrapCluster
|
||||
// is used only when the per-account agent-network Settings row hasn't
|
||||
// been created yet; otherwise it is ignored (the cluster is pinned on
|
||||
// Settings and every provider in the account routes through it).
|
||||
func (m *managerImpl) CreateProvider(ctx context.Context, userID string, provider *types.Provider, bootstrapCluster string) (*types.Provider, error) {
|
||||
if err := m.requirePermission(ctx, provider.AccountID, userID, operations.Create); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// An empty api_key would silently produce a synthesised service
|
||||
// that 401s on every upstream request. Surface the misconfiguration
|
||||
// at create time instead.
|
||||
if strings.TrimSpace(provider.APIKey) == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "api_key is required when creating an agent network provider")
|
||||
}
|
||||
|
||||
if provider.ID == "" {
|
||||
fresh := types.NewProvider(provider.AccountID)
|
||||
provider.ID = fresh.ID
|
||||
provider.CreatedAt = fresh.CreatedAt
|
||||
provider.UpdatedAt = fresh.UpdatedAt
|
||||
}
|
||||
|
||||
if err := ensureSessionKeys(provider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.store.SaveAgentNetworkProvider(ctx, provider); err != nil {
|
||||
return nil, fmt.Errorf("save agent network provider: %w", err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(bootstrapCluster) != "" {
|
||||
if _, err := m.bootstrapSettingsIfNeeded(ctx, provider.AccountID, bootstrapCluster); err != nil {
|
||||
// The provider create has already succeeded; logging the
|
||||
// bootstrap miss matches the plan's PoC behaviour. The synth
|
||||
// path treats a missing settings row as a no-op, and the next
|
||||
// provider create retries the bootstrap.
|
||||
log.WithContext(ctx).Debugf("agent-network bootstrap settings for account %s on cluster %s: %v", provider.AccountID, bootstrapCluster, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, provider.ID, provider.AccountID, activity.AgentNetworkProviderCreated, provider.EventMeta())
|
||||
m.reconcile(ctx, provider.AccountID)
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateProvider(ctx context.Context, userID string, provider *types.Provider) (*types.Provider, error) {
|
||||
if err := m.requirePermission(ctx, provider.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthUpdate, provider.AccountID, provider.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get agent network provider: %w", err)
|
||||
}
|
||||
|
||||
// Preserve the API key if the caller didn't rotate it. A
|
||||
// whitespace-only value is treated as "not rotated" rather than a
|
||||
// real key, but it must not silently overwrite a valid stored key.
|
||||
if provider.APIKey == "" {
|
||||
provider.APIKey = existing.APIKey
|
||||
} else if strings.TrimSpace(provider.APIKey) == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "api_key must be non-blank when rotating an agent network provider")
|
||||
}
|
||||
// Always preserve the session keypair across updates so existing
|
||||
// session cookies stay valid. The keys are server-managed and
|
||||
// never surfaced through the API.
|
||||
provider.SessionPrivateKey = existing.SessionPrivateKey
|
||||
provider.SessionPublicKey = existing.SessionPublicKey
|
||||
if err := ensureSessionKeys(provider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
provider.CreatedAt = existing.CreatedAt
|
||||
provider.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkProvider(ctx, provider); err != nil {
|
||||
return nil, fmt.Errorf("save agent network provider: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, provider.ID, provider.AccountID, activity.AgentNetworkProviderUpdated, provider.EventMeta())
|
||||
m.reconcile(ctx, provider.AccountID)
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteProvider(ctx context.Context, accountID, userID, providerID string) error {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
provider, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthUpdate, accountID, providerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get agent network provider: %w", err)
|
||||
}
|
||||
|
||||
// Refuse to delete while any policy still references this provider.
|
||||
// The operator must detach it first.
|
||||
policies, err := m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get agent network policies: %w", err)
|
||||
}
|
||||
var blocking []string
|
||||
for _, p := range policies {
|
||||
if slices.Contains(p.DestinationProviderIDs, providerID) {
|
||||
blocking = append(blocking, p.Name)
|
||||
}
|
||||
}
|
||||
if len(blocking) > 0 {
|
||||
return status.Errorf(
|
||||
status.InvalidArgument,
|
||||
"provider is in use by %d %s (%s); detach it before deleting",
|
||||
len(blocking),
|
||||
pluralize(len(blocking), "policy", "policies"),
|
||||
strings.Join(blocking, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
if err := m.store.DeleteAgentNetworkProvider(ctx, accountID, providerID); err != nil {
|
||||
return fmt.Errorf("failed to delete agent network provider: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, providerID, accountID, activity.AgentNetworkProviderDeleted, provider.EventMeta())
|
||||
m.reconcile(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func pluralize(n int, singular, plural string) string {
|
||||
if n == 1 {
|
||||
return singular
|
||||
}
|
||||
return plural
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPolicy(ctx context.Context, accountID, userID, policyID string) (*types.Policy, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error) {
|
||||
if err := m.requirePermission(ctx, policy.AccountID, userID, operations.Create); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if policy.ID == "" {
|
||||
fresh := types.NewPolicy(policy.AccountID)
|
||||
policy.ID = fresh.ID
|
||||
policy.CreatedAt = fresh.CreatedAt
|
||||
policy.UpdatedAt = fresh.UpdatedAt
|
||||
}
|
||||
|
||||
if err := m.validateProviderRefs(ctx, policy.AccountID, policy.DestinationProviderIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.store.SaveAgentNetworkPolicy(ctx, policy); err != nil {
|
||||
return nil, fmt.Errorf("failed to save agent network policy: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, policy.ID, policy.AccountID, activity.AgentNetworkPolicyCreated, policy.EventMeta())
|
||||
m.reconcile(ctx, policy.AccountID)
|
||||
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error) {
|
||||
if err := m.requirePermission(ctx, policy.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthUpdate, policy.AccountID, policy.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get agent network policy: %w", err)
|
||||
}
|
||||
|
||||
if err := m.validateProviderRefs(ctx, policy.AccountID, policy.DestinationProviderIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
policy.CreatedAt = existing.CreatedAt
|
||||
policy.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkPolicy(ctx, policy); err != nil {
|
||||
return nil, fmt.Errorf("failed to save agent network policy: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, policy.ID, policy.AccountID, activity.AgentNetworkPolicyUpdated, policy.EventMeta())
|
||||
m.reconcile(ctx, policy.AccountID)
|
||||
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeletePolicy(ctx context.Context, accountID, userID, policyID string) error {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
policy, err := m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get agent network policy: %w", err)
|
||||
}
|
||||
|
||||
if err := m.store.DeleteAgentNetworkPolicy(ctx, accountID, policyID); err != nil {
|
||||
return fmt.Errorf("failed to delete agent network policy: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, policyID, accountID, activity.AgentNetworkPolicyDeleted, policy.EventMeta())
|
||||
m.reconcile(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllGuardrails(ctx context.Context, accountID, userID string) ([]*types.Guardrail, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetGuardrail(ctx context.Context, accountID, userID, guardrailID string) (*types.Guardrail, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthNone, accountID, guardrailID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error) {
|
||||
if err := m.requirePermission(ctx, guardrail.AccountID, userID, operations.Create); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if guardrail.ID == "" {
|
||||
fresh := types.NewGuardrail(guardrail.AccountID)
|
||||
guardrail.ID = fresh.ID
|
||||
guardrail.CreatedAt = fresh.CreatedAt
|
||||
guardrail.UpdatedAt = fresh.UpdatedAt
|
||||
}
|
||||
|
||||
if err := m.store.SaveAgentNetworkGuardrail(ctx, guardrail); err != nil {
|
||||
return nil, fmt.Errorf("failed to save agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, guardrail.ID, guardrail.AccountID, activity.AgentNetworkGuardrailCreated, guardrail.EventMeta())
|
||||
m.reconcile(ctx, guardrail.AccountID)
|
||||
|
||||
return guardrail, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error) {
|
||||
if err := m.requirePermission(ctx, guardrail.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthUpdate, guardrail.AccountID, guardrail.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
guardrail.CreatedAt = existing.CreatedAt
|
||||
guardrail.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkGuardrail(ctx, guardrail); err != nil {
|
||||
return nil, fmt.Errorf("failed to save agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, guardrail.ID, guardrail.AccountID, activity.AgentNetworkGuardrailUpdated, guardrail.EventMeta())
|
||||
m.reconcile(ctx, guardrail.AccountID)
|
||||
|
||||
return guardrail, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteGuardrail(ctx context.Context, accountID, userID, guardrailID string) error {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
guardrail, err := m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthUpdate, accountID, guardrailID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
if err := m.store.DeleteAgentNetworkGuardrail(ctx, accountID, guardrailID); err != nil {
|
||||
return fmt.Errorf("failed to delete agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, guardrailID, accountID, activity.AgentNetworkGuardrailDeleted, guardrail.EventMeta())
|
||||
m.reconcile(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllBudgetRules returns every account-level budget rule for the account.
|
||||
func (m *managerImpl) GetAllBudgetRules(ctx context.Context, accountID, userID string) ([]*types.AccountBudgetRule, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// GetBudgetRule returns a single account-level budget rule.
|
||||
func (m *managerImpl) GetBudgetRule(ctx context.Context, accountID, userID, ruleID string) (*types.AccountBudgetRule, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthNone, accountID, ruleID)
|
||||
}
|
||||
|
||||
// CreateBudgetRule persists a new account-level budget rule. Budget rules are
|
||||
// enforced at request time (CheckLLMPolicyLimits), not baked into the synth
|
||||
// proxy config, so no reconcile is needed.
|
||||
func (m *managerImpl) CreateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
|
||||
if err := m.requirePermission(ctx, rule.AccountID, userID, operations.Create); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rule.ID == "" {
|
||||
fresh := types.NewAccountBudgetRule(rule.AccountID)
|
||||
rule.ID = fresh.ID
|
||||
rule.CreatedAt = fresh.CreatedAt
|
||||
rule.UpdatedAt = fresh.UpdatedAt
|
||||
}
|
||||
|
||||
if err := m.store.SaveAgentNetworkBudgetRule(ctx, rule); err != nil {
|
||||
return nil, fmt.Errorf("save agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, rule.ID, rule.AccountID, activity.AgentNetworkBudgetRuleCreated, rule.EventMeta())
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// UpdateBudgetRule updates an existing account-level budget rule.
|
||||
func (m *managerImpl) UpdateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
|
||||
if err := m.requirePermission(ctx, rule.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthUpdate, rule.AccountID, rule.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
rule.CreatedAt = existing.CreatedAt
|
||||
rule.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkBudgetRule(ctx, rule); err != nil {
|
||||
return nil, fmt.Errorf("save agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, rule.ID, rule.AccountID, activity.AgentNetworkBudgetRuleUpdated, rule.EventMeta())
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// DeleteBudgetRule removes an account-level budget rule.
|
||||
func (m *managerImpl) DeleteBudgetRule(ctx context.Context, accountID, userID, ruleID string) error {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rule, err := m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthUpdate, accountID, ruleID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.store.DeleteAgentNetworkBudgetRule(ctx, accountID, ruleID); err != nil {
|
||||
return fmt.Errorf("delete agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, ruleID, accountID, activity.AgentNetworkBudgetRuleDeleted, rule.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateSettings applies the mutable account-level settings — the collection
|
||||
// toggles — onto the existing row. Cluster and Subdomain are immutable and are
|
||||
// preserved from the persisted row regardless of the input. Because the
|
||||
// collection toggles change the synthesised service config (prompt-capture
|
||||
// gating, access-log emission), a reconcile is triggered so the proxy and peer
|
||||
// network maps converge on the new state.
|
||||
func (m *managerImpl) UpdateSettings(ctx context.Context, userID string, settings *types.Settings) (*types.Settings, error) {
|
||||
if err := m.requirePermission(ctx, settings.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthUpdate, settings.AccountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get agent network settings: %w", err)
|
||||
}
|
||||
|
||||
existing.EnableLogCollection = settings.EnableLogCollection
|
||||
existing.EnablePromptCollection = settings.EnablePromptCollection
|
||||
existing.RedactPii = settings.RedactPii
|
||||
existing.AccessLogRetentionDays = settings.AccessLogRetentionDays
|
||||
existing.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkSettings(ctx, existing); err != nil {
|
||||
return nil, fmt.Errorf("save agent network settings: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, settings.AccountID, settings.AccountID, activity.AgentNetworkSettingsUpdated, map[string]any{
|
||||
"log_collection": existing.EnableLogCollection,
|
||||
"prompt_collection": existing.EnablePromptCollection,
|
||||
"redact_pii": existing.RedactPii,
|
||||
})
|
||||
m.reconcile(ctx, settings.AccountID)
|
||||
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
// validateProviderRefs ensures every destination provider id refers to a
|
||||
// provider that exists in the same account.
|
||||
func (m *managerImpl) validateProviderRefs(ctx context.Context, accountID string, providerIDs []string) error {
|
||||
if len(providerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, id := range providerIDs {
|
||||
if _, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthNone, accountID, id); err != nil {
|
||||
// Only a genuine not-found means the reference is invalid; a
|
||||
// store/runtime error must propagate as-is rather than be
|
||||
// masked as a client validation error.
|
||||
var sErr *status.Error
|
||||
if errors.As(err, &sErr) && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "destination_provider_ids: provider %s does not exist", id)
|
||||
}
|
||||
return fmt.Errorf("get destination provider %s: %w", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSettings returns the agent-network settings row for the account.
|
||||
// Returns the underlying status.NotFound when no row has been
|
||||
// bootstrapped yet (i.e. the account has no providers).
|
||||
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// bootstrapSettingsIfNeeded creates the per-account agent-network
|
||||
// settings row when missing. The cluster comes from the create-time
|
||||
// hint the dashboard sends (auto-picked from the active cluster list);
|
||||
// the subdomain is picked from the curated wordlist avoiding
|
||||
// collisions on the same cluster. Idempotent: if a row already exists
|
||||
// it is returned untouched and the hint is ignored.
|
||||
func (m *managerImpl) bootstrapSettingsIfNeeded(ctx context.Context, accountID, providerCluster string) (*types.Settings, error) {
|
||||
if accountID == "" {
|
||||
return nil, fmt.Errorf("bootstrap settings: account id is required")
|
||||
}
|
||||
if strings.TrimSpace(providerCluster) == "" {
|
||||
return nil, fmt.Errorf("bootstrap settings: provider cluster is required")
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err == nil {
|
||||
return existing, nil
|
||||
}
|
||||
var sErr *status.Error
|
||||
if !errors.As(err, &sErr) || sErr.Type() != status.NotFound {
|
||||
return nil, fmt.Errorf("get agent network settings: %w", err)
|
||||
}
|
||||
|
||||
siblings, err := m.store.GetAgentNetworkSettingsByCluster(ctx, store.LockingStrengthNone, providerCluster)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list agent network settings on cluster: %w", err)
|
||||
}
|
||||
taken := make(map[string]struct{}, len(siblings))
|
||||
for _, s := range siblings {
|
||||
taken[s.Subdomain] = struct{}{}
|
||||
}
|
||||
|
||||
suffix := accountID
|
||||
if len(suffix) > 4 {
|
||||
suffix = suffix[:4]
|
||||
}
|
||||
|
||||
m.labelRngMu.Lock()
|
||||
subdomain := labelgen.PickUnique(m.labelRng, taken, suffix)
|
||||
m.labelRngMu.Unlock()
|
||||
|
||||
now := time.Now().UTC()
|
||||
settings := &types.Settings{
|
||||
AccountID: accountID,
|
||||
Cluster: providerCluster,
|
||||
Subdomain: subdomain,
|
||||
// Logs on by default; usage is collected regardless. Retention bounds
|
||||
// how long full log rows are kept.
|
||||
EnableLogCollection: true,
|
||||
AccessLogRetentionDays: types.DefaultAccessLogRetentionDays,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := m.store.SaveAgentNetworkSettings(ctx, settings); err != nil {
|
||||
return nil, fmt.Errorf("save agent network settings: %w", err)
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// ListConsumption returns every consumption row recorded for the
|
||||
// account, ordered window-newest-first. Backs the dashboard's basic
|
||||
// counter view; permission gate is the same Read role that gates
|
||||
// every other agent-network surface.
|
||||
func (m *managerImpl) ListConsumption(ctx context.Context, accountID, userID string) ([]*types.Consumption, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.ListAgentNetworkConsumption(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// ListAccessLogs returns a paginated, server-side-filtered page of
|
||||
// agent-network access logs plus the total count matching the filter.
|
||||
func (m *managerImpl) ListAccessLogs(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return m.store.GetAgentNetworkAccessLogs(ctx, store.LockingStrengthNone, accountID, filter)
|
||||
}
|
||||
|
||||
// GetUsageOverview returns the filtered usage rows aggregated into time buckets
|
||||
// at the requested granularity, oldest-first.
|
||||
func (m *managerImpl) GetUsageOverview(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter, granularity types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := m.store.GetAgentNetworkUsageRows(ctx, store.LockingStrengthNone, accountID, filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return types.AggregateUsageByGranularity(rows, granularity), nil
|
||||
}
|
||||
|
||||
// StartAccessLogCleanup launches a background sweep that periodically deletes
|
||||
// each account's agent-network access-log rows older than that account's
|
||||
// AccessLogRetentionDays. Usage records are never swept. A non-positive
|
||||
// interval defaults to 24h.
|
||||
func (m *managerImpl) StartAccessLogCleanup(ctx context.Context, cleanupIntervalHours int) {
|
||||
if cleanupIntervalHours <= 0 {
|
||||
cleanupIntervalHours = 24
|
||||
}
|
||||
interval := time.Duration(cleanupIntervalHours) * time.Hour
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
m.cleanupAccessLogsOnce(ctx) // run once on startup
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.cleanupAccessLogsOnce(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// cleanupAccessLogsOnce sweeps every account's expired access-log rows against
|
||||
// its configured retention. Best-effort: a per-account failure is logged and
|
||||
// the sweep continues.
|
||||
func (m *managerImpl) cleanupAccessLogsOnce(ctx context.Context) {
|
||||
settings, err := m.store.GetAllAgentNetworkSettings(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("agent-network access-log cleanup: list settings: %v", err)
|
||||
return
|
||||
}
|
||||
for _, s := range settings {
|
||||
if s.AccessLogRetentionDays <= 0 {
|
||||
continue // keep indefinitely
|
||||
}
|
||||
cutoff := time.Now().UTC().AddDate(0, 0, -s.AccessLogRetentionDays)
|
||||
deleted, err := m.store.DeleteOldAgentNetworkAccessLogs(ctx, s.AccountID, cutoff)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("agent-network access-log cleanup for account %s: %v", s.AccountID, err)
|
||||
continue
|
||||
}
|
||||
if deleted > 0 {
|
||||
log.WithContext(ctx).Infof("agent-network access-log cleanup: deleted %d rows for account %s (retention %d days)", deleted, s.AccountID, s.AccessLogRetentionDays)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecordConsumption increments the (dim, window) counter by the
|
||||
// supplied deltas. The window_start is computed from time.Now under
|
||||
// the supplied window_seconds so callers don't have to pre-align —
|
||||
// the proxy's post-flight path simply hands us tokens + cost and
|
||||
// which dimension we're booking against.
|
||||
func (m *managerImpl) RecordConsumption(ctx context.Context, accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds, tokensIn, tokensOut int64, costUSD float64) error {
|
||||
if accountID == "" || dimID == "" || windowSeconds <= 0 {
|
||||
return status.Errorf(status.InvalidArgument, "account_id, dim_id and window_seconds must be set")
|
||||
}
|
||||
windowStart := types.WindowStart(time.Now(), windowSeconds)
|
||||
return m.store.IncrementAgentNetworkConsumption(ctx, accountID, kind, dimID, windowSeconds, windowStart, tokensIn, tokensOut, costUSD)
|
||||
}
|
||||
|
||||
func (m *managerImpl) requirePermission(ctx context.Context, accountID, userID string, op operations.Operation) error {
|
||||
ok, _, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.AgentNetwork, op)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockManager struct{}
|
||||
|
||||
// NewManagerMock returns a no-op manager useful for tests.
|
||||
func NewManagerMock() Manager {
|
||||
return &mockManager{}
|
||||
}
|
||||
|
||||
func (*mockManager) GetAllProviders(_ context.Context, _, _ string) ([]*types.Provider, error) {
|
||||
return []*types.Provider{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetProvider(_ context.Context, _, _, _ string) (*types.Provider, error) {
|
||||
return &types.Provider{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) CreateProvider(_ context.Context, _ string, p *types.Provider, _ string) (*types.Provider, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (*mockManager) UpdateProvider(_ context.Context, _ string, p *types.Provider) (*types.Provider, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (*mockManager) DeleteProvider(_ context.Context, _, _, _ string) error { return nil }
|
||||
|
||||
func (*mockManager) GetAllPolicies(_ context.Context, _, _ string) ([]*types.Policy, error) {
|
||||
return []*types.Policy{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetPolicy(_ context.Context, _, _, _ string) (*types.Policy, error) {
|
||||
return &types.Policy{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) CreatePolicy(_ context.Context, _ string, p *types.Policy) (*types.Policy, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (*mockManager) UpdatePolicy(_ context.Context, _ string, p *types.Policy) (*types.Policy, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (*mockManager) DeletePolicy(_ context.Context, _, _, _ string) error { return nil }
|
||||
|
||||
func (*mockManager) GetAllGuardrails(_ context.Context, _, _ string) ([]*types.Guardrail, error) {
|
||||
return []*types.Guardrail{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetGuardrail(_ context.Context, _, _, _ string) (*types.Guardrail, error) {
|
||||
return &types.Guardrail{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) CreateGuardrail(_ context.Context, _ string, g *types.Guardrail) (*types.Guardrail, error) {
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (*mockManager) UpdateGuardrail(_ context.Context, _ string, g *types.Guardrail) (*types.Guardrail, error) {
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (*mockManager) DeleteGuardrail(_ context.Context, _, _, _ string) error { return nil }
|
||||
|
||||
func (*mockManager) GetAllBudgetRules(_ context.Context, _, _ string) ([]*types.AccountBudgetRule, error) {
|
||||
return []*types.AccountBudgetRule{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetBudgetRule(_ context.Context, _, _, _ string) (*types.AccountBudgetRule, error) {
|
||||
return &types.AccountBudgetRule{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) CreateBudgetRule(_ context.Context, _ string, r *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (*mockManager) UpdateBudgetRule(_ context.Context, _ string, r *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (*mockManager) DeleteBudgetRule(_ context.Context, _, _, _ string) error { return nil }
|
||||
|
||||
func (*mockManager) GetSettings(_ context.Context, _, _ string) (*types.Settings, error) {
|
||||
return nil, status.Errorf(status.NotFound, "agent network settings not found")
|
||||
}
|
||||
|
||||
func (*mockManager) UpdateSettings(_ context.Context, _ string, s *types.Settings) (*types.Settings, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (*mockManager) ListConsumption(_ context.Context, _, _ string) ([]*types.Consumption, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*mockManager) ListAccessLogs(_ context.Context, _, _ string, _ types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetUsageOverview(_ context.Context, _, _ string, _ types.AgentNetworkAccessLogFilter, _ types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*mockManager) StartAccessLogCleanup(_ context.Context, _ int) {}
|
||||
|
||||
func (*mockManager) RecordConsumption(_ context.Context, _ string, _ types.ConsumptionDimension, _ string, _, _, _ int64, _ float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*mockManager) RecordAccountBudgetUsage(_ context.Context, _, _ string, _ []string, _, _ int64, _ float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*mockManager) RecordUsage(_ context.Context, _ RecordUsageInput) error {
|
||||
return nil
|
||||
}
|
||||
660
management/internals/modules/agentnetwork/policyselect.go
Normal file
660
management/internals/modules/agentnetwork/policyselect.go
Normal file
@@ -0,0 +1,660 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// validateUsageDeltas rejects negative or non-finite usage counters before they
|
||||
// reach the consumption store, so a bad delta can't decrement or poison totals.
|
||||
// The store batch method enforces the same invariant; this is the manager-level
|
||||
// guard so direct callers fail fast with a clear error.
|
||||
func validateUsageDeltas(tokensIn, tokensOut int64, costUSD float64) error {
|
||||
if tokensIn < 0 || tokensOut < 0 || costUSD < 0 || math.IsNaN(costUSD) || math.IsInf(costUSD, 0) {
|
||||
return status.Errorf(status.InvalidArgument, "usage deltas must be non-negative and finite")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deny codes the proxy surfaces back to the caller when every
|
||||
// applicable policy is exhausted. The proxy converts these into
|
||||
// upstream-shaped error responses.
|
||||
const (
|
||||
//nolint:gosec // policy deny code label, not a credential
|
||||
denyCodeTokenCapExceeded = "llm_policy.token_cap_exceeded"
|
||||
//nolint:gosec // policy deny code label, not a credential
|
||||
denyCodeBudgetCapExceeded = "llm_policy.budget_cap_exceeded"
|
||||
//nolint:gosec // account deny code label, not a credential
|
||||
denyCodeAccountTokenCapExceeded = "llm_account.token_cap_exceeded"
|
||||
//nolint:gosec // account deny code label, not a credential
|
||||
denyCodeAccountBudgetCapExceeded = "llm_account.budget_cap_exceeded"
|
||||
)
|
||||
|
||||
// consumptionCache holds the consumption counters prefetched for one
|
||||
// policy-selection request, keyed by ConsumptionKey. A miss returns a zero
|
||||
// counter — the same contract the store's single-row getter uses for absent
|
||||
// rows — so the eval logic is identical whether a counter exists yet or not.
|
||||
type consumptionCache map[types.ConsumptionKey]*types.Consumption
|
||||
|
||||
func (c consumptionCache) get(accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds int64, windowStart time.Time) *types.Consumption {
|
||||
key := types.ConsumptionKey{Kind: kind, DimID: dimID, WindowSeconds: windowSeconds, WindowStartUTC: windowStart.UTC()}
|
||||
if row, ok := c[key]; ok && row != nil {
|
||||
return row
|
||||
}
|
||||
return &types.Consumption{
|
||||
AccountID: accountID,
|
||||
DimensionKind: kind,
|
||||
DimensionID: dimID,
|
||||
WindowSeconds: windowSeconds,
|
||||
WindowStartUTC: windowStart.UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// addLimitKeys records the user/group consumption keys a single enabled (token
|
||||
// or budget) limit window reads for the given attribution group, into a dedup
|
||||
// set. attrGroup may be empty (no group dimension applies).
|
||||
func addLimitKeys(set map[types.ConsumptionKey]struct{}, userID, attrGroup string, windowSeconds int64, now time.Time) {
|
||||
if windowSeconds <= 0 {
|
||||
return
|
||||
}
|
||||
ws := types.WindowStart(now, windowSeconds)
|
||||
if userID != "" {
|
||||
set[types.ConsumptionKey{Kind: types.DimensionUser, DimID: userID, WindowSeconds: windowSeconds, WindowStartUTC: ws}] = struct{}{}
|
||||
}
|
||||
if attrGroup != "" {
|
||||
set[types.ConsumptionKey{Kind: types.DimensionGroup, DimID: attrGroup, WindowSeconds: windowSeconds, WindowStartUTC: ws}] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// prefetchConsumption loads, in one store round-trip, every consumption counter
|
||||
// that the account-budget ceiling and the candidate policies will read while
|
||||
// scoring this request. This replaces the per-cap point reads the selector
|
||||
// previously issued one at a time (the N+1 on the hot path).
|
||||
func (m *managerImpl) prefetchConsumption(ctx context.Context, in PolicySelectionInput, rules []*types.AccountBudgetRule, candidates []*types.Policy, now time.Time) (consumptionCache, error) {
|
||||
set := make(map[types.ConsumptionKey]struct{})
|
||||
for _, p := range candidates {
|
||||
attr := lowestIntersect(p.SourceGroups, in.GroupIDs)
|
||||
if p.Limits.TokenLimit.Enabled {
|
||||
addLimitKeys(set, in.UserID, attr, p.Limits.TokenLimit.WindowSeconds, now)
|
||||
}
|
||||
if p.Limits.BudgetLimit.Enabled {
|
||||
addLimitKeys(set, in.UserID, attr, p.Limits.BudgetLimit.WindowSeconds, now)
|
||||
}
|
||||
}
|
||||
for _, r := range rules {
|
||||
if r == nil || !r.Enabled || !budgetRuleApplies(r, in) {
|
||||
continue
|
||||
}
|
||||
attr := lowestIntersect(r.TargetGroups, in.GroupIDs)
|
||||
if r.Limits.TokenLimit.Enabled {
|
||||
addLimitKeys(set, in.UserID, attr, r.Limits.TokenLimit.WindowSeconds, now)
|
||||
}
|
||||
if r.Limits.BudgetLimit.Enabled {
|
||||
addLimitKeys(set, in.UserID, attr, r.Limits.BudgetLimit.WindowSeconds, now)
|
||||
}
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return consumptionCache{}, nil
|
||||
}
|
||||
keys := make([]types.ConsumptionKey, 0, len(set))
|
||||
for k := range set {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
rows, err := m.store.GetAgentNetworkConsumptionBatch(ctx, store.LockingStrengthNone, in.AccountID, keys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch read consumption: %w", err)
|
||||
}
|
||||
return consumptionCache(rows), nil
|
||||
}
|
||||
|
||||
// SelectPolicyForRequest picks the policy that "pays" for the
|
||||
// incoming request. The chosen policy is the one with the largest
|
||||
// pool that still has headroom — drain the bigger bucket first,
|
||||
// fall through to the next-biggest only when the current one's
|
||||
// group cap or shared per-user cap is exhausted. This matches
|
||||
// operator intuition for layered tiers ("privileged group has the
|
||||
// 10k budget, regular group has 1k as the safety net") and avoids
|
||||
// the load-balancer flapping that fraction-based scoring produces
|
||||
// once any cap has been touched.
|
||||
//
|
||||
// Ordering across non-exhausted candidates:
|
||||
// 1. Policies with NO enabled caps (catch-all-allow) win over any
|
||||
// capped policy — operators who configure unlimited access
|
||||
// expect requests to attribute there until they explicitly add
|
||||
// caps.
|
||||
// 2. Larger group token cap wins.
|
||||
// 3. Larger group budget USD cap wins.
|
||||
// 4. Larger user token cap wins.
|
||||
// 5. Larger user budget USD cap wins.
|
||||
// 6. Older created_at wins (deterministic final tiebreak so
|
||||
// multi-node selection converges).
|
||||
//
|
||||
// Returns Allow=true with empty SelectedPolicyID when no policy in
|
||||
// the account targets the (provider, caller-groups) combination —
|
||||
// llm_router is the gate that owns "no policy authorises this
|
||||
// request" semantics; this function trusts that authorisation has
|
||||
// already happened upstream and only does the limit-aware
|
||||
// attribution.
|
||||
func (m *managerImpl) SelectPolicyForRequest(ctx context.Context, in PolicySelectionInput) (*PolicySelectionResult, error) {
|
||||
if in.AccountID == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "account_id is required")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, in.AccountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list account budget rules: %w", err)
|
||||
}
|
||||
policies, err := m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, in.AccountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list account policies: %w", err)
|
||||
}
|
||||
candidates := filterApplicablePolicies(policies, in)
|
||||
|
||||
// Prefetch every consumption counter the ceiling + candidate policies will
|
||||
// read, in a single store round-trip, then score against the cache.
|
||||
cache, err := m.prefetchConsumption(ctx, in, rules, candidates, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Account-level budget rules are an always-on ceiling, evaluated
|
||||
// independently of policy selection (they bind even for catch-all-allow
|
||||
// policies or requests that match no policy). All applicable rules must
|
||||
// pass — this is where min-wins lives.
|
||||
if deny, code, reason := checkAccountBudget(in, rules, cache, now); deny {
|
||||
return &PolicySelectionResult{Allow: false, DenyCode: code, DenyReason: reason}, nil
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return &PolicySelectionResult{Allow: true}, nil
|
||||
}
|
||||
scored, lastDenyCode, lastDenyReason := scoreCandidates(in, candidates, cache, now)
|
||||
if len(scored) == 0 {
|
||||
return &PolicySelectionResult{
|
||||
Allow: false,
|
||||
DenyCode: lastDenyCode,
|
||||
DenyReason: lastDenyReason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sort.SliceStable(scored, func(i, j int) bool {
|
||||
// Catch-all-allow (no caps configured) wins outright over
|
||||
// any capped policy.
|
||||
iNoCap := isUncapped(scored[i].policy)
|
||||
jNoCap := isUncapped(scored[j].policy)
|
||||
if iNoCap != jNoCap {
|
||||
return iNoCap
|
||||
}
|
||||
// Bigger pool drains first. Group caps dominate (shared
|
||||
// across the group) before individual caps.
|
||||
if a, b := groupCapTokens(scored[i].policy), groupCapTokens(scored[j].policy); a != b {
|
||||
return a > b
|
||||
}
|
||||
if a, b := groupCapBudgetUsd(scored[i].policy), groupCapBudgetUsd(scored[j].policy); a != b {
|
||||
return a > b
|
||||
}
|
||||
if a, b := userCapTokens(scored[i].policy), userCapTokens(scored[j].policy); a != b {
|
||||
return a > b
|
||||
}
|
||||
if a, b := userCapBudgetUsd(scored[i].policy), userCapBudgetUsd(scored[j].policy); a != b {
|
||||
return a > b
|
||||
}
|
||||
return scored[i].policy.CreatedAt.Before(scored[j].policy.CreatedAt)
|
||||
})
|
||||
|
||||
winner := scored[0]
|
||||
return &PolicySelectionResult{
|
||||
Allow: true,
|
||||
SelectedPolicyID: winner.policy.ID,
|
||||
AttributionGroupID: winner.attributionGroup,
|
||||
WindowSeconds: winner.windowSeconds,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// filterApplicablePolicies returns the enabled policies that target
|
||||
// the requested provider and have at least one of the caller's groups
|
||||
// in their source_groups. Caller's group set is matched
|
||||
// case-sensitively against policy.SourceGroups.
|
||||
func filterApplicablePolicies(policies []*types.Policy, in PolicySelectionInput) []*types.Policy {
|
||||
if len(policies) == 0 {
|
||||
return nil
|
||||
}
|
||||
groupSet := make(map[string]struct{}, len(in.GroupIDs))
|
||||
for _, g := range in.GroupIDs {
|
||||
if g != "" {
|
||||
groupSet[g] = struct{}{}
|
||||
}
|
||||
}
|
||||
out := make([]*types.Policy, 0, len(policies))
|
||||
for _, p := range policies {
|
||||
if p == nil || !p.Enabled {
|
||||
continue
|
||||
}
|
||||
if !sliceContains(p.DestinationProviderIDs, in.ProviderID) {
|
||||
continue
|
||||
}
|
||||
if !anyGroupMatches(p.SourceGroups, groupSet) {
|
||||
continue
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// candidate is the per-policy intermediate the selector ranks. A
|
||||
// policy that's been exhausted on any enabled cap never makes it
|
||||
// into this slice; the selector's deny envelope carries the latest
|
||||
// exhaustion's reason out separately.
|
||||
type candidate struct {
|
||||
policy *types.Policy
|
||||
attributionGroup string
|
||||
windowSeconds int64
|
||||
}
|
||||
|
||||
// scoreCandidates evaluates every applicable policy against the
|
||||
// caller's current consumption. Exhausted policies are filtered out
|
||||
// of the returned slice; the most recent exhaustion's deny code +
|
||||
// human reason is returned alongside so the caller can surface it
|
||||
// when no candidate survives.
|
||||
func scoreCandidates(
|
||||
in PolicySelectionInput,
|
||||
candidates []*types.Policy,
|
||||
cache consumptionCache,
|
||||
now time.Time,
|
||||
) ([]candidate, string, string) {
|
||||
out := make([]candidate, 0, len(candidates))
|
||||
var lastDenyCode, lastDenyReason string
|
||||
|
||||
for _, p := range candidates {
|
||||
c, exhausted, denyCode, denyReason := scoreOne(in, p, cache, now)
|
||||
if exhausted {
|
||||
lastDenyCode = denyCode
|
||||
lastDenyReason = denyReason
|
||||
continue
|
||||
}
|
||||
out = append(out, c)
|
||||
}
|
||||
return out, lastDenyCode, lastDenyReason
|
||||
}
|
||||
|
||||
// scoreOne checks a single policy for cap exhaustion. Returns the
|
||||
// candidate envelope when the policy still has headroom on every
|
||||
// enabled cap; reports exhausted=true with a deny code naming the
|
||||
// offending cap kind otherwise.
|
||||
func scoreOne(
|
||||
in PolicySelectionInput,
|
||||
p *types.Policy,
|
||||
cache consumptionCache,
|
||||
now time.Time,
|
||||
) (candidate, bool, string, string) {
|
||||
attrGroup := lowestIntersect(p.SourceGroups, in.GroupIDs)
|
||||
c := candidate{
|
||||
policy: p,
|
||||
attributionGroup: attrGroup,
|
||||
windowSeconds: effectiveWindowSeconds(p),
|
||||
}
|
||||
|
||||
if p.Limits.TokenLimit.Enabled && p.Limits.TokenLimit.WindowSeconds > 0 {
|
||||
if exhausted, reason := evalTokenCap(cache, in.AccountID, in.UserID, attrGroup, p.Limits.TokenLimit, now, "policy "+p.ID); exhausted {
|
||||
return candidate{}, true, denyCodeTokenCapExceeded, reason
|
||||
}
|
||||
}
|
||||
|
||||
if p.Limits.BudgetLimit.Enabled && p.Limits.BudgetLimit.WindowSeconds > 0 {
|
||||
if exhausted, reason := evalBudgetCap(cache, in.AccountID, in.UserID, attrGroup, p.Limits.BudgetLimit, now, "policy "+p.ID); exhausted {
|
||||
return candidate{}, true, denyCodeBudgetCapExceeded, reason
|
||||
}
|
||||
}
|
||||
|
||||
return c, false, "", ""
|
||||
}
|
||||
|
||||
// evalTokenCap reports whether the token limit is already exhausted for the
|
||||
// caller in its own window. attrGroup may be empty (no group dimension applies).
|
||||
// label identifies the cap source ("policy <id>" or "account rule <id>") for the
|
||||
// deny reason. It is the shared primitive behind both policy and account-rule
|
||||
// enforcement.
|
||||
func evalTokenCap(
|
||||
cache consumptionCache,
|
||||
accountID, userID, attrGroup string,
|
||||
tl types.PolicyTokenLimit,
|
||||
now time.Time,
|
||||
label string,
|
||||
) (bool, string) {
|
||||
windowStart := types.WindowStart(now, tl.WindowSeconds)
|
||||
|
||||
if tl.UserCap > 0 && userID != "" {
|
||||
row := cache.get(accountID, types.DimensionUser, userID, tl.WindowSeconds, windowStart)
|
||||
used := row.TokensInput + row.TokensOutput
|
||||
if used >= tl.UserCap {
|
||||
return true, fmt.Sprintf("user token cap exhausted on %s (used %d of %d)", label, used, tl.UserCap)
|
||||
}
|
||||
}
|
||||
|
||||
if tl.GroupCap > 0 && attrGroup != "" {
|
||||
row := cache.get(accountID, types.DimensionGroup, attrGroup, tl.WindowSeconds, windowStart)
|
||||
used := row.TokensInput + row.TokensOutput
|
||||
if used >= tl.GroupCap {
|
||||
return true, fmt.Sprintf("group token cap exhausted on %s (used %d of %d)", label, used, tl.GroupCap)
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// evalBudgetCap is the budget (USD) counterpart of evalTokenCap.
|
||||
func evalBudgetCap(
|
||||
cache consumptionCache,
|
||||
accountID, userID, attrGroup string,
|
||||
bl types.PolicyBudgetLimit,
|
||||
now time.Time,
|
||||
label string,
|
||||
) (bool, string) {
|
||||
windowStart := types.WindowStart(now, bl.WindowSeconds)
|
||||
|
||||
if bl.UserCapUsd > 0 && userID != "" {
|
||||
row := cache.get(accountID, types.DimensionUser, userID, bl.WindowSeconds, windowStart)
|
||||
if row.CostUSD >= bl.UserCapUsd {
|
||||
return true, fmt.Sprintf("user budget cap exhausted on %s (used $%.4f of $%.4f)", label, row.CostUSD, bl.UserCapUsd)
|
||||
}
|
||||
}
|
||||
|
||||
if bl.GroupCapUsd > 0 && attrGroup != "" {
|
||||
row := cache.get(accountID, types.DimensionGroup, attrGroup, bl.WindowSeconds, windowStart)
|
||||
if row.CostUSD >= bl.GroupCapUsd {
|
||||
return true, fmt.Sprintf("group budget cap exhausted on %s (used $%.4f of $%.4f)", label, row.CostUSD, bl.GroupCapUsd)
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// checkAccountBudget evaluates every applicable account-level budget rule as an
|
||||
// all-must-pass ceiling. A rule applies when the caller is in its TargetUsers,
|
||||
// one of its TargetGroups, or it has no targets at all (account-wide). Returns
|
||||
// deny=true with an llm_account.* code on the first exhausted rule. Group caps
|
||||
// attribute to the lowest intersecting group (the same model policies use), so
|
||||
// multi-group behavior is unchanged.
|
||||
func checkAccountBudget(in PolicySelectionInput, rules []*types.AccountBudgetRule, cache consumptionCache, now time.Time) (bool, string, string) {
|
||||
for _, r := range rules {
|
||||
if r == nil || !r.Enabled || !budgetRuleApplies(r, in) {
|
||||
continue
|
||||
}
|
||||
attrGroup := lowestIntersect(r.TargetGroups, in.GroupIDs)
|
||||
label := "account rule " + r.ID
|
||||
|
||||
if r.Limits.TokenLimit.Enabled && r.Limits.TokenLimit.WindowSeconds > 0 {
|
||||
if exhausted, reason := evalTokenCap(cache, in.AccountID, in.UserID, attrGroup, r.Limits.TokenLimit, now, label); exhausted {
|
||||
return true, denyCodeAccountTokenCapExceeded, reason
|
||||
}
|
||||
}
|
||||
|
||||
if r.Limits.BudgetLimit.Enabled && r.Limits.BudgetLimit.WindowSeconds > 0 {
|
||||
if exhausted, reason := evalBudgetCap(cache, in.AccountID, in.UserID, attrGroup, r.Limits.BudgetLimit, now, label); exhausted {
|
||||
return true, denyCodeAccountBudgetCapExceeded, reason
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, "", ""
|
||||
}
|
||||
|
||||
// budgetRuleApplies reports whether an account budget rule binds the caller:
|
||||
// a direct user match, a group intersection, or an untargeted (account-wide)
|
||||
// rule.
|
||||
func budgetRuleApplies(r *types.AccountBudgetRule, in PolicySelectionInput) bool {
|
||||
if len(r.TargetUsers) == 0 && len(r.TargetGroups) == 0 {
|
||||
return true
|
||||
}
|
||||
if in.UserID != "" && sliceContains(r.TargetUsers, in.UserID) {
|
||||
return true
|
||||
}
|
||||
groupSet := make(map[string]struct{}, len(in.GroupIDs))
|
||||
for _, g := range in.GroupIDs {
|
||||
if g != "" {
|
||||
groupSet[g] = struct{}{}
|
||||
}
|
||||
}
|
||||
return anyGroupMatches(r.TargetGroups, groupSet)
|
||||
}
|
||||
|
||||
// RecordAccountBudgetUsage fans the served request's usage out to every
|
||||
// applicable account budget rule's own (dimension, window) counter. The user
|
||||
// dimension is always booked when a rule has a user-applicable cap; the group
|
||||
// dimension books against the rule's lowest intersecting group. This runs
|
||||
// alongside the policy-window record so account ceilings accumulate in their own
|
||||
// windows (commonly monthly) independently of the per-policy window.
|
||||
func (m *managerImpl) RecordAccountBudgetUsage(ctx context.Context, accountID, userID string, groupIDs []string, tokensIn, tokensOut int64, costUSD float64) error {
|
||||
if accountID == "" {
|
||||
return status.Errorf(status.InvalidArgument, "account_id is required")
|
||||
}
|
||||
if err := validateUsageDeltas(tokensIn, tokensOut, costUSD); err != nil {
|
||||
return err
|
||||
}
|
||||
rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list account budget rules: %w", err)
|
||||
}
|
||||
set := make(map[types.ConsumptionKey]struct{})
|
||||
addAccountBudgetKeys(set, PolicySelectionInput{AccountID: accountID, UserID: userID, GroupIDs: groupIDs}, rules, time.Now().UTC())
|
||||
if len(set) == 0 {
|
||||
return nil
|
||||
}
|
||||
return m.store.IncrementAgentNetworkConsumptionBatch(ctx, accountID, keysSlice(set), tokensIn, tokensOut, costUSD)
|
||||
}
|
||||
|
||||
// RecordUsageInput carries everything RecordUsage books for one served request.
|
||||
type RecordUsageInput struct {
|
||||
AccountID string
|
||||
UserID string
|
||||
AttributionGroupID string // selected policy's attribution group (policy window)
|
||||
GroupIDs []string
|
||||
WindowSeconds int64 // selected policy's window; 0 means no policy cap
|
||||
TokensIn int64
|
||||
TokensOut int64
|
||||
CostUSD float64
|
||||
}
|
||||
|
||||
// RecordUsage books a served request's usage against every counter it touches —
|
||||
// the selected policy's per-(user, group) window plus every applicable account
|
||||
// budget rule's own window — deduplicated and written in a single transaction.
|
||||
// Two counters that collapse to the same (dimension, window) tuple are booked
|
||||
// once, so a single request can never double-count against one cap.
|
||||
func (m *managerImpl) RecordUsage(ctx context.Context, in RecordUsageInput) error {
|
||||
if in.AccountID == "" {
|
||||
return status.Errorf(status.InvalidArgument, "account_id is required")
|
||||
}
|
||||
if err := validateUsageDeltas(in.TokensIn, in.TokensOut, in.CostUSD); err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
set := make(map[types.ConsumptionKey]struct{})
|
||||
|
||||
// Policy-window dimensions are booked only when a policy cap bound this
|
||||
// request (window > 0). A zero window means catch-all-allow / no policy cap;
|
||||
// the account fan-out below still books against the budget rules' windows.
|
||||
if in.WindowSeconds > 0 {
|
||||
addLimitKeys(set, in.UserID, in.AttributionGroupID, in.WindowSeconds, now)
|
||||
}
|
||||
|
||||
rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, in.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list account budget rules: %w", err)
|
||||
}
|
||||
addAccountBudgetKeys(set, PolicySelectionInput{AccountID: in.AccountID, UserID: in.UserID, GroupIDs: in.GroupIDs}, rules, now)
|
||||
|
||||
if len(set) == 0 {
|
||||
return nil
|
||||
}
|
||||
return m.store.IncrementAgentNetworkConsumptionBatch(ctx, in.AccountID, keysSlice(set), in.TokensIn, in.TokensOut, in.CostUSD)
|
||||
}
|
||||
|
||||
// addAccountBudgetKeys adds the (dimension, window) keys a served request books
|
||||
// against every applicable account budget rule into the dedup set.
|
||||
func addAccountBudgetKeys(set map[types.ConsumptionKey]struct{}, in PolicySelectionInput, rules []*types.AccountBudgetRule, now time.Time) {
|
||||
for _, r := range rules {
|
||||
if r == nil || !r.Enabled || !budgetRuleApplies(r, in) {
|
||||
continue
|
||||
}
|
||||
attrGroup := lowestIntersect(r.TargetGroups, in.GroupIDs)
|
||||
for _, window := range ruleWindows(r) {
|
||||
addLimitKeys(set, in.UserID, attrGroup, window, now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// keysSlice flattens a ConsumptionKey set into a slice.
|
||||
func keysSlice(set map[types.ConsumptionKey]struct{}) []types.ConsumptionKey {
|
||||
keys := make([]types.ConsumptionKey, 0, len(set))
|
||||
for k := range set {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// ruleWindows returns the distinct enabled window lengths a budget rule books
|
||||
// against (token window and/or budget window, deduplicated).
|
||||
func ruleWindows(r *types.AccountBudgetRule) []int64 {
|
||||
var windows []int64
|
||||
if r.Limits.TokenLimit.Enabled && r.Limits.TokenLimit.WindowSeconds > 0 {
|
||||
windows = append(windows, r.Limits.TokenLimit.WindowSeconds)
|
||||
}
|
||||
if r.Limits.BudgetLimit.Enabled && r.Limits.BudgetLimit.WindowSeconds > 0 {
|
||||
bw := r.Limits.BudgetLimit.WindowSeconds
|
||||
if len(windows) == 0 || windows[0] != bw {
|
||||
windows = append(windows, bw)
|
||||
}
|
||||
}
|
||||
return windows
|
||||
}
|
||||
|
||||
// effectiveWindowSeconds returns the window length the proxy should
|
||||
// hand back to RecordLLMUsage. When both halves are enabled with
|
||||
// different windows, token_limit wins (the more common config); when
|
||||
// only one is enabled that one wins; when neither is enabled the
|
||||
// returned value is 0 — RecordLLMUsage treats 0 as "no limit
|
||||
// tracking" and skips the increment, which is the right pass-through
|
||||
// for catch-all-allow policies with no caps configured.
|
||||
func effectiveWindowSeconds(p *types.Policy) int64 {
|
||||
if p.Limits.TokenLimit.Enabled && p.Limits.TokenLimit.WindowSeconds > 0 {
|
||||
return p.Limits.TokenLimit.WindowSeconds
|
||||
}
|
||||
if p.Limits.BudgetLimit.Enabled && p.Limits.BudgetLimit.WindowSeconds > 0 {
|
||||
return p.Limits.BudgetLimit.WindowSeconds
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// lowestIntersect returns the lowest-by-string-sort element of
|
||||
// callerGroups ∩ sourceGroups. Empty when the intersection is empty.
|
||||
// Lowest is deterministic so multi-node selection converges.
|
||||
func lowestIntersect(sourceGroups, callerGroups []string) string {
|
||||
if len(sourceGroups) == 0 || len(callerGroups) == 0 {
|
||||
return ""
|
||||
}
|
||||
srcSet := make(map[string]struct{}, len(sourceGroups))
|
||||
for _, g := range sourceGroups {
|
||||
srcSet[g] = struct{}{}
|
||||
}
|
||||
var best string
|
||||
for _, g := range callerGroups {
|
||||
if _, ok := srcSet[g]; !ok {
|
||||
continue
|
||||
}
|
||||
if best == "" || g < best {
|
||||
best = g
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func anyGroupMatches(sourceGroups []string, callerSet map[string]struct{}) bool {
|
||||
for _, g := range sourceGroups {
|
||||
if _, ok := callerSet[g]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isUncapped reports whether a policy has any enabled cap with a
|
||||
// positive limit value. Mirrors the eval functions' guards: a policy
|
||||
// with token_limit.enabled=true but every cap value at 0 still
|
||||
// counts as uncapped because the eval would query nothing and bind
|
||||
// nothing.
|
||||
func isUncapped(p *types.Policy) bool {
|
||||
tl := p.Limits.TokenLimit
|
||||
if tl.Enabled && tl.WindowSeconds > 0 && (tl.GroupCap > 0 || tl.UserCap > 0) {
|
||||
return false
|
||||
}
|
||||
bl := p.Limits.BudgetLimit
|
||||
if bl.Enabled && bl.WindowSeconds > 0 && (bl.GroupCapUsd > 0 || bl.UserCapUsd > 0) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// groupCapTokens returns the policy's group-token cap when the token
|
||||
// limit is enabled, zero otherwise. Drives the primary "bigger pool
|
||||
// first" sort.
|
||||
func groupCapTokens(p *types.Policy) int64 {
|
||||
if p.Limits.TokenLimit.Enabled {
|
||||
return p.Limits.TokenLimit.GroupCap
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// groupCapBudgetUsd returns the policy's group-budget cap in USD
|
||||
// when the budget limit is enabled, zero otherwise. Secondary sort
|
||||
// key after token group cap so budget-only policies still order
|
||||
// predictably.
|
||||
func groupCapBudgetUsd(p *types.Policy) float64 {
|
||||
if p.Limits.BudgetLimit.Enabled {
|
||||
return p.Limits.BudgetLimit.GroupCapUsd
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// userCapTokens returns the policy's per-user token cap when the
|
||||
// token limit is enabled, zero otherwise. Tertiary sort key, used
|
||||
// when group caps tie or are absent.
|
||||
func userCapTokens(p *types.Policy) int64 {
|
||||
if p.Limits.TokenLimit.Enabled {
|
||||
return p.Limits.TokenLimit.UserCap
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// userCapBudgetUsd returns the policy's per-user budget cap in USD
|
||||
// when the budget limit is enabled, zero otherwise. Quaternary sort
|
||||
// key for budget-only policies whose group caps tie or are absent.
|
||||
func userCapBudgetUsd(p *types.Policy) float64 {
|
||||
if p.Limits.BudgetLimit.Enabled {
|
||||
return p.Limits.BudgetLimit.UserCapUsd
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func sliceContains(haystack []string, needle string) bool {
|
||||
for _, v := range haystack {
|
||||
if v == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// mockManager fallback so tests that don't care about selection still
|
||||
// compile.
|
||||
func (*mockManager) SelectPolicyForRequest(_ context.Context, _ PolicySelectionInput) (*PolicySelectionResult, error) {
|
||||
return &PolicySelectionResult{Allow: true}, nil
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// GC-2 no-mock enforcement tests for the account-budget ceiling. They drive the
|
||||
// real store + real consumption accounting through SelectPolicyForRequest and
|
||||
// RecordAccountBudgetUsage, asserting min-wins (account binds independently of
|
||||
// policy), targeting (groups + direct users), and the record fan-out.
|
||||
|
||||
func accountWideUserTokenRule(id string, userCap, window int64) *types.AccountBudgetRule {
|
||||
r := types.NewAccountBudgetRule(realSelectAccount)
|
||||
r.ID = id
|
||||
r.Limits.TokenLimit = types.PolicyTokenLimit{Enabled: true, UserCap: userCap, WindowSeconds: window}
|
||||
return r
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AccountCeilingBindsEvenWithUncappedPolicy proves
|
||||
// min-wins: the account user ceiling denies once exhausted even though a
|
||||
// catch-all-allow (uncapped) policy would otherwise pass the request. The
|
||||
// account gate runs independently of and ahead of policy selection.
|
||||
func TestSelectPolicy_RealStore_AccountCeilingBindsEvenWithUncappedPolicy(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// An uncapped (catch-all-allow) policy: enabled token limit, zero caps.
|
||||
uncapped := capPolicy("pol-open", realSelectAccount, []string{"grp-eng"}, "prov-1", 0, 86_400)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, uncapped))
|
||||
|
||||
// Account-wide user ceiling of 100 tokens in an hourly window.
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-1", 100, 3_600)))
|
||||
|
||||
in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"}
|
||||
|
||||
// Fresh: account ceiling has headroom, uncapped policy wins.
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "fresh account ceiling must allow")
|
||||
|
||||
// Drain the account user ceiling via the fan-out path.
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", []string{"grp-eng"}, 100, 0, 0))
|
||||
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "account ceiling must deny even though the policy is uncapped (min-wins)")
|
||||
assert.Equal(t, denyCodeAccountTokenCapExceeded, res.DenyCode, "deny must carry the llm_account.* code")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AccountGroupCeiling proves a group-targeted rule
|
||||
// binds the caller's group dimension.
|
||||
func TestSelectPolicy_RealStore_AccountGroupCeiling(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
rule := types.NewAccountBudgetRule(realSelectAccount)
|
||||
rule.ID = "ainbud-grp"
|
||||
rule.TargetGroups = []string{"grp-eng"}
|
||||
rule.Limits.BudgetLimit = types.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 5.0, WindowSeconds: 2_592_000}
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, rule))
|
||||
|
||||
in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"}
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "fresh group ceiling must allow")
|
||||
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", []string{"grp-eng"}, 0, 0, 5.0))
|
||||
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "group budget ceiling must deny once spent")
|
||||
assert.Equal(t, denyCodeAccountBudgetCapExceeded, res.DenyCode, "account budget deny code")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AccountTargetUsersBindsOnlyThatUser proves a
|
||||
// TargetUsers rule tightens only the named user, leaving others unbound.
|
||||
func TestSelectPolicy_RealStore_AccountTargetUsersBindsOnlyThatUser(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
rule := types.NewAccountBudgetRule(realSelectAccount)
|
||||
rule.ID = "ainbud-alice"
|
||||
rule.TargetUsers = []string{"alice"}
|
||||
rule.Limits.TokenLimit = types.PolicyTokenLimit{Enabled: true, UserCap: 100, WindowSeconds: 3_600}
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, rule))
|
||||
|
||||
// Record alice's usage to the rule window.
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "alice", nil, 100, 0, 0))
|
||||
|
||||
aliceIn := PolicySelectionInput{AccountID: realSelectAccount, UserID: "alice", ProviderID: "prov-1"}
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, aliceIn)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "alice is bound by the TargetUsers rule and is exhausted")
|
||||
|
||||
bobIn := PolicySelectionInput{AccountID: realSelectAccount, UserID: "bob", ProviderID: "prov-1"}
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, bobIn)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "bob is not in TargetUsers, so the rule must not bind him")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AccountRuleRecordsToOwnWindow proves the record
|
||||
// fan-out books usage in the rule's own window (distinct from any policy
|
||||
// window), so the account ceiling accumulates independently.
|
||||
func TestSelectPolicy_RealStore_AccountRuleRecordsToOwnWindow(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-w", 100, 3_600)))
|
||||
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", nil, 60, 0, 0))
|
||||
|
||||
// Same user, a policy-style daily window must NOT see the account-window
|
||||
// usage — windows are independent counters.
|
||||
dailyRow, err := s.GetAgentNetworkConsumption(ctx, store.LockingStrengthNone, realSelectAccount, types.DimensionUser, "user-1", 86_400, types.WindowStart(time.Now().UTC(), 86_400))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), dailyRow.TokensInput+dailyRow.TokensOutput, "daily window must be untouched by the hourly account-rule record")
|
||||
|
||||
// A second record pushes the hourly account window to its cap → deny.
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", nil, 40, 0, 0))
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", ProviderID: "prov-1"})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "100 tokens recorded in the rule's hourly window must exhaust the 100-token ceiling")
|
||||
assert.Equal(t, denyCodeAccountTokenCapExceeded, res.DenyCode, "account token deny code")
|
||||
}
|
||||
|
||||
// TestRecordUsage_RealStore_BooksPolicyAndAccountWindows proves the batched
|
||||
// post-flight write books the selected policy's window AND every applicable
|
||||
// account rule's (independent) window in a single call — the #6 batched-write
|
||||
// path the proxy's RecordLLMUsage RPC now uses.
|
||||
func TestRecordUsage_RealStore_BooksPolicyAndAccountWindows(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Policy: 100-token group cap on a daily window. Account rule: 100-token
|
||||
// user ceiling on an hourly window — an independent counter.
|
||||
policy := capPolicy("pol-1", realSelectAccount, []string{"grp-eng"}, "prov-1", 100, 86_400)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, policy))
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-1", 100, 3_600)))
|
||||
|
||||
in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"}
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
require.True(t, res.Allow)
|
||||
require.Equal(t, "pol-1", res.SelectedPolicyID)
|
||||
|
||||
// One batched record books the policy window (group + user @86400) and the
|
||||
// account rule window (user @3600) atomically.
|
||||
require.NoError(t, mgr.RecordUsage(ctx, RecordUsageInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
AttributionGroupID: res.AttributionGroupID,
|
||||
GroupIDs: []string{"grp-eng"},
|
||||
WindowSeconds: res.WindowSeconds,
|
||||
TokensIn: 100,
|
||||
}))
|
||||
|
||||
// The next selection denies — the account hourly ceiling binds first.
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "usage booked by RecordUsage must enforce on the next request")
|
||||
|
||||
// Prove BOTH windows were booked in the one call via a direct batch read.
|
||||
now := time.Now().UTC()
|
||||
userKey := types.ConsumptionKey{Kind: types.DimensionUser, DimID: "user-1", WindowSeconds: 3_600, WindowStartUTC: types.WindowStart(now, 3_600)}
|
||||
groupKey := types.ConsumptionKey{Kind: types.DimensionGroup, DimID: "grp-eng", WindowSeconds: 86_400, WindowStartUTC: types.WindowStart(now, 86_400)}
|
||||
rows, err := s.GetAgentNetworkConsumptionBatch(ctx, store.LockingStrengthNone, realSelectAccount, []types.ConsumptionKey{userKey, groupKey})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, rows, userKey, "account rule user/hourly window booked")
|
||||
require.Contains(t, rows, groupKey, "policy group/daily window booked")
|
||||
assert.Equal(t, int64(100), rows[userKey].TokensInput, "account hourly user counter")
|
||||
assert.Equal(t, int64(100), rows[groupKey].TokensInput, "policy daily group counter")
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// This file is the no-mock regression guard for policy limit enforcement.
|
||||
// policyselect_test.go pins the same behavior through a gomock store with
|
||||
// explicit call-sequence expectations — brittle precisely where the upcoming
|
||||
// account-budget work (GC-2) refactors the cap-eval primitive and adds an
|
||||
// account-level gate. These tests drive the REAL sqlite store + REAL
|
||||
// consumption accounting and assert observable behavior (allow / deny /
|
||||
// selection / attribution), not which store methods get called. They must keep
|
||||
// passing unchanged after GC-2 lands, which is what proves "current behavior is
|
||||
// not changed."
|
||||
|
||||
const realSelectAccount = "acc-realselect-1"
|
||||
|
||||
// newRealSelectorMgr builds a managerImpl backed by a real sqlite test store.
|
||||
func newRealSelectorMgr(t *testing.T) (*managerImpl, store.Store) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
t.Cleanup(cleanup)
|
||||
return &managerImpl{store: s}, s
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_NoApplicablePolicies pins the pass-through:
|
||||
// nothing targets the (provider, groups) combination, so the selector allows
|
||||
// without attribution or consumption tracking.
|
||||
func TestSelectPolicy_RealStore_NoApplicablePolicies(t *testing.T) {
|
||||
mgr, _ := newRealSelectorMgr(t)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-x"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "no applicable policy must pass through as allow")
|
||||
assert.Empty(t, res.SelectedPolicyID, "no selection when nothing applies")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AllowAndLowestGroupAttribution pins the v1
|
||||
// attribution rule (lowest intersecting group by string sort) through the
|
||||
// real store, with a fresh (zero) consumption row.
|
||||
func TestSelectPolicy_RealStore_AllowAndLowestGroupAttribution(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
p := capPolicy("pol-A", realSelectAccount, []string{"grp-zz", "grp-aa", "grp-mm"}, "prov-1", 10_000, 86_400)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p))
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-zz", "grp-aa", "grp-mm"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "fresh state under cap must allow")
|
||||
assert.Equal(t, "pol-A", res.SelectedPolicyID, "only applicable policy must be selected")
|
||||
assert.Equal(t, "grp-aa", res.AttributionGroupID, "lowest-by-sort intersecting group must win")
|
||||
assert.Equal(t, int64(86_400), res.WindowSeconds, "selected policy's window must be returned")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_LargerPoolWins_FallsThroughWhenExhausted pins the
|
||||
// core selection behavior end to end. The two policies bind DISTINCT groups so
|
||||
// they read separate counters — the only shape where fall-through actually
|
||||
// yields headroom (policies on the same group share one counter, as
|
||||
// policyselect_test.go notes). Larger pool wins fresh; after real consumption
|
||||
// drains the larger group, selection falls through to the smaller; once both
|
||||
// counters are exhausted the request is denied.
|
||||
func TestSelectPolicy_RealStore_LargerPoolWins_FallsThroughWhenExhausted(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
tight := capPolicy("pol-tight", realSelectAccount, []string{"grp-tight"}, "prov-1", 100, 86_400)
|
||||
tight.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
wide := capPolicy("pol-wide", realSelectAccount, []string{"grp-wide"}, "prov-1", 10_000, 86_400)
|
||||
wide.CreatedAt = time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, tight))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, wide))
|
||||
|
||||
// Caller is in both groups, so both policies apply with independent counters.
|
||||
in := PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-tight", "grp-wide"},
|
||||
ProviderID: "prov-1",
|
||||
}
|
||||
|
||||
// Fresh: larger pool wins.
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-wide", res.SelectedPolicyID, "larger pool drains first")
|
||||
|
||||
// Drain only the wide group's counter to its cap.
|
||||
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-wide", 86_400, 10_000, 0, 0))
|
||||
|
||||
// Wide exhausted, tight's separate counter is fresh → fall through to tight.
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "tight pool has its own untouched counter")
|
||||
assert.Equal(t, "pol-tight", res.SelectedPolicyID, "selection falls through to the smaller pool once the larger is exhausted")
|
||||
|
||||
// Drain the tight group's counter too → both exhausted → deny.
|
||||
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-tight", 86_400, 100, 0, 0))
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "both group counters exhausted must deny")
|
||||
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode, "deny code names the offending cap kind")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_BudgetCapDenies pins budget (USD) enforcement
|
||||
// through the real store: once recorded cost reaches the cap, deny.
|
||||
func TestSelectPolicy_RealStore_BudgetCapDenies(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
p := &types.Policy{
|
||||
ID: "pol-budget",
|
||||
AccountID: realSelectAccount,
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-eng"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
BudgetLimit: types.PolicyBudgetLimit{
|
||||
Enabled: true,
|
||||
GroupCapUsd: 5.0,
|
||||
WindowSeconds: 86_400,
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p))
|
||||
|
||||
in := PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-eng"},
|
||||
ProviderID: "prov-1",
|
||||
}
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "fresh budget must allow")
|
||||
|
||||
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-eng", 86_400, 0, 0, 5.0))
|
||||
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "cost at the cap must deny")
|
||||
assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode, "budget deny code must be surfaced")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_GroupCounterSharedAcrossPolicies pins that two
|
||||
// policies on the same group+window read one shared consumption counter: usage
|
||||
// recorded once is visible to both, so exhausting the group budget denies
|
||||
// regardless of which policy would attribute.
|
||||
func TestSelectPolicy_RealStore_GroupCounterSharedAcrossPolicies(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
a := capPolicy("pol-a", realSelectAccount, []string{"grp-eng"}, "prov-1", 1_000, 86_400)
|
||||
b := capPolicy("pol-b", realSelectAccount, []string{"grp-eng"}, "prov-1", 1_000, 86_400)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, a))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, b))
|
||||
|
||||
in := PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-eng"},
|
||||
ProviderID: "prov-1",
|
||||
}
|
||||
|
||||
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-eng", 86_400, 1_000, 0, 0))
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "shared group counter at cap denies both equal policies")
|
||||
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode, "token deny code on the shared counter")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_DisabledPolicyIgnored pins that a disabled policy
|
||||
// is invisible to selection even when it otherwise matches.
|
||||
func TestSelectPolicy_RealStore_DisabledPolicyIgnored(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
p := capPolicy("pol-disabled", realSelectAccount, []string{"grp-eng"}, "prov-1", 10_000, 86_400)
|
||||
p.Enabled = false
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p))
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-eng"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "no enabled policy applies → pass-through allow")
|
||||
assert.Empty(t, res.SelectedPolicyID, "disabled policy must not be selected")
|
||||
}
|
||||
641
management/internals/modules/agentnetwork/policyselect_test.go
Normal file
641
management/internals/modules/agentnetwork/policyselect_test.go
Normal file
@@ -0,0 +1,641 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
nbstatus "github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func newSelectorMgr(t *testing.T, ctrl *gomock.Controller) (*managerImpl, *store.MockStore) {
|
||||
t.Helper()
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
// SelectPolicyForRequest evaluates the account-budget ceiling before policy
|
||||
// selection. These policy-selection tests don't exercise account rules, so
|
||||
// default to "no rules" — the no-mock policyselect_realstore_test.go covers
|
||||
// the account gate's behavior end to end.
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkBudgetRules(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nil, nil).
|
||||
AnyTimes()
|
||||
return &managerImpl{store: mockStore}, mockStore
|
||||
}
|
||||
|
||||
type usedKey struct {
|
||||
kind types.ConsumptionDimension
|
||||
dimID string
|
||||
window int64
|
||||
}
|
||||
|
||||
// expectConsumptionBatch stubs the batched consumption read to return the
|
||||
// supplied per-(kind, dim, window) counters, filling each row's window start
|
||||
// from the actual request keys so it always matches what the selector computed.
|
||||
// Keys absent from used resolve to zero counters.
|
||||
func expectConsumptionBatch(mockStore *store.MockStore, used map[usedKey]*types.Consumption) {
|
||||
mockStore.EXPECT().
|
||||
GetAgentNetworkConsumptionBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, _ store.LockingStrength, _ string, keys []types.ConsumptionKey) (map[types.ConsumptionKey]*types.Consumption, error) {
|
||||
out := make(map[types.ConsumptionKey]*types.Consumption)
|
||||
for _, k := range keys {
|
||||
if row, ok := used[usedKey{k.Kind, k.DimID, k.WindowSeconds}]; ok {
|
||||
rc := *row
|
||||
rc.WindowStartUTC = k.WindowStartUTC
|
||||
out[k] = &rc
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}).
|
||||
AnyTimes()
|
||||
}
|
||||
|
||||
func capPolicy(id, account string, sourceGroups []string, providerID string, tokenCap int64, windowSec int64) *types.Policy {
|
||||
return &types.Policy{
|
||||
ID: id,
|
||||
AccountID: account,
|
||||
Enabled: true,
|
||||
SourceGroups: sourceGroups,
|
||||
DestinationProviderIDs: []string{providerID},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{
|
||||
Enabled: true,
|
||||
GroupCap: tokenCap,
|
||||
WindowSeconds: windowSec,
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectPolicy_NoApplicablePolicies covers the pass-through path:
|
||||
// llm_router authorisation is upstream of selection; when the
|
||||
// selector finds no policy targeting the (provider, caller-groups)
|
||||
// combination, it returns Allow with no attribution and lets the
|
||||
// request continue without consumption tracking.
|
||||
func TestSelectPolicy_NoApplicablePolicies(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{}, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-x"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "no applicable policies = pass-through allow")
|
||||
assert.Empty(t, res.SelectedPolicyID, "no selection when nothing applies")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_AllowWithLowestGroupAttribution proves the v1
|
||||
// attribution rule: when the caller's groups intersect a policy's
|
||||
// source_groups in multiple positions, the selector picks the lowest
|
||||
// group id by string sort so multi-node selection converges.
|
||||
func TestSelectPolicy_AllowWithLowestGroupAttribution(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policy := capPolicy("pol-A", "acc-1", []string{"grp-zz", "grp-aa", "grp-mm"}, "prov-1", 10_000, 86_400)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policy}, nil)
|
||||
// Fresh: zero consumption across the board.
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-zz", "grp-aa", "grp-mm"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow)
|
||||
assert.Equal(t, "pol-A", res.SelectedPolicyID)
|
||||
assert.Equal(t, "grp-aa", res.AttributionGroupID,
|
||||
"lowest-by-sort intersection wins so multi-node selection converges")
|
||||
assert.Equal(t, int64(86_400), res.WindowSeconds)
|
||||
}
|
||||
|
||||
// TestSelectPolicy_LargerPoolWinsAcrossUsageLevels proves the core
|
||||
// selection rule: among multiple applicable policies with caps, the
|
||||
// selector picks the one with the larger absolute pool — at every
|
||||
// usage level, not just at fresh state. The smaller-pool policy is
|
||||
// only reached when the larger one is exhausted. This is the
|
||||
// "drain biggest first" semantic operators expect for layered
|
||||
// tiers; a fraction-based score would flap between the two as
|
||||
// soon as one is partially used.
|
||||
func TestSelectPolicy_LargerPoolWinsAcrossUsageLevels(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
tight := capPolicy("pol-tight", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400)
|
||||
tight.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
wide := capPolicy("pol-wide", "acc-1", []string{"grp-engineers"}, "prov-1", 10_000, 86_400)
|
||||
wide.CreatedAt = time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{tight, wide}, nil)
|
||||
|
||||
// Both partially used. tight at 50/100 (50% used); wide at
|
||||
// 50/10000 (0.5% used). Old fraction-based algo would pick wide
|
||||
// here too — but for the wrong reason ("more relative slack").
|
||||
// New algo picks wide because its initial group cap is bigger
|
||||
// (10000 > 100), and that decision is stable as wide drains.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 50},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-wide", res.SelectedPolicyID,
|
||||
"the policy with the bigger initial pool wins — operators expect 'drain the privileged tier first', not load-balance across tiers")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_StaysOnLargerPoolAfterPartialDrain locks the
|
||||
// stickiness contract reported by operators: with two policies
|
||||
// where A has a 200-token group cap and B has 150, the very first
|
||||
// request goes to A AND every subsequent request continues to land
|
||||
// on A until A's group cap is exhausted — at which point B becomes
|
||||
// the only candidate. A fraction-based score would flap to B as
|
||||
// soon as A had any consumption (B's 1.0 fraction beats A's 0.75)
|
||||
// even though A still has more absolute headroom; that produced
|
||||
// confusing per-policy attribution ledger entries and stranded
|
||||
// A's remaining capacity behind B's exhaustion.
|
||||
func TestSelectPolicy_StaysOnLargerPoolAfterPartialDrain(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policyA := capPolicy("pol-A-200", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400)
|
||||
policyB := capPolicy("pol-B-150", "acc-1", []string{"grp-engineers"}, "prov-1", 150, 86_400)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policyA, policyB}, nil)
|
||||
|
||||
// A is partially drained (50/200 used = 25% used; 75% headroom
|
||||
// remaining). B is fresh (0/150). The old fraction-based score
|
||||
// would pick B here (1.0 > 0.75 fraction); the new pool-size
|
||||
// score sticks with A (200 > 150 absolute cap).
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 50},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-A-200", res.SelectedPolicyID,
|
||||
"once attribution lands on the bigger pool it must STAY there until exhausted — operators expect 'drain A then B', not 'flip to B as soon as A is touched'")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_FallsThroughToSmallerPoolWhenLargerExhausted
|
||||
// proves the second half of the stickiness contract: once the
|
||||
// larger-pool policy IS exhausted, the smaller one takes over.
|
||||
// Without this we'd deny on requests the smaller policy is fully
|
||||
// equipped to serve.
|
||||
func TestSelectPolicy_FallsThroughToSmallerPoolWhenLargerExhausted(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policyA := capPolicy("pol-A-200", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400)
|
||||
// B uses a different window length so it has an INDEPENDENT counter — the
|
||||
// realistic shape for fall-through. On the SAME (group, window) tuple the
|
||||
// counter is shared, so A's cap of 200 being reached would also exhaust B's
|
||||
// 150; independent counters are what let A exhaust while B retains headroom.
|
||||
policyB := capPolicy("pol-B-150", "acc-1", []string{"grp-engineers"}, "prov-1", 150, 3_600)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policyA, policyB}, nil)
|
||||
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 200}, // A: 200 >= 200 → exhausted
|
||||
{types.DimensionGroup, "grp-engineers", 3_600}: {TokensInput: 100}, // B: 100 < 150 → headroom
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-B-150", res.SelectedPolicyID,
|
||||
"once the bigger pool is exhausted, the smaller one must take over — denying when capacity remains would strand B's allowance")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_TiebreakByLargerGroupPool covers the user-reported
|
||||
// bug: an admin in two groups (Users + Admins) where Users is bound
|
||||
// by a smaller-group-cap policy (50 group, 100 user) and Admins is
|
||||
// bound by a bigger-group-cap policy (100 group, 20 user) MUST get
|
||||
// attributed to the Admins policy on the first request.
|
||||
//
|
||||
// Without this rule, the fresh-state fraction is 1.0 for both and
|
||||
// the older policy wins by created_at. The first 24-token request
|
||||
// then drains the shared user counter past Admins's tight 20-token
|
||||
// user cap, locking Admins out of selection forever. The 100-token
|
||||
// Admins group pool ends up stranded while requests pile onto the
|
||||
// 50-token Users pool — the opposite of what the operator intended
|
||||
// when they put the bigger pool on the privileged group.
|
||||
func TestSelectPolicy_TiebreakByLargerGroupPool(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
// Policy A: Users group, smaller group pool, looser per-user cap.
|
||||
policyA := &types.Policy{
|
||||
ID: "pol-Users",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-Users"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{
|
||||
Enabled: true, GroupCap: 50, UserCap: 100, WindowSeconds: 86_400,
|
||||
},
|
||||
},
|
||||
// Older — would win the legacy created_at tiebreak.
|
||||
CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
// Policy B: Admins group, bigger group pool, tighter per-user cap.
|
||||
policyB := &types.Policy{
|
||||
ID: "pol-Admins",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-Admins"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{
|
||||
Enabled: true, GroupCap: 100, UserCap: 20, WindowSeconds: 86_400,
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policyA, policyB}, nil)
|
||||
// Fresh state: every cap evaluation reads zero usage.
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-Users", "grp-Admins"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-Admins", res.SelectedPolicyID,
|
||||
"the bigger group pool wins the fresh-state tiebreak — picking Users first would burn the shared user counter past Admins's tight user cap on the very first request and strand the bigger Admins pool")
|
||||
assert.Equal(t, "grp-Admins", res.AttributionGroupID)
|
||||
}
|
||||
|
||||
// TestSelectPolicy_TiebreakByCreatedAt proves the deterministic
|
||||
// final tiebreak: when two applicable policies have the same
|
||||
// headroom fraction AND the same group cap (so the larger-pool rule
|
||||
// can't differentiate either), the older policy wins so attribution
|
||||
// is stable across replays.
|
||||
func TestSelectPolicy_TiebreakByCreatedAt(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
older := capPolicy("pol-old", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400)
|
||||
older.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
newer := capPolicy("pol-new", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400)
|
||||
newer.CreatedAt = time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{newer, older}, nil)
|
||||
// Both at zero consumption → identical headroom fraction.
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-old", res.SelectedPolicyID,
|
||||
"older policy wins on equal-headroom tiebreak so attribution is stable across replays")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_DeniesWhenAllExhausted proves the deny envelope:
|
||||
// when every applicable policy has at least one cap fully exhausted,
|
||||
// the selector returns Allow=false with the most-recent exhaustion's
|
||||
// deny code + human reason. The proxy's middleware surfaces this as
|
||||
// a 403 with the canonical llm_policy.* code.
|
||||
func TestSelectPolicy_DeniesWhenAllExhausted(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
a := capPolicy("pol-a", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400)
|
||||
b := capPolicy("pol-b", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{a, b}, nil)
|
||||
|
||||
// Shared group counter at 200: A (cap 100) and B (cap 200) both exhausted.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 200},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "every applicable policy exhausted = deny")
|
||||
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode)
|
||||
assert.Contains(t, res.DenyReason, "token cap exhausted",
|
||||
"deny reason must name the exhausted cap kind for operator debugging")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_UncappedPolicyAlwaysWinsAgainstCapped proves the
|
||||
// catch-all-allow contract: a policy with NO enabled caps wins
|
||||
// against any capped policy regardless of how much headroom the
|
||||
// capped one has, because operators who configure unlimited access
|
||||
// expect requests to attribute there until they explicitly add caps.
|
||||
func TestSelectPolicy_UncappedPolicyAlwaysWinsAgainstCapped(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
uncapped := &types.Policy{
|
||||
ID: "pol-uncapped",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-engineers"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
// All Limits.*.Enabled = false (zero-value).
|
||||
CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
wide := capPolicy("pol-wide", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000_000, 86_400)
|
||||
wide.CreatedAt = time.Date(2025, 12, 1, 0, 0, 0, 0, time.UTC) // older than uncapped
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{uncapped, wide}, nil)
|
||||
// Only the wide policy reads consumption; uncapped doesn't query
|
||||
// because it has no enabled caps.
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-uncapped", res.SelectedPolicyID,
|
||||
"a no-caps policy must always win selection — that's how operators express 'unlimited access through this path'")
|
||||
assert.Equal(t, int64(0), res.WindowSeconds, "no caps configured = WindowSeconds=0 so RecordLLMUsage skips counter writes")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_DisabledPolicyIgnored proves disabled policies
|
||||
// don't count toward selection — even when they'd otherwise be the
|
||||
// best match. Operators disable a policy to take it offline; the
|
||||
// selector must respect that and route through whatever's left.
|
||||
func TestSelectPolicy_DisabledPolicyIgnored(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
disabled := capPolicy("pol-disabled", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000_000, 86_400)
|
||||
disabled.Enabled = false
|
||||
enabled := capPolicy("pol-enabled", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{disabled, enabled}, nil)
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-enabled", res.SelectedPolicyID,
|
||||
"disabled policies must be ignored at selection time")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_StoreErrorPropagates locks the no-fail-open
|
||||
// contract: a transient store error must surface to the caller, not
|
||||
// be silently treated as "no policies = allow". A false allow on the
|
||||
// hot path would let a request slip past every cap.
|
||||
func TestSelectPolicy_StoreErrorPropagates(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return(nil, errors.New("boom"))
|
||||
|
||||
_, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
})
|
||||
require.Error(t, err, "store errors must surface — never fail open on the hot path")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RejectsEmptyAccount is the input-validation guard:
|
||||
// empty account_id is a programmer error and must surface as
|
||||
// InvalidArgument, not as a silent zero-result lookup.
|
||||
func TestSelectPolicy_RejectsEmptyAccount(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, _ := newSelectorMgr(t, ctrl)
|
||||
|
||||
_, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{})
|
||||
require.Error(t, err)
|
||||
var sErr *nbstatus.Error
|
||||
require.True(t, errors.As(err, &sErr))
|
||||
assert.Equal(t, nbstatus.InvalidArgument, sErr.Type())
|
||||
}
|
||||
|
||||
// TestSelectPolicy_SharesGroupCounterAcrossPolicies locks the
|
||||
// counter-keying design fork: counters are keyed on (account,
|
||||
// dim_kind, dim_id, window_hours, window_start) — NOT on policy_id.
|
||||
// Two policies that target the same group with the SAME window length
|
||||
// share one bucket: spend booked under policy A is visible to policy
|
||||
// B's headroom calculation and counts toward B's cap.
|
||||
//
|
||||
// This is what makes "operator's per-group enforcement" sane — caps
|
||||
// describe how much a GROUP can use, not how much each policy owes.
|
||||
func TestSelectPolicy_SharesGroupCounterAcrossPolicies(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
// Two policies, both targeting grp-engineers + prov-1, same 24h
|
||||
// window length. Different cap sizes.
|
||||
policyA := capPolicy("pol-A", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400)
|
||||
policyB := capPolicy("pol-B", "acc-1", []string{"grp-engineers"}, "prov-1", 5_000, 86_400)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policyA, policyB}, nil)
|
||||
// Both policies query the SAME consumption row — same dim_id,
|
||||
// same window_hours, same window_start. The mock returns the
|
||||
// same row for both calls, simulating the shared counter.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 800},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// 800 used → policy A has 200 tokens left of 1000 (20% headroom);
|
||||
// policy B has 4200 left of 5000 (84% headroom). B wins.
|
||||
assert.Equal(t, "pol-B", res.SelectedPolicyID,
|
||||
"the SAME 800 tokens count toward both policies — counters share the (group, window) key, caps differ per policy")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_AntiFallThroughOnLowestGroup locks the no-fall-
|
||||
// through behaviour: when a caller is in multiple of a policy's
|
||||
// source_groups and the lowest-by-sort group is exhausted, we DENY
|
||||
// rather than fall through to a less-loaded sibling. Per-group caps
|
||||
// are independent (each group has its own bucket), but attribution
|
||||
// is one-shot — operators wanting fall-through must split into
|
||||
// separate policies.
|
||||
//
|
||||
// This nails down semantics future contributors might "improve" into
|
||||
// fall-through behaviour by accident.
|
||||
func TestSelectPolicy_AntiFallThroughOnLowestGroup(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
// Policy targets two groups; caller is in both.
|
||||
policy := capPolicy("pol-1", "acc-1", []string{"grp-aaa", "grp-bbb"}, "prov-1", 100, 86_400)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policy}, nil)
|
||||
|
||||
// grp-aaa is the lowest by sort → attribution picks it, and the
|
||||
// prefetch only collects the attribution group's key. We exhaust
|
||||
// grp-aaa (100/100); grp-bbb's counter is never requested because the
|
||||
// selector attributes one-shot to the lowest group, so it can't fall
|
||||
// through to a less-loaded sibling.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-aaa", 86_400}: {TokensInput: 100},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-aaa", "grp-bbb"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow,
|
||||
"lowest-group-by-sort attribution does NOT fall through to a less-loaded sibling — operators wanting fall-through must split into separate policies")
|
||||
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode)
|
||||
assert.Contains(t, res.DenyReason, "pol-1",
|
||||
"deny reason names the exhausted policy id so operators can grep it from the access log")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_BudgetOnlyExhaustionDenies covers the symmetric
|
||||
// path to TestSelectPolicy_DeniesWhenAllExhausted but for the budget
|
||||
// cap: a policy with token_limit DISABLED and budget_limit at-cap
|
||||
// must deny with llm_policy.budget_cap_exceeded (not the token code).
|
||||
//
|
||||
// Without this, the budget evaluation path in evalBudgetCap could
|
||||
// silently regress and we'd still pass DeniesWhenAllExhausted (which
|
||||
// only exercises tokens).
|
||||
func TestSelectPolicy_BudgetOnlyExhaustionDenies(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policy := &types.Policy{
|
||||
ID: "pol-budget",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-engineers"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{Enabled: false},
|
||||
BudgetLimit: types.PolicyBudgetLimit{
|
||||
Enabled: true,
|
||||
GroupCapUsd: 10.00,
|
||||
WindowSeconds: 86_400,
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policy}, nil)
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {CostUSD: 10.50}, // over the $10 cap
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "budget cap exhausted must deny independently of any token cap state")
|
||||
assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode,
|
||||
"deny code must be the budget code — token-only deny would silently regress the budget evaluation path")
|
||||
assert.Contains(t, res.DenyReason, "budget", "deny reason names the budget cap kind for operator debugging")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_BudgetTighterThanTokenWins is the dual-cap headroom
|
||||
// fork: when both Token and Budget are enabled on the same policy,
|
||||
// the SMALLER remaining ratio gates the policy. A policy with
|
||||
// abundant token headroom but near-zero budget headroom must deny on
|
||||
// budget, not pass on tokens.
|
||||
func TestSelectPolicy_BudgetTighterThanTokenWins(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policy := &types.Policy{
|
||||
ID: "pol-dual",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-engineers"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{Enabled: true, GroupCap: 10_000_000, WindowSeconds: 86_400},
|
||||
BudgetLimit: types.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 1.00, WindowSeconds: 86_400},
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policy}, nil)
|
||||
// One shared counter carries both token usage (ample headroom) and cost
|
||||
// (at the $1 budget cap); the tighter budget cap gates the policy.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 100, CostUSD: 1.00},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow,
|
||||
"the tighter of (token, budget) wins — abundant token headroom must NOT mask an exhausted budget")
|
||||
assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode)
|
||||
}
|
||||
131
management/internals/modules/agentnetwork/reconcile.go
Normal file
131
management/internals/modules/agentnetwork/reconcile.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// reconcile recomputes the synthesised reverse-proxy services for an
|
||||
// account, diffs them against the previously-synthesised set in the
|
||||
// in-memory cache, and emits Create / Update / Delete proxy mappings
|
||||
// to the affected clusters. Also triggers a peer-side network-map
|
||||
// recompute via accountManager.UpdateAccountPeers so the
|
||||
// private-service ACL injection picks up the new state immediately.
|
||||
//
|
||||
// Reconcile failures are logged and swallowed — the underlying CRUD
|
||||
// has already completed, and the next mutation (or proxy reconnect)
|
||||
// will re-converge the cluster's view.
|
||||
func (m *managerImpl) reconcile(ctx context.Context, accountID string) {
|
||||
if accountID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if m.accountManager != nil {
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{
|
||||
Resource: types.UpdateResourceService,
|
||||
Operation: types.UpdateOperationUpdate,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
if m.proxyController == nil {
|
||||
return
|
||||
}
|
||||
|
||||
services, err := SynthesizeServices(ctx, m.store, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).WithError(err).Warnf("agent-network reconcile: synthesise services for account %s", accountID)
|
||||
return
|
||||
}
|
||||
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
current := make(map[string]*proto.ProxyMapping, len(services))
|
||||
for _, svc := range services {
|
||||
if svc == nil || svc.ID == "" {
|
||||
continue
|
||||
}
|
||||
current[svc.ID] = svc.ToProtoMapping(rpservice.Update, "", oidcCfg)
|
||||
}
|
||||
|
||||
m.reconcileMu.Lock()
|
||||
previous := m.reconcileCache[accountID]
|
||||
if previous == nil {
|
||||
previous = make(map[string]*proto.ProxyMapping)
|
||||
}
|
||||
|
||||
creates, updates, deletes := diffMappings(previous, current)
|
||||
if len(current) == 0 {
|
||||
delete(m.reconcileCache, accountID)
|
||||
} else {
|
||||
m.reconcileCache[accountID] = current
|
||||
}
|
||||
m.reconcileMu.Unlock()
|
||||
|
||||
for _, mapping := range creates {
|
||||
mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping))
|
||||
}
|
||||
for _, mapping := range updates {
|
||||
mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping))
|
||||
}
|
||||
for _, mapping := range deletes {
|
||||
mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping))
|
||||
}
|
||||
}
|
||||
|
||||
// diffMappings classifies the previous→current transition for a
|
||||
// single account into Create / Update / Delete sets.
|
||||
//
|
||||
// Cluster moves (current.cluster != previous.cluster) are surfaced as
|
||||
// a Delete on the old cluster + Create on the new — handled by
|
||||
// emitting both a delete (on previous mapping) and a create (on the
|
||||
// current mapping) for that service ID.
|
||||
func diffMappings(previous, current map[string]*proto.ProxyMapping) (creates, updates, deletes []*proto.ProxyMapping) {
|
||||
for id, cur := range current {
|
||||
prev, existed := previous[id]
|
||||
switch {
|
||||
case !existed:
|
||||
creates = append(creates, cur)
|
||||
case prev.GetDomain() == "" || cur.GetAccountId() == prev.GetAccountId() && currentClusterChanged(prev, cur):
|
||||
deletes = append(deletes, prev)
|
||||
creates = append(creates, cur)
|
||||
default:
|
||||
updates = append(updates, cur)
|
||||
}
|
||||
}
|
||||
for id, prev := range previous {
|
||||
if _, stillThere := current[id]; !stillThere {
|
||||
deletes = append(deletes, prev)
|
||||
}
|
||||
}
|
||||
return creates, updates, deletes
|
||||
}
|
||||
|
||||
func currentClusterChanged(prev, cur *proto.ProxyMapping) bool {
|
||||
return clusterFromMapping(prev) != clusterFromMapping(cur)
|
||||
}
|
||||
|
||||
// clusterFromMapping returns the cluster the mapping should be sent
|
||||
// to. ProxyMapping doesn't carry the cluster directly, so we rely on
|
||||
// the synthesised service's domain (`<slug>.<cluster>`) and split on
|
||||
// the first '.'.
|
||||
func clusterFromMapping(m *proto.ProxyMapping) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
domain := m.GetDomain()
|
||||
for i := 0; i < len(domain); i++ {
|
||||
if domain[i] == '.' {
|
||||
return domain[i+1:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
232
management/internals/modules/agentnetwork/reconcile_test.go
Normal file
232
management/internals/modules/agentnetwork/reconcile_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func newReconcileMgr(t *testing.T, ctrl *gomock.Controller) (*managerImpl, *store.MockStore, *proxy.MockController) {
|
||||
t.Helper()
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockProxy := proxy.NewMockController(ctrl)
|
||||
return &managerImpl{
|
||||
store: mockStore,
|
||||
proxyController: mockProxy,
|
||||
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
|
||||
}, mockStore, mockProxy
|
||||
}
|
||||
|
||||
func newReconcileTestProvider() *types.Provider {
|
||||
return &types.Provider{
|
||||
ID: "prov-1",
|
||||
AccountID: "acct-1",
|
||||
ProviderID: "openai_api",
|
||||
Name: "OpenAI",
|
||||
UpstreamURL: "https://api.openai.com",
|
||||
APIKey: "sk-test-key",
|
||||
Enabled: true,
|
||||
SessionPrivateKey: "test-priv-key",
|
||||
SessionPublicKey: "test-pub-key",
|
||||
}
|
||||
}
|
||||
|
||||
func newReconcileTestPolicy(providerID, sourceGroupID string) *types.Policy {
|
||||
return &types.Policy{
|
||||
ID: "pol-1",
|
||||
AccountID: "acct-1",
|
||||
Name: "engineers",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{sourceGroupID},
|
||||
DestinationProviderIDs: []string{providerID},
|
||||
}
|
||||
}
|
||||
|
||||
func newReconcileTestSettings() *types.Settings {
|
||||
return &types.Settings{
|
||||
AccountID: "acct-1",
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
Subdomain: "violet",
|
||||
}
|
||||
}
|
||||
|
||||
func expectReconcileSynthInputs(mockStore *store.MockStore, ctx context.Context, providers []*types.Provider, policies []*types.Policy, guardrails []*types.Guardrail) {
|
||||
mockStore.EXPECT().
|
||||
GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(newReconcileTestSettings(), nil)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(providers, nil)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(policies, nil)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(guardrails, nil)
|
||||
}
|
||||
|
||||
func TestReconcile_FirstSynth_EmitsCreate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl)
|
||||
provider := newReconcileTestProvider()
|
||||
policy := newReconcileTestPolicy(provider.ID, "grp-eng")
|
||||
|
||||
expectReconcileSynthInputs(mockStore, ctx, []*types.Provider{provider}, []*types.Policy{policy}, []*types.Guardrail{})
|
||||
mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{})
|
||||
|
||||
var sentMappings []*proto.ProxyMapping
|
||||
mockProxy.EXPECT().
|
||||
SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), "eu.proxy.netbird.io").
|
||||
Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) {
|
||||
sentMappings = append(sentMappings, m)
|
||||
})
|
||||
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
|
||||
require.Len(t, sentMappings, 1, "first synth must emit one mapping")
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, sentMappings[0].Type, "first synth is a Create")
|
||||
assert.Equal(t, "agent-net-svc-acct-1", sentMappings[0].Id, "stable account-scoped virtual service id")
|
||||
assert.Equal(t, "violet.eu.proxy.netbird.io", sentMappings[0].Domain, "domain comes from settings (subdomain.cluster)")
|
||||
|
||||
mgr.reconcileMu.Lock()
|
||||
cached := mgr.reconcileCache["acct-1"]
|
||||
mgr.reconcileMu.Unlock()
|
||||
require.Len(t, cached, 1, "cache must hold the synth result for next diff")
|
||||
}
|
||||
|
||||
func TestReconcile_NoChange_EmitsNothingExtra(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl)
|
||||
provider := newReconcileTestProvider()
|
||||
policy := newReconcileTestPolicy(provider.ID, "grp-eng")
|
||||
|
||||
// Two identical synth runs.
|
||||
mockStore.EXPECT().
|
||||
GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(newReconcileTestSettings(), nil).Times(2)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return([]*types.Provider{provider}, nil).Times(2)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return([]*types.Policy{policy}, nil).Times(2)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return([]*types.Guardrail{}, nil).Times(2)
|
||||
mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).Times(2)
|
||||
|
||||
createCalls := 0
|
||||
updateCalls := 0
|
||||
mockProxy.EXPECT().
|
||||
SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), gomock.Any()).
|
||||
Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) {
|
||||
switch m.Type {
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
|
||||
createCalls++
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
|
||||
updateCalls++
|
||||
}
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
|
||||
assert.Equal(t, 1, createCalls, "first reconcile creates")
|
||||
assert.Equal(t, 1, updateCalls, "second reconcile re-pushes as Modified (no semantic change but mapping fields refresh)")
|
||||
}
|
||||
|
||||
func TestReconcile_PolicyRemoved_EmitsDelete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl)
|
||||
provider := newReconcileTestProvider()
|
||||
policy := newReconcileTestPolicy(provider.ID, "grp-eng")
|
||||
|
||||
gomock.InOrder(
|
||||
// First reconcile: provider + policy, synthesised.
|
||||
mockStore.EXPECT().GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").Return(newReconcileTestSettings(), nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Provider{provider}, nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Policy{policy}, nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Guardrail{}, nil),
|
||||
// Second reconcile: policy gone, provider stays but no longer referenced.
|
||||
mockStore.EXPECT().GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").Return(newReconcileTestSettings(), nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Provider{provider}, nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Policy{}, nil),
|
||||
)
|
||||
mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()
|
||||
|
||||
var seenTypes []proto.ProxyMappingUpdateType
|
||||
mockProxy.EXPECT().
|
||||
SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), "eu.proxy.netbird.io").
|
||||
Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) {
|
||||
seenTypes = append(seenTypes, m.Type)
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
|
||||
require.Len(t, seenTypes, 2, "create then delete")
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, seenTypes[0])
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, seenTypes[1])
|
||||
|
||||
mgr.reconcileMu.Lock()
|
||||
_, present := mgr.reconcileCache["acct-1"]
|
||||
mgr.reconcileMu.Unlock()
|
||||
assert.False(t, present, "cache for the account must be cleared once nothing is synthesised")
|
||||
}
|
||||
|
||||
func TestReconcile_NilProxyController_NoOp(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mgr := &managerImpl{
|
||||
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
|
||||
}
|
||||
// Must not panic; must not query the store.
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
}
|
||||
|
||||
func TestReconcile_EmptyAccountID_NoOp(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mgr, _, _ := newReconcileMgr(t, ctrl)
|
||||
// Empty accountID short-circuits before any store call.
|
||||
mgr.reconcile(ctx, "")
|
||||
}
|
||||
|
||||
func TestClusterFromMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
want string
|
||||
}{
|
||||
{"simple", "openai.eu.proxy.netbird.io", "eu.proxy.netbird.io"},
|
||||
{"deeply nested", "a.b.c.d", "b.c.d"},
|
||||
{"no dot", "openai", ""},
|
||||
{"empty", "", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := clusterFromMapping(&proto.ProxyMapping{Domain: tt.domain})
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
1059
management/internals/modules/agentnetwork/synthesizer.go
Normal file
1059
management/internals/modules/agentnetwork/synthesizer.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,178 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// decodeServiceGuardrailConfig pulls the llm_guardrail middleware config off the
|
||||
// synthesised service's single target.
|
||||
func decodeServiceGuardrailConfig(t *testing.T, svc *rpservice.Service) guardrailConfig {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, svc.Targets, "synth service must carry a target")
|
||||
for _, mw := range svc.Targets[0].Options.Middlewares {
|
||||
if mw.ID == middlewareIDLLMGuardrail {
|
||||
var cfg guardrailConfig
|
||||
require.NoError(t, json.Unmarshal(mw.ConfigJSON, &cfg), "guardrail config must decode")
|
||||
return cfg
|
||||
}
|
||||
}
|
||||
t.Fatal("llm_guardrail middleware not present on synthesised service")
|
||||
return guardrailConfig{}
|
||||
}
|
||||
|
||||
// decodeMiddlewareRawConfig returns the raw ConfigJSON bytes for the named
|
||||
// middleware on the synth service's target, or fails the test.
|
||||
func decodeMiddlewareRawConfig(t *testing.T, svc *rpservice.Service, id string) []byte {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, svc.Targets, "synth service must carry a target")
|
||||
for _, mw := range svc.Targets[0].Options.Middlewares {
|
||||
if mw.ID == id {
|
||||
return mw.ConfigJSON
|
||||
}
|
||||
}
|
||||
t.Fatalf("middleware %q not present on synthesised service", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveGuardrailAndPolicy persists a guardrail with prompt capture + redact + a
|
||||
// model allowlist, referenced by one enabled policy. Shared by the GC-3 tests.
|
||||
func saveGuardrailAndPolicy(t *testing.T, ctx context.Context, s store.Store, provider *types.Provider) {
|
||||
t.Helper()
|
||||
guardrail := &types.Guardrail{
|
||||
ID: "ainguard-1",
|
||||
AccountID: testAccountID,
|
||||
Name: "strict",
|
||||
Checks: types.GuardrailChecks{
|
||||
ModelAllowlist: types.GuardrailModelAllowlist{Enabled: true, Models: []string{"gpt-5.4"}},
|
||||
PromptCapture: types.GuardrailPromptCapture{Enabled: true, RedactPii: true},
|
||||
},
|
||||
}
|
||||
require.NoError(t, s.SaveAgentNetworkGuardrail(ctx, guardrail))
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", guardrail.ID)))
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_PromptCaptureAccountIsSoleControl is the
|
||||
// GC-3 contract: the account master switch (EnablePromptCollection) is the
|
||||
// SOLE control for capture enablement. Policy-level guardrail prompt_capture is
|
||||
// ignored for enablement — operators don't need to attach a capture guardrail
|
||||
// to a policy just to turn capture on for the account. Off by default.
|
||||
func TestSynthesizeServices_RealStore_PromptCaptureAccountIsSoleControl(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
// Account collection master switch OFF (default).
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
saveGuardrailAndPolicy(t, ctx, s, newSynthTestProvider())
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
cfg := decodeServiceGuardrailConfig(t, services[0])
|
||||
assert.Equal(t, []string{"gpt-5.4"}, cfg.ModelAllowlist,
|
||||
"model allowlist is a pure policy guardrail and must always reach the config")
|
||||
assert.False(t, cfg.PromptCapture.Enabled,
|
||||
"prompt capture must be off when the account toggle is off, even with a capture-enabled guardrail")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_PromptCaptureFlowsWhenAccountOptsIn proves
|
||||
// the account toggle is sufficient on its own — even with NO guardrail
|
||||
// attached to the policy, capture fires when the account opts in. Redact is
|
||||
// the OR of account + guardrail.
|
||||
func TestSynthesizeServices_RealStore_PromptCaptureFlowsWhenAccountOptsIn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.EnablePromptCollection = true
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
|
||||
// Save a provider and a policy with NO guardrails attached — proves the
|
||||
// account toggle is sufficient on its own.
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
cfg := decodeServiceGuardrailConfig(t, services[0])
|
||||
assert.True(t, cfg.PromptCapture.Enabled,
|
||||
"account toggle alone must enable capture; no guardrail attachment required")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_AccountRedactWithoutGuardrailRedact proves
|
||||
// the redact OR-merge from the account side: account RedactPii on, guardrail
|
||||
// redact off, capture on at both levels.
|
||||
func TestSynthesizeServices_RealStore_AccountRedactWithoutGuardrailRedact(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.EnablePromptCollection = true
|
||||
settings.RedactPii = true
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
guardrail := &types.Guardrail{
|
||||
ID: "ainguard-noredact",
|
||||
AccountID: testAccountID,
|
||||
Name: "capture-only",
|
||||
Checks: types.GuardrailChecks{
|
||||
PromptCapture: types.GuardrailPromptCapture{Enabled: true, RedactPii: false},
|
||||
},
|
||||
}
|
||||
require.NoError(t, s.SaveAgentNetworkGuardrail(ctx, guardrail))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", guardrail.ID)))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
cfg := decodeServiceGuardrailConfig(t, services[0])
|
||||
assert.True(t, cfg.PromptCapture.Enabled, "capture on (account + guardrail)")
|
||||
assert.True(t, cfg.PromptCapture.RedactPii, "account RedactPii must apply even when the guardrail leaves it off (OR)")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_NoGuardrail_CaptureOff pins the default:
|
||||
// with no guardrail referenced, the synth service's guardrail config has prompt
|
||||
// capture disabled and an empty allowlist. This is the "off by default" baseline
|
||||
// the account switch must preserve.
|
||||
func TestSynthesizeServices_RealStore_NoGuardrail_CaptureOff(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1, "exactly one synth service expected")
|
||||
|
||||
cfg := decodeServiceGuardrailConfig(t, services[0])
|
||||
assert.Empty(t, cfg.ModelAllowlist, "no guardrail → no allowlist")
|
||||
assert.False(t, cfg.PromptCapture.Enabled, "no guardrail → prompt capture off by default")
|
||||
assert.False(t, cfg.PromptCapture.RedactPii, "no guardrail → redact off by default")
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// TestSynthesizeServices_RealStore_LogCollectionOff_SuppressesAccessLog drives the
|
||||
// happy default: account settings ship with EnableLogCollection=false, so the
|
||||
// synthesised target opts out of access-log emission (DisableAccessLog=true) and
|
||||
// the proto mapping the proxy receives reflects that.
|
||||
func TestSynthesizeServices_RealStore_LogCollectionOff_SuppressesAccessLog(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1, "exactly one synth service expected")
|
||||
require.NotEmpty(t, services[0].Targets, "synth service must carry a target")
|
||||
assert.True(t, services[0].Targets[0].Options.DisableAccessLog,
|
||||
"EnableLogCollection=false (default) must produce DisableAccessLog=true on the synth target")
|
||||
|
||||
mapping := services[0].ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{})
|
||||
require.NotEmpty(t, mapping.GetPath(), "proto mapping must carry a path")
|
||||
assert.True(t, mapping.GetPath()[0].GetOptions().GetDisableAccessLog(),
|
||||
"proto mapping must propagate DisableAccessLog=true so the proxy suppresses access-log emission")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_LogCollectionOn_PermitsAccessLog asserts the
|
||||
// inverse: once the account opts in, the synth target leaves DisableAccessLog
|
||||
// at its default false and the proto wire stays unset.
|
||||
func TestSynthesizeServices_RealStore_LogCollectionOn_PermitsAccessLog(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.EnableLogCollection = true
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1, "exactly one synth service expected")
|
||||
require.NotEmpty(t, services[0].Targets, "synth service must carry a target")
|
||||
assert.False(t, services[0].Targets[0].Options.DisableAccessLog,
|
||||
"EnableLogCollection=true must leave DisableAccessLog=false on the synth target")
|
||||
|
||||
mapping := services[0].ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{})
|
||||
require.NotEmpty(t, mapping.GetPath(), "proto mapping must carry a path")
|
||||
assert.False(t, mapping.GetPath()[0].GetOptions().GetDisableAccessLog(),
|
||||
"proto mapping must propagate DisableAccessLog=false so access-log emission stays on")
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// parserRedactConfig mirrors the on-wire shape of the redact + capture knobs
|
||||
// that both llm_request_parser and llm_response_parser unmarshal. We don't
|
||||
// import the proxy-side packages from a management test (cross-module), so we
|
||||
// decode the JSON directly and assert on the fields that are part of the
|
||||
// synth contract.
|
||||
type parserRedactConfig struct {
|
||||
RedactPii bool `json:"redact_pii,omitempty"`
|
||||
CapturePrompt *bool `json:"capture_prompt,omitempty"` // present only on the request parser
|
||||
CaptureCompletion *bool `json:"capture_completion,omitempty"` // present only on the response parser
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_ParserConfigsCarryRedactPii is the
|
||||
// management-side contract test for the request/response parser redaction
|
||||
// wiring. When settings.RedactPii is true, the synthesised middleware chain
|
||||
// MUST stamp redact_pii=true on both llm_request_parser and llm_response_parser
|
||||
// configs — otherwise the parsers ship raw prompts / completions to the
|
||||
// access log even though the account has opted in. This is exactly the live
|
||||
// leak path that motivated the parser-side redaction in the first place.
|
||||
func TestSynthesizeServices_RealStore_ParserConfigsCarryRedactPii(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.RedactPii = true
|
||||
settings.EnablePromptCollection = true
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1, "exactly one synth service expected")
|
||||
|
||||
for _, parserID := range []string{middlewareIDLLMRequestParser, middlewareIDLLMResponseParser} {
|
||||
raw := decodeMiddlewareRawConfig(t, services[0], parserID)
|
||||
var cfg parserRedactConfig
|
||||
require.NoError(t, json.Unmarshal(raw, &cfg), "%s config must be valid JSON", parserID)
|
||||
assert.True(t, cfg.RedactPii, "%s config must carry redact_pii=true when settings.RedactPii is on (otherwise the parser ships raw prompts/completions to the access log)", parserID)
|
||||
}
|
||||
// The capture flag is set explicitly to enable_prompt_collection on each
|
||||
// parser. With it on here, both must allow emission.
|
||||
reqCfg := decodeParserConfig(t, services[0], middlewareIDLLMRequestParser)
|
||||
require.NotNil(t, reqCfg.CapturePrompt, "request parser must carry an explicit capture_prompt")
|
||||
assert.True(t, *reqCfg.CapturePrompt, "capture_prompt=true when EnablePromptCollection=true")
|
||||
respCfg := decodeParserConfig(t, services[0], middlewareIDLLMResponseParser)
|
||||
require.NotNil(t, respCfg.CaptureCompletion, "response parser must carry an explicit capture_completion")
|
||||
assert.True(t, *respCfg.CaptureCompletion, "capture_completion=true when EnablePromptCollection=true")
|
||||
}
|
||||
|
||||
// decodeParserConfig is a small helper around decodeMiddlewareRawConfig that
|
||||
// also unmarshals into parserRedactConfig.
|
||||
func decodeParserConfig(t *testing.T, svc *rpservice.Service, parserID string) parserRedactConfig {
|
||||
t.Helper()
|
||||
raw := decodeMiddlewareRawConfig(t, svc, parserID)
|
||||
var cfg parserRedactConfig
|
||||
require.NoError(t, json.Unmarshal(raw, &cfg), "%s config must be valid JSON", parserID)
|
||||
return cfg
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_ParserConfigsSuppressCaptureWhenLogCollectionOnly
|
||||
// is the contract test for the bug: enable_log_collection=true with
|
||||
// enable_prompt_collection=false MUST result in capture_prompt=false on the
|
||||
// request parser AND capture_completion=false on the response parser, so the
|
||||
// access-log row stays metadata-only (provider, model, tokens, cost) and
|
||||
// carries NO prompt input nor response output. Without this, operators who
|
||||
// want billing-style logs end up with raw user prompts and model outputs in
|
||||
// every access-log entry.
|
||||
func TestSynthesizeServices_RealStore_ParserConfigsSuppressCaptureWhenLogCollectionOnly(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.EnableLogCollection = true // operator wants logs ON
|
||||
settings.EnablePromptCollection = false // but NOT content capture
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
reqCfg := decodeParserConfig(t, services[0], middlewareIDLLMRequestParser)
|
||||
require.NotNil(t, reqCfg.CapturePrompt, "request parser must carry an explicit capture_prompt gate")
|
||||
assert.False(t, *reqCfg.CapturePrompt, "capture_prompt MUST be false when EnablePromptCollection is off — otherwise llm.request_prompt_raw leaks user input into the access log")
|
||||
|
||||
respCfg := decodeParserConfig(t, services[0], middlewareIDLLMResponseParser)
|
||||
require.NotNil(t, respCfg.CaptureCompletion, "response parser must carry an explicit capture_completion gate")
|
||||
assert.False(t, *respCfg.CaptureCompletion, "capture_completion MUST be false when EnablePromptCollection is off — otherwise llm.response_completion leaks model output into the access log")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_ParserConfigsOmitRedactPiiWhenOff proves
|
||||
// the inverse: with the account toggle off, the parser configs stay clean (no
|
||||
// redact_pii field, which the parsers treat as zero / no redaction). This is
|
||||
// the operator-opt-out path — the access log keeps raw prompts/completions
|
||||
// for debugging until the operator opts in.
|
||||
func TestSynthesizeServices_RealStore_ParserConfigsOmitRedactPiiWhenOff(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
|
||||
// Default settings: RedactPii = false.
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
for _, parserID := range []string{middlewareIDLLMRequestParser, middlewareIDLLMResponseParser} {
|
||||
raw := decodeMiddlewareRawConfig(t, services[0], parserID)
|
||||
// Inspect the decoded JSON directly: a struct decode would also pass
|
||||
// if redact_pii were present-but-false. The contract is that the key
|
||||
// is omitted entirely while the account toggle is off.
|
||||
var rawCfg map[string]json.RawMessage
|
||||
require.NoError(t, json.Unmarshal(raw, &rawCfg), "%s config must be valid JSON", parserID)
|
||||
assert.NotContains(t, rawCfg, "redact_pii",
|
||||
"%s config must omit redact_pii entirely while the account toggle is off", parserID)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// decodeServiceRouterConfig finds the llm_router middleware on the synthesised
|
||||
// service's single target and decodes its config — the model→provider routing
|
||||
// table the proxy authorises against.
|
||||
func decodeServiceRouterConfig(t *testing.T, svc *rpservice.Service) routerConfig {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, svc.Targets, "synth service must carry a target")
|
||||
for _, mw := range svc.Targets[0].Options.Middlewares {
|
||||
if mw.ID == middlewareIDLLMRouter {
|
||||
var cfg routerConfig
|
||||
require.NoError(t, json.Unmarshal(mw.ConfigJSON, &cfg), "router config must decode")
|
||||
return cfg
|
||||
}
|
||||
}
|
||||
t.Fatal("llm_router middleware not present on synthesised service")
|
||||
return routerConfig{}
|
||||
}
|
||||
|
||||
// decodeMappingRouterConfig is the proto-wire equivalent: it pulls the
|
||||
// llm_router config off the ProxyMapping the proxy actually receives.
|
||||
func decodeMappingRouterConfig(t *testing.T, m *proto.ProxyMapping) routerConfig {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, m.GetPath(), "mapping must carry a path")
|
||||
for _, mw := range m.GetPath()[0].GetOptions().GetMiddlewares() {
|
||||
if mw.GetId() == middlewareIDLLMRouter {
|
||||
var cfg routerConfig
|
||||
require.NoError(t, json.Unmarshal(mw.GetConfigJson(), &cfg), "wire router config must decode")
|
||||
return cfg
|
||||
}
|
||||
}
|
||||
t.Fatal("llm_router middleware not present on proxy mapping")
|
||||
return routerConfig{}
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_SurvivesStatusToggle drives synthesis through
|
||||
// a REAL sqlite store (Save → gorm/JSON serialize → reload → decrypt) instead of
|
||||
// a MockStore, so it exercises the field round-trip that a provider/policy edit
|
||||
// actually hits. Mock-based tests can't catch a field that dies in persistence;
|
||||
// this one can. It then performs the exact operation that reproduced the live
|
||||
// 403 — disable then re-enable the provider — and asserts the re-enabled state
|
||||
// is fully routable again.
|
||||
func TestSynthesizeServices_RealStore_SurvivesStatusToggle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
assertRoutable := func(t *testing.T, stage string) {
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err, stage)
|
||||
require.Len(t, services, 1, "%s: exactly one synth service expected", stage)
|
||||
svc := services[0]
|
||||
|
||||
assert.True(t, svc.Private, "%s: synth service must be Private after store round-trip", stage)
|
||||
assert.Equal(t, []string{"grp-eng"}, svc.AccessGroups, "%s: AccessGroups must survive the round-trip", stage)
|
||||
|
||||
m := svc.ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{})
|
||||
assert.True(t, m.GetPrivate(), "%s: proto mapping Private must be true (proxy gates tunnel-peer auth on it)", stage)
|
||||
|
||||
cfg := decodeServiceRouterConfig(t, svc)
|
||||
require.Len(t, cfg.Providers, 1, "%s: the enabled+linked provider must appear in the router config", stage)
|
||||
assert.Equal(t, []string{"gpt-5.4"}, cfg.Providers[0].Models, "%s: provider models must reach the route", stage)
|
||||
assert.Equal(t, []string{"grp-eng"}, cfg.Providers[0].AllowedGroupIDs, "%s: policy source groups must reach the route", stage)
|
||||
}
|
||||
|
||||
assertRoutable(t, "initial")
|
||||
|
||||
provider.Enabled = false
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
disabled, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err, "synthesis must not error with a disabled provider")
|
||||
for _, svc := range disabled {
|
||||
assert.Empty(t, decodeServiceRouterConfig(t, svc).Providers,
|
||||
"a disabled provider must not appear in the router config (otherwise it would route while off)")
|
||||
}
|
||||
|
||||
provider.Enabled = true
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
assertRoutable(t, "after disable->enable")
|
||||
}
|
||||
|
||||
// captureController is a proxy.Controller that records the mappings reconcile
|
||||
// pushes, so the test can inspect the exact wire payload — Private flag and
|
||||
// router config included.
|
||||
type captureController struct {
|
||||
rpproxy.Controller
|
||||
pushed []*proto.ProxyMapping
|
||||
}
|
||||
|
||||
func (c *captureController) GetOIDCValidationConfig() rpproxy.OIDCValidationConfig {
|
||||
return rpproxy.OIDCValidationConfig{}
|
||||
}
|
||||
|
||||
func (c *captureController) SendServiceUpdateToCluster(_ context.Context, _ string, update *proto.ProxyMapping, _ string) {
|
||||
c.pushed = append(c.pushed, update)
|
||||
}
|
||||
|
||||
// noopAccountManager satisfies the reconcile path's accountManager dependency.
|
||||
type noopAccountManager struct {
|
||||
account.Manager
|
||||
}
|
||||
|
||||
func (noopAccountManager) UpdateAccountPeers(context.Context, string, nbtypes.UpdateReason) {}
|
||||
|
||||
// TestReconcile_RealStore_PushesPrivateAfterStatusToggle reproduces the live
|
||||
// path end-to-end below the gRPC boundary: a real store + the real
|
||||
// managerImpl.reconcile + a capturing proxy controller. It runs the operation
|
||||
// that broke in production — provider disable then re-enable — and asserts the
|
||||
// mapping reconcile pushes to the cluster after re-enable is Private=true and
|
||||
// carries the routable provider. If reconcile ever pushes private=false (the
|
||||
// symptom that left UserGroups empty → no_authorised_provider), this fails.
|
||||
func TestReconcile_RealStore_PushesPrivateAfterStatusToggle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
ctrl := &captureController{}
|
||||
m := &managerImpl{
|
||||
store: s,
|
||||
accountManager: noopAccountManager{},
|
||||
proxyController: ctrl,
|
||||
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
m.reconcile(ctx, testAccountID) // initial, provider enabled
|
||||
|
||||
provider.Enabled = false
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
m.reconcile(ctx, testAccountID) // disabled
|
||||
|
||||
provider.Enabled = true
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
m.reconcile(ctx, testAccountID) // re-enabled — the reproduction step
|
||||
|
||||
require.NotEmpty(t, ctrl.pushed, "reconcile must push at least one mapping")
|
||||
last := ctrl.pushed[len(ctrl.pushed)-1]
|
||||
|
||||
assert.Equal(t, newSynthTestSettings().Endpoint(), last.GetDomain(), "synth domain on the wire")
|
||||
assert.True(t, last.GetPrivate(),
|
||||
"reconcile-pushed mapping after re-enable MUST be Private=true; a false here is the exact bug — the proxy skips ValidateTunnelPeer, UserGroups stays empty, and llm_router denies no_authorised_provider")
|
||||
|
||||
cfg := decodeMappingRouterConfig(t, last)
|
||||
require.Len(t, cfg.Providers, 1, "re-enabled provider must be back in the pushed router config")
|
||||
assert.Equal(t, []string{"gpt-5.4"}, cfg.Providers[0].Models, "model must be routable again after re-enable")
|
||||
assert.Equal(t, []string{"grp-eng"}, cfg.Providers[0].AllowedGroupIDs, "authorised groups must be present after re-enable")
|
||||
}
|
||||
1098
management/internals/modules/agentnetwork/synthesizer_test.go
Normal file
1098
management/internals/modules/agentnetwork/synthesizer_test.go
Normal file
File diff suppressed because it is too large
Load Diff
117
management/internals/modules/agentnetwork/types/accesslog.go
Normal file
117
management/internals/modules/agentnetwork/types/accesslog.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// AgentNetworkAccessLog is the dedicated, flattened agent-network access-log
|
||||
// row. Unlike the shared reverse-proxy AccessLogEntry (which kept LLM data in
|
||||
// an opaque metadata JSON blob), the LLM dimensions live in first-class,
|
||||
// indexed columns so the access-log surface can filter server-side by
|
||||
// user / group / provider / model / decision.
|
||||
type AgentNetworkAccessLog struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
ServiceID string `gorm:"index"`
|
||||
Timestamp time.Time `gorm:"index"`
|
||||
UserID string `gorm:"index"`
|
||||
SourceIP string
|
||||
Method string
|
||||
Host string
|
||||
Path string `gorm:"type:text"`
|
||||
Duration time.Duration
|
||||
StatusCode int `gorm:"index"`
|
||||
AuthMethod string
|
||||
BytesUpload int64
|
||||
BytesDownload int64
|
||||
|
||||
// Flattened LLM dimensions (queryable). Sourced from proxy metadata keys.
|
||||
Provider string `gorm:"index"` // vendor, e.g. "openai" (llm.provider)
|
||||
Model string `gorm:"index"` // llm.model
|
||||
SessionID string `gorm:"index"` // llm.session_id — groups a conversation / coding session
|
||||
ResolvedProviderID string `gorm:"index"` // llm.resolved_provider_id
|
||||
SelectedPolicyID string `gorm:"index"` // llm.selected_policy_id
|
||||
Decision string `gorm:"index"` // llm_policy.decision (allow/deny)
|
||||
DenyReason string // llm_policy.reason (raw code, mapped in the UI)
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
TotalTokens int64
|
||||
CostUSD float64
|
||||
Stream bool
|
||||
|
||||
// Prompt capture. Only populated when prompt collection is enabled
|
||||
// (account master switch AND policy guardrail). Heavy free text.
|
||||
RequestPrompt string `gorm:"type:text"`
|
||||
ResponseCompletion string `gorm:"type:text"`
|
||||
|
||||
CreatedAt time.Time
|
||||
|
||||
// GroupIDs is the authorising group ids for this entry, hydrated from the
|
||||
// group child table on read. Not a column.
|
||||
GroupIDs []string `gorm:"-"`
|
||||
}
|
||||
|
||||
// TableName keeps agent-network access logs in their own table, separate from
|
||||
// the reverse-proxy AccessLogEntry table.
|
||||
func (AgentNetworkAccessLog) TableName() string { return "agent_network_access_log" }
|
||||
|
||||
// ToAPIResponse renders the flattened entry as the API representation.
|
||||
func (a *AgentNetworkAccessLog) ToAPIResponse() api.AgentNetworkAccessLog {
|
||||
out := api.AgentNetworkAccessLog{
|
||||
Id: a.ID,
|
||||
ServiceId: a.ServiceID,
|
||||
Timestamp: a.Timestamp,
|
||||
StatusCode: a.StatusCode,
|
||||
DurationMs: int(a.Duration.Milliseconds()),
|
||||
InputTokens: a.InputTokens,
|
||||
OutputTokens: a.OutputTokens,
|
||||
TotalTokens: a.TotalTokens,
|
||||
CostUsd: a.CostUSD,
|
||||
Stream: &a.Stream,
|
||||
}
|
||||
|
||||
out.UserId = strPtr(a.UserID)
|
||||
out.SourceIp = strPtr(a.SourceIP)
|
||||
out.Method = strPtr(a.Method)
|
||||
out.Host = strPtr(a.Host)
|
||||
out.Path = strPtr(a.Path)
|
||||
out.Provider = strPtr(a.Provider)
|
||||
out.Model = strPtr(a.Model)
|
||||
out.SessionId = strPtr(a.SessionID)
|
||||
out.ResolvedProviderId = strPtr(a.ResolvedProviderID)
|
||||
out.SelectedPolicyId = strPtr(a.SelectedPolicyID)
|
||||
out.Decision = strPtr(a.Decision)
|
||||
out.DenyReason = strPtr(a.DenyReason)
|
||||
out.RequestPrompt = strPtr(a.RequestPrompt)
|
||||
out.ResponseCompletion = strPtr(a.ResponseCompletion)
|
||||
|
||||
if len(a.GroupIDs) > 0 {
|
||||
groups := a.GroupIDs
|
||||
out.GroupIds = &groups
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// strPtr returns a pointer to s, or nil when s is empty — so empty optional
|
||||
// fields are omitted from the JSON rather than serialised as "".
|
||||
func strPtr(s string) *string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
// AgentNetworkAccessLogGroup is the normalised many-to-many row linking a log
|
||||
// entry to one authorising group, so the access-log endpoint can filter by
|
||||
// group with a simple `group_id IN (...)` join instead of substring-matching a
|
||||
// CSV column.
|
||||
type AgentNetworkAccessLogGroup struct {
|
||||
LogID string `gorm:"primaryKey"`
|
||||
GroupID string `gorm:"primaryKey;index"`
|
||||
AccountID string `gorm:"index"`
|
||||
}
|
||||
|
||||
// TableName names the access-log group child table.
|
||||
func (AgentNetworkAccessLogGroup) TableName() string { return "agent_network_access_log_group" }
|
||||
@@ -0,0 +1,213 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
// AccessLogDefaultPageSize is the default number of records per page.
|
||||
AccessLogDefaultPageSize = 50
|
||||
// AccessLogMaxPageSize is the maximum number of records allowed per page.
|
||||
AccessLogMaxPageSize = 100
|
||||
|
||||
accessLogDefaultSortBy = "timestamp"
|
||||
accessLogDefaultSortOrder = "desc"
|
||||
|
||||
// usageOverviewDefaultLookback bounds an unbounded usage-overview query so
|
||||
// it never aggregates an account's entire history into memory.
|
||||
usageOverviewDefaultLookback = 90 * 24 * time.Hour
|
||||
// usageOverviewMaxRange caps how far back an explicit range may reach.
|
||||
usageOverviewMaxRange = 366 * 24 * time.Hour
|
||||
)
|
||||
|
||||
// ApplyUsageOverviewBounds bounds a missing or over-wide date range so the
|
||||
// in-memory usage aggregation can't load an account's full usage history. An
|
||||
// absent range defaults to the last usageOverviewDefaultLookback; a range wider
|
||||
// than usageOverviewMaxRange is clamped from the (possibly defaulted) end.
|
||||
func (f *AgentNetworkAccessLogFilter) ApplyUsageOverviewBounds(now time.Time) {
|
||||
end := now
|
||||
if f.EndDate != nil {
|
||||
end = *f.EndDate
|
||||
}
|
||||
f.EndDate = &end
|
||||
if f.StartDate == nil {
|
||||
start := end.Add(-usageOverviewDefaultLookback)
|
||||
f.StartDate = &start
|
||||
return
|
||||
}
|
||||
if end.Sub(*f.StartDate) > usageOverviewMaxRange {
|
||||
start := end.Add(-usageOverviewMaxRange)
|
||||
f.StartDate = &start
|
||||
}
|
||||
}
|
||||
|
||||
// accessLogSortFields maps the API sort_by values to their database columns.
|
||||
var accessLogSortFields = map[string]string{
|
||||
"timestamp": "timestamp",
|
||||
"model": "model",
|
||||
"provider": "provider",
|
||||
"status_code": "status_code",
|
||||
"duration": "duration",
|
||||
"cost_usd": "cost_usd",
|
||||
"total_tokens": "total_tokens",
|
||||
"user_id": "user_id",
|
||||
"decision": "decision",
|
||||
}
|
||||
|
||||
// AgentNetworkAccessLogFilter holds pagination, filtering and sorting
|
||||
// parameters for the agent-network access-log listing. Group / provider /
|
||||
// model are multi-valued (the UI uses multi-select; an entry matches when it
|
||||
// matches any selected value).
|
||||
type AgentNetworkAccessLogFilter struct {
|
||||
Page int
|
||||
PageSize int
|
||||
|
||||
SortBy string
|
||||
SortOrder string
|
||||
|
||||
Search *string // log id, host, path, model, user email/name
|
||||
UserID *string // exact user id (the dashboard sends the picked user's id)
|
||||
SessionID *string // exact session id — groups one conversation / coding session
|
||||
GroupIDs []string // authorising group ids (match any)
|
||||
ProviderIDs []string // resolved provider ids (match any)
|
||||
Models []string // models (match any)
|
||||
Decision *string // policy decision (allow/deny)
|
||||
PathPrefix *string // request path prefix (path LIKE 'prefix%')
|
||||
StartDate *time.Time // timestamp >= start_date
|
||||
EndDate *time.Time // timestamp <= end_date
|
||||
}
|
||||
|
||||
// ParseFromRequest fills the filter from the request query parameters. It
|
||||
// returns a validation error when a supplied start_date / end_date is present
|
||||
// but not valid RFC3339: silently dropping a malformed date would broaden the
|
||||
// query (and, for the usage overview, fall back to the default window).
|
||||
func (f *AgentNetworkAccessLogFilter) ParseFromRequest(r *http.Request) error {
|
||||
q := r.URL.Query()
|
||||
|
||||
f.Page = parseAccessLogPositiveInt(q.Get("page"), 1)
|
||||
f.PageSize = min(parseAccessLogPositiveInt(q.Get("page_size"), AccessLogDefaultPageSize), AccessLogMaxPageSize)
|
||||
|
||||
f.SortBy = parseAccessLogSortField(q.Get("sort_by"))
|
||||
f.SortOrder = parseAccessLogSortOrder(q.Get("sort_order"))
|
||||
|
||||
f.Search = parseAccessLogOptionalString(q.Get("search"))
|
||||
f.UserID = parseAccessLogOptionalString(q.Get("user_id"))
|
||||
f.SessionID = parseAccessLogOptionalString(q.Get("session_id"))
|
||||
f.Decision = parseAccessLogOptionalString(q.Get("decision"))
|
||||
f.PathPrefix = parseAccessLogOptionalString(q.Get("path"))
|
||||
// Multi-value filters accept either repeated params (?group_id=a&group_id=b)
|
||||
// or a single comma-separated value (?group_id=a,b) so both the OpenAPI
|
||||
// array form and the dashboard's single-value query builder work.
|
||||
f.GroupIDs = splitMultiValue(q["group_id"])
|
||||
f.ProviderIDs = splitMultiValue(q["provider_id"])
|
||||
f.Models = splitMultiValue(q["model"])
|
||||
|
||||
var err error
|
||||
if f.StartDate, err = parseAccessLogOptionalRFC3339(q.Get("start_date")); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "invalid start_date: %v", err)
|
||||
}
|
||||
if f.EndDate, err = parseAccessLogOptionalRFC3339(q.Get("end_date")); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "invalid end_date: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSortColumn returns the database column for the active sort field.
|
||||
func (f *AgentNetworkAccessLogFilter) GetSortColumn() string {
|
||||
if col, ok := accessLogSortFields[f.SortBy]; ok {
|
||||
return col
|
||||
}
|
||||
return accessLogSortFields[accessLogDefaultSortBy]
|
||||
}
|
||||
|
||||
// GetSortOrder returns the normalised sort order ("ASC"/"DESC").
|
||||
func (f *AgentNetworkAccessLogFilter) GetSortOrder() string {
|
||||
if strings.EqualFold(f.SortOrder, "asc") {
|
||||
return "ASC"
|
||||
}
|
||||
return "DESC"
|
||||
}
|
||||
|
||||
// GetLimit returns the page size, defaulting/clamping when unset.
|
||||
func (f *AgentNetworkAccessLogFilter) GetLimit() int {
|
||||
if f.PageSize <= 0 {
|
||||
return AccessLogDefaultPageSize
|
||||
}
|
||||
return min(f.PageSize, AccessLogMaxPageSize)
|
||||
}
|
||||
|
||||
// GetOffset returns the zero-based row offset for the active page. Page is
|
||||
// user-controlled, so the multiplication is guarded against int overflow.
|
||||
func (f *AgentNetworkAccessLogFilter) GetOffset() int {
|
||||
limit := f.GetLimit()
|
||||
if f.Page <= 1 || limit <= 0 {
|
||||
return 0
|
||||
}
|
||||
if f.Page-1 > math.MaxInt/limit {
|
||||
return math.MaxInt - (math.MaxInt % limit)
|
||||
}
|
||||
return (f.Page - 1) * limit
|
||||
}
|
||||
|
||||
func parseAccessLogPositiveInt(s string, def int) int {
|
||||
if v, err := strconv.Atoi(strings.TrimSpace(s)); err == nil && v > 0 {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func parseAccessLogSortField(s string) string {
|
||||
if _, ok := accessLogSortFields[s]; ok {
|
||||
return s
|
||||
}
|
||||
return accessLogDefaultSortBy
|
||||
}
|
||||
|
||||
func parseAccessLogSortOrder(s string) string {
|
||||
if strings.EqualFold(s, "asc") {
|
||||
return "asc"
|
||||
}
|
||||
return accessLogDefaultSortOrder
|
||||
}
|
||||
|
||||
func parseAccessLogOptionalString(s string) *string {
|
||||
if s = strings.TrimSpace(s); s != "" {
|
||||
return &s
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAccessLogOptionalRFC3339(s string) (*time.Time, error) {
|
||||
if s = strings.TrimSpace(s); s == "" {
|
||||
return nil, nil //nolint:nilnil // not provided: no value and no error
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
// splitMultiValue flattens repeated query params and comma-separated values
|
||||
// into a single trimmed, blank-free list. Returns nil when nothing remains so
|
||||
// callers can skip the filter entirely.
|
||||
func splitMultiValue(values []string) []string {
|
||||
out := make([]string, 0, len(values))
|
||||
for _, raw := range values {
|
||||
for _, v := range strings.Split(raw, ",") {
|
||||
if v = strings.TrimSpace(v); v != "" {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
106
management/internals/modules/agentnetwork/types/budgetrule.go
Normal file
106
management/internals/modules/agentnetwork/types/budgetrule.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// AccountBudgetRule is an account-level, limit-only rule bound to groups
|
||||
// and/or users. It mirrors the policy budget experience without any routing:
|
||||
// it carries the same cap shape as a policy (PolicyLimits) but never selects a
|
||||
// provider. Rules apply across policies as an always-on ceiling — every
|
||||
// applicable rule binds (min-wins), so a rule can only tighten a caller's
|
||||
// effective limit, never loosen it.
|
||||
//
|
||||
// TargetGroups matches when it intersects the caller's groups; TargetUsers
|
||||
// binds a specific user directly. Empty TargetGroups and TargetUsers means the
|
||||
// rule applies to every caller (the account-wide default).
|
||||
type AccountBudgetRule struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Enabled bool
|
||||
TargetGroups []string `gorm:"serializer:json;column:target_groups"`
|
||||
TargetUsers []string `gorm:"serializer:json;column:target_users"`
|
||||
Limits PolicyLimits `gorm:"serializer:json;column:limits"`
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// TableName puts budget rules in their own table.
|
||||
func (AccountBudgetRule) TableName() string { return "agent_network_budget_rules" }
|
||||
|
||||
// NewAccountBudgetRule returns a new rule with a freshly minted ID.
|
||||
func NewAccountBudgetRule(accountID string) *AccountBudgetRule {
|
||||
now := time.Now().UTC()
|
||||
return &AccountBudgetRule{
|
||||
ID: "ainbud_" + xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the rule, including its target slices.
|
||||
func (r *AccountBudgetRule) Copy() *AccountBudgetRule {
|
||||
c := *r
|
||||
c.TargetGroups = append([]string(nil), r.TargetGroups...)
|
||||
c.TargetUsers = append([]string(nil), r.TargetUsers...)
|
||||
return &c
|
||||
}
|
||||
|
||||
// EventMeta renders the rule for the activity log.
|
||||
func (r *AccountBudgetRule) EventMeta() map[string]any {
|
||||
return map[string]any{
|
||||
"name": r.Name,
|
||||
"enabled": r.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
// FromAPIRequest applies the request payload onto the receiver.
|
||||
func (r *AccountBudgetRule) FromAPIRequest(req *api.AgentNetworkBudgetRuleRequest) {
|
||||
r.Name = req.Name
|
||||
if req.Enabled != nil {
|
||||
r.Enabled = *req.Enabled
|
||||
}
|
||||
if req.TargetGroups != nil {
|
||||
r.TargetGroups = append([]string(nil), (*req.TargetGroups)...)
|
||||
} else {
|
||||
r.TargetGroups = []string{}
|
||||
}
|
||||
if req.TargetUsers != nil {
|
||||
r.TargetUsers = append([]string(nil), (*req.TargetUsers)...)
|
||||
} else {
|
||||
r.TargetUsers = []string{}
|
||||
}
|
||||
r.Limits = limitsFromAPI(req.Limits)
|
||||
}
|
||||
|
||||
// ToAPIResponse renders the rule as the API representation.
|
||||
func (r *AccountBudgetRule) ToAPIResponse() *api.AgentNetworkBudgetRule {
|
||||
groups := r.TargetGroups
|
||||
if groups == nil {
|
||||
groups = []string{}
|
||||
}
|
||||
users := r.TargetUsers
|
||||
if users == nil {
|
||||
users = []string{}
|
||||
}
|
||||
created := r.CreatedAt
|
||||
updated := r.UpdatedAt
|
||||
return &api.AgentNetworkBudgetRule{
|
||||
Id: r.ID,
|
||||
Name: r.Name,
|
||||
Enabled: r.Enabled,
|
||||
TargetGroups: groups,
|
||||
TargetUsers: users,
|
||||
Limits: limitsToAPI(r.Limits),
|
||||
CreatedAt: &created,
|
||||
UpdatedAt: &updated,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package types
|
||||
|
||||
import "time"
|
||||
|
||||
// ConsumptionDimension classifies which kind of identity a consumption
|
||||
// row counts against. The proxy-side enforcement layer ticks one row
|
||||
// per dimension per request — typically one user row plus one group
|
||||
// row.
|
||||
type ConsumptionDimension string
|
||||
|
||||
const (
|
||||
// DimensionUser counts tokens / spend for a single end user. The
|
||||
// dim_id column carries the netbird user id (or peer.ID when the
|
||||
// caller is a tunnel-peer principal).
|
||||
DimensionUser ConsumptionDimension = "user"
|
||||
// DimensionGroup counts tokens / spend for a single source group
|
||||
// across every member of that group. The dim_id column carries
|
||||
// the netbird group id.
|
||||
DimensionGroup ConsumptionDimension = "group"
|
||||
)
|
||||
|
||||
// Consumption is a per-dimension token + USD counter for a fixed
|
||||
// aligned window. The (account, dim_kind, dim_id, window_seconds,
|
||||
// window_start) tuple is the primary key; rows are rolled forward by
|
||||
// the proxy's post-flight RecordLLMUsage path on every request.
|
||||
//
|
||||
// The same dim_id (e.g. a group id) gets one row per distinct
|
||||
// window_seconds length in scope across the account's policies,
|
||||
// because two policies with different window lengths read independent
|
||||
// counters even though they share the dimension. Two policies with
|
||||
// identical window_seconds on the same dimension share one counter
|
||||
// (correct: their caps are checked against the same shared bucket).
|
||||
type Consumption struct {
|
||||
AccountID string `gorm:"primaryKey;type:varchar(255)"`
|
||||
DimensionKind ConsumptionDimension `gorm:"primaryKey;type:varchar(16);column:dim_kind"`
|
||||
DimensionID string `gorm:"primaryKey;type:varchar(255);column:dim_id"`
|
||||
WindowSeconds int64 `gorm:"primaryKey;column:window_seconds"`
|
||||
WindowStartUTC time.Time `gorm:"primaryKey;column:window_start_utc"`
|
||||
TokensInput int64 `gorm:"column:tokens_input"`
|
||||
TokensOutput int64 `gorm:"column:tokens_output"`
|
||||
CostUSD float64 `gorm:"column:cost_usd"`
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// TableName forces a stable name independent of GORM's pluraliser.
|
||||
func (Consumption) TableName() string { return "agent_network_consumption" }
|
||||
|
||||
// ConsumptionKey identifies a single consumption counter within an account:
|
||||
// the (dim_kind, dim_id, window_seconds, window_start) part of the row's
|
||||
// primary key. Used to batch-read and batch-increment many counters for one
|
||||
// request in a single store round-trip / transaction.
|
||||
type ConsumptionKey struct {
|
||||
Kind ConsumptionDimension
|
||||
DimID string
|
||||
WindowSeconds int64
|
||||
WindowStartUTC time.Time
|
||||
}
|
||||
|
||||
// WindowStart returns the aligned UTC start of the window of length
|
||||
// windowSeconds that contains t. Aligned to the unix epoch so the
|
||||
// same bucket boundary is computed deterministically across processes.
|
||||
func WindowStart(t time.Time, windowSeconds int64) time.Time {
|
||||
if windowSeconds <= 0 {
|
||||
return t.UTC()
|
||||
}
|
||||
step := windowSeconds * int64(time.Second)
|
||||
bucketed := t.UTC().UnixNano() / step * step
|
||||
return time.Unix(0, bucketed).UTC()
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestWindowStart_AlignedToUnixEpoch is the multi-node-convergence
|
||||
// guarantee: any two proxies computing WindowStart(now, s) for the
|
||||
// same s must land on the same boundary. The implementation aligns
|
||||
// to the unix epoch (UTC) rather than local time, calendar weeks, or
|
||||
// process start time — none of which are shared across nodes.
|
||||
//
|
||||
// Table covers the load-bearing window lengths (5m, 1h, 24h, 30d)
|
||||
// plus a few odd values that still need to align cleanly.
|
||||
func TestWindowStart_AlignedToUnixEpoch(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
instant time.Time
|
||||
windowSeconds int64
|
||||
want time.Time
|
||||
}{
|
||||
{
|
||||
name: "5m window — drops seconds inside the bucket",
|
||||
instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.UTC),
|
||||
windowSeconds: 300,
|
||||
want: time.Date(2026, 5, 6, 13, 45, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "1h window — drops minutes / seconds, keeps the hour",
|
||||
instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.UTC),
|
||||
windowSeconds: 3600,
|
||||
want: time.Date(2026, 5, 6, 13, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "24h window aligns to UTC midnight",
|
||||
instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.UTC),
|
||||
windowSeconds: 86_400,
|
||||
want: time.Date(2026, 5, 6, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "30d (2_592_000s) window aligns to the 30d epoch grid, not month boundaries",
|
||||
instant: time.Date(2026, 5, 6, 0, 0, 0, 0, time.UTC),
|
||||
windowSeconds: 2_592_000,
|
||||
// 2026-05-06 UTC = 1778025600s; 1778025600 / 2592000 = 685
|
||||
// 685 * 2592000 = 1775520000s = 2026-04-07 00:00:00 UTC
|
||||
want: time.Date(2026, 4, 7, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "non-UTC input still anchors on UTC epoch boundaries",
|
||||
instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.FixedZone("CEST", 2*3600)),
|
||||
windowSeconds: 86_400,
|
||||
// 2026-05-06 13:47:23 CEST = 11:47:23 UTC → bucket 2026-05-06 00:00:00 UTC
|
||||
want: time.Date(2026, 5, 6, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := WindowStart(tc.instant, tc.windowSeconds)
|
||||
assert.True(t, got.Equal(tc.want),
|
||||
"WindowStart(%v, %ds) = %v, want %v", tc.instant, tc.windowSeconds, got, tc.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWindowStart_WithinWindowConverges proves the determinism
|
||||
// contract: any two timestamps inside the same window land on the
|
||||
// exact same boundary. Two proxy nodes serving requests 7s apart
|
||||
// must agree on which counter row to upsert.
|
||||
func TestWindowStart_WithinWindowConverges(t *testing.T) {
|
||||
t1 := time.Date(2026, 5, 6, 14, 0, 0, 0, time.UTC)
|
||||
t2 := t1.Add(7 * time.Second)
|
||||
t3 := t1.Add(59*time.Minute + 59*time.Second)
|
||||
|
||||
a := WindowStart(t1, 3600)
|
||||
b := WindowStart(t2, 3600)
|
||||
c := WindowStart(t3, 3600)
|
||||
|
||||
assert.True(t, a.Equal(b), "two timestamps 7s apart in the same 1h window must align to the same boundary")
|
||||
assert.True(t, a.Equal(c), "the very last second of a 1h window still lands on the SAME bucket as the first second")
|
||||
}
|
||||
|
||||
// TestWindowStart_AcrossWindowsDiverges is the symmetric guarantee:
|
||||
// two timestamps separated by a window's worth of time MUST land on
|
||||
// different boundaries. Without this, a 24h window's "rollover"
|
||||
// would never reset the counter.
|
||||
func TestWindowStart_AcrossWindowsDiverges(t *testing.T) {
|
||||
t1 := time.Date(2026, 5, 6, 23, 59, 59, 0, time.UTC)
|
||||
t2 := t1.Add(2 * time.Second) // 2026-05-07 00:00:01
|
||||
|
||||
a := WindowStart(t1, 86_400)
|
||||
b := WindowStart(t2, 86_400)
|
||||
assert.False(t, a.Equal(b),
|
||||
"timestamps straddling a 24h-window boundary must land on different buckets — otherwise daily caps never reset")
|
||||
}
|
||||
|
||||
// TestWindowStart_DifferentWindowsHaveDifferentBuckets locks the
|
||||
// design fork "two policies with different window_seconds on the same
|
||||
// group produce independent counters". A 24h boundary at noon is NOT
|
||||
// the same as the 30d boundary that contains it.
|
||||
func TestWindowStart_DifferentWindowsHaveDifferentBuckets(t *testing.T) {
|
||||
now := time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC)
|
||||
short := WindowStart(now, 86_400)
|
||||
long := WindowStart(now, 2_592_000)
|
||||
assert.False(t, short.Equal(long),
|
||||
"the 24h bucket and 30d bucket containing the same instant must differ — independent counters require independent keys")
|
||||
}
|
||||
|
||||
// TestWindowStart_SubMinuteAndMinuteAlignment locks sub-hour windows.
|
||||
// A 5-minute window must align to multiples of 300s from the unix
|
||||
// epoch — minute marks 0/5/10/.../55 within an hour, deterministic
|
||||
// across nodes regardless of clock drift.
|
||||
func TestWindowStart_SubMinuteAndMinuteAlignment(t *testing.T) {
|
||||
t1 := time.Date(2026, 5, 6, 14, 12, 30, 0, time.UTC)
|
||||
t2 := time.Date(2026, 5, 6, 14, 14, 59, 0, time.UTC)
|
||||
t3 := time.Date(2026, 5, 6, 14, 15, 0, 0, time.UTC)
|
||||
|
||||
a := WindowStart(t1, 300)
|
||||
b := WindowStart(t2, 300)
|
||||
c := WindowStart(t3, 300)
|
||||
|
||||
assert.True(t, a.Equal(b),
|
||||
"14:12:30 and 14:14:59 fall in the same 5m bucket starting at 14:10:00")
|
||||
assert.True(t, a.Equal(time.Date(2026, 5, 6, 14, 10, 0, 0, time.UTC)),
|
||||
"5m bucket containing 14:12 starts at 14:10 — aligned to multiples of 300s from unix epoch")
|
||||
assert.False(t, a.Equal(c),
|
||||
"14:15:00 is the start of the next 5m bucket — must not fold into the previous one")
|
||||
}
|
||||
|
||||
// TestWindowStart_ZeroWindowReturnsInputUTC covers the defensive
|
||||
// path: caller hands a zero / negative window (shouldn't happen, but
|
||||
// might mid-refactor). The function returns the input as UTC rather
|
||||
// than dividing by zero.
|
||||
func TestWindowStart_ZeroWindowReturnsInputUTC(t *testing.T) {
|
||||
now := time.Date(2026, 5, 6, 12, 30, 45, 0, time.FixedZone("CEST", 2*3600))
|
||||
got := WindowStart(now, 0)
|
||||
assert.True(t, got.Equal(now.UTC()), "zero window must not panic — return input as UTC")
|
||||
}
|
||||
120
management/internals/modules/agentnetwork/types/guardrail.go
Normal file
120
management/internals/modules/agentnetwork/types/guardrail.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// GuardrailChecks is the configurable parameter set persisted with each
|
||||
// guardrail. Stored as a JSON blob to keep the table flat.
|
||||
type GuardrailChecks struct {
|
||||
ModelAllowlist GuardrailModelAllowlist `json:"model_allowlist"`
|
||||
PromptCapture GuardrailPromptCapture `json:"prompt_capture"`
|
||||
}
|
||||
|
||||
type GuardrailModelAllowlist struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Models []string `json:"models"`
|
||||
}
|
||||
|
||||
type GuardrailPromptCapture struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
RedactPii bool `json:"redact_pii"`
|
||||
}
|
||||
|
||||
// Guardrail is an Agent Network reusable guardrail set persisted per account.
|
||||
type Guardrail struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Description string
|
||||
Checks GuardrailChecks `gorm:"serializer:json"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// TableName uses an explicit name so guardrail rows live in their own
|
||||
// table.
|
||||
func (Guardrail) TableName() string { return "agent_network_guardrails" }
|
||||
|
||||
// NewGuardrail returns a new Guardrail with a freshly minted ID.
|
||||
func NewGuardrail(accountID string) *Guardrail {
|
||||
now := time.Now().UTC()
|
||||
return &Guardrail{
|
||||
ID: "ainguard_" + xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Checks: GuardrailChecks{ModelAllowlist: GuardrailModelAllowlist{Models: []string{}}},
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
// FromAPIRequest applies the request payload onto the receiver.
|
||||
func (g *Guardrail) FromAPIRequest(req *api.AgentNetworkGuardrailRequest) {
|
||||
g.Name = req.Name
|
||||
if req.Description != nil {
|
||||
g.Description = *req.Description
|
||||
}
|
||||
g.Checks = checksFromAPI(req.Checks)
|
||||
}
|
||||
|
||||
// ToAPIResponse renders the guardrail as the API representation.
|
||||
func (g *Guardrail) ToAPIResponse() *api.AgentNetworkGuardrail {
|
||||
created := g.CreatedAt
|
||||
updated := g.UpdatedAt
|
||||
return &api.AgentNetworkGuardrail{
|
||||
Id: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Checks: checksToAPI(g.Checks),
|
||||
CreatedAt: &created,
|
||||
UpdatedAt: &updated,
|
||||
}
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the guardrail.
|
||||
func (g *Guardrail) Copy() *Guardrail {
|
||||
clone := *g
|
||||
if g.Checks.ModelAllowlist.Models != nil {
|
||||
clone.Checks.ModelAllowlist.Models = append([]string(nil), g.Checks.ModelAllowlist.Models...)
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
|
||||
// EventMeta is the audit-log payload for activity events.
|
||||
func (g *Guardrail) EventMeta() map[string]any {
|
||||
return map[string]any{"name": g.Name}
|
||||
}
|
||||
|
||||
func checksFromAPI(c api.AgentNetworkGuardrailChecks) GuardrailChecks {
|
||||
models := append([]string(nil), c.ModelAllowlist.Models...)
|
||||
if models == nil {
|
||||
models = []string{}
|
||||
}
|
||||
return GuardrailChecks{
|
||||
ModelAllowlist: GuardrailModelAllowlist{
|
||||
Enabled: c.ModelAllowlist.Enabled,
|
||||
Models: models,
|
||||
},
|
||||
PromptCapture: GuardrailPromptCapture{
|
||||
Enabled: c.PromptCapture.Enabled,
|
||||
RedactPii: c.PromptCapture.RedactPii,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func checksToAPI(c GuardrailChecks) api.AgentNetworkGuardrailChecks {
|
||||
models := c.ModelAllowlist.Models
|
||||
if models == nil {
|
||||
models = []string{}
|
||||
}
|
||||
out := api.AgentNetworkGuardrailChecks{}
|
||||
out.ModelAllowlist.Enabled = c.ModelAllowlist.Enabled
|
||||
out.ModelAllowlist.Models = models
|
||||
out.PromptCapture.Enabled = c.PromptCapture.Enabled
|
||||
out.PromptCapture.RedactPii = c.PromptCapture.RedactPii
|
||||
return out
|
||||
}
|
||||
192
management/internals/modules/agentnetwork/types/policy.go
Normal file
192
management/internals/modules/agentnetwork/types/policy.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// Policy is an Agent Network policy persisted per account. A policy
|
||||
// authorises members of SourceGroups to reach the listed
|
||||
// DestinationProviderIDs under the attached GuardrailIDs and Limits.
|
||||
//
|
||||
// Token and budget limits live on the Policy itself (Limits field);
|
||||
// guardrails carry only model allowlist and prompt capture.
|
||||
type Policy struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Description string
|
||||
Enabled bool
|
||||
SourceGroups []string `gorm:"serializer:json;column:source_groups"`
|
||||
DestinationProviderIDs []string `gorm:"serializer:json;column:destination_provider_ids"`
|
||||
GuardrailIDs []string `gorm:"serializer:json;column:guardrail_ids"`
|
||||
Limits PolicyLimits `gorm:"serializer:json;column:limits"`
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// PolicyLimits aggregates the token and budget caps attached directly
|
||||
// to a policy. Both halves are always present; their Enabled flags
|
||||
// control whether the proxy enforces them.
|
||||
type PolicyLimits struct {
|
||||
TokenLimit PolicyTokenLimit `json:"token_limit"`
|
||||
BudgetLimit PolicyBudgetLimit `json:"budget_limit"`
|
||||
}
|
||||
|
||||
// PolicyTokenLimit is a token-count cap evaluated over an aligned
|
||||
// window of WindowSeconds seconds. GroupCap is applied to each
|
||||
// source group independently — every group in the policy's
|
||||
// SourceGroups gets its own bucket of GroupCap tokens. UserCap
|
||||
// applies independently to each individual user. A zero cap means
|
||||
// uncapped. WindowSeconds must be at least 60 (one minute) when the
|
||||
// limit is enabled.
|
||||
type PolicyTokenLimit struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
GroupCap int64 `json:"group_cap"`
|
||||
UserCap int64 `json:"user_cap"`
|
||||
WindowSeconds int64 `json:"window_seconds"`
|
||||
}
|
||||
|
||||
// PolicyBudgetLimit is a USD spend cap evaluated over an aligned
|
||||
// window of WindowSeconds seconds. GroupCapUsd is applied to each
|
||||
// source group independently — every group in the policy's
|
||||
// SourceGroups gets its own bucket of GroupCapUsd USD. UserCapUsd
|
||||
// applies independently to each individual user. A zero cap means
|
||||
// uncapped. WindowSeconds must be at least 60 (one minute) when the
|
||||
// limit is enabled.
|
||||
type PolicyBudgetLimit struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
GroupCapUsd float64 `json:"group_cap_usd"`
|
||||
UserCapUsd float64 `json:"user_cap_usd"`
|
||||
WindowSeconds int64 `json:"window_seconds"`
|
||||
}
|
||||
|
||||
// TableName forces a unique GORM table to avoid collision with the access
|
||||
// control Policy type, which also resolves to "policies" by default.
|
||||
func (Policy) TableName() string { return "agent_network_policies" }
|
||||
|
||||
// NewPolicy returns a new Policy with a freshly minted ID.
|
||||
func NewPolicy(accountID string) *Policy {
|
||||
now := time.Now().UTC()
|
||||
return &Policy{
|
||||
ID: "ainpol_" + xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
// FromAPIRequest applies the request payload onto the receiver.
|
||||
func (p *Policy) FromAPIRequest(req *api.AgentNetworkPolicyRequest) {
|
||||
p.Name = req.Name
|
||||
if req.Description != nil {
|
||||
p.Description = *req.Description
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
p.Enabled = *req.Enabled
|
||||
}
|
||||
p.SourceGroups = append([]string(nil), req.SourceGroups...)
|
||||
p.DestinationProviderIDs = append([]string(nil), req.DestinationProviderIds...)
|
||||
if req.GuardrailIds != nil {
|
||||
p.GuardrailIDs = append([]string(nil), (*req.GuardrailIds)...)
|
||||
} else {
|
||||
p.GuardrailIDs = []string{}
|
||||
}
|
||||
if req.Limits != nil {
|
||||
p.Limits = limitsFromAPI(*req.Limits)
|
||||
} else {
|
||||
p.Limits = PolicyLimits{}
|
||||
}
|
||||
}
|
||||
|
||||
// ToAPIResponse renders the policy as the API representation.
|
||||
func (p *Policy) ToAPIResponse() *api.AgentNetworkPolicy {
|
||||
src := p.SourceGroups
|
||||
if src == nil {
|
||||
src = []string{}
|
||||
}
|
||||
dst := p.DestinationProviderIDs
|
||||
if dst == nil {
|
||||
dst = []string{}
|
||||
}
|
||||
guardrails := p.GuardrailIDs
|
||||
if guardrails == nil {
|
||||
guardrails = []string{}
|
||||
}
|
||||
created := p.CreatedAt
|
||||
updated := p.UpdatedAt
|
||||
return &api.AgentNetworkPolicy{
|
||||
Id: p.ID,
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
Enabled: p.Enabled,
|
||||
SourceGroups: src,
|
||||
DestinationProviderIds: dst,
|
||||
GuardrailIds: guardrails,
|
||||
Limits: limitsToAPI(p.Limits),
|
||||
CreatedAt: &created,
|
||||
UpdatedAt: &updated,
|
||||
}
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the policy.
|
||||
func (p *Policy) Copy() *Policy {
|
||||
clone := *p
|
||||
if p.SourceGroups != nil {
|
||||
clone.SourceGroups = append([]string(nil), p.SourceGroups...)
|
||||
}
|
||||
if p.DestinationProviderIDs != nil {
|
||||
clone.DestinationProviderIDs = append([]string(nil), p.DestinationProviderIDs...)
|
||||
}
|
||||
if p.GuardrailIDs != nil {
|
||||
clone.GuardrailIDs = append([]string(nil), p.GuardrailIDs...)
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
|
||||
// EventMeta is the audit-log payload for activity events.
|
||||
func (p *Policy) EventMeta() map[string]any {
|
||||
return map[string]any{
|
||||
"name": p.Name,
|
||||
"enabled": p.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
func limitsFromAPI(in api.AgentNetworkPolicyLimits) PolicyLimits {
|
||||
return PolicyLimits{
|
||||
TokenLimit: PolicyTokenLimit{
|
||||
Enabled: in.TokenLimit.Enabled,
|
||||
GroupCap: in.TokenLimit.GroupCap,
|
||||
UserCap: in.TokenLimit.UserCap,
|
||||
WindowSeconds: in.TokenLimit.WindowSeconds,
|
||||
},
|
||||
BudgetLimit: PolicyBudgetLimit{
|
||||
Enabled: in.BudgetLimit.Enabled,
|
||||
GroupCapUsd: in.BudgetLimit.GroupCapUsd,
|
||||
UserCapUsd: in.BudgetLimit.UserCapUsd,
|
||||
WindowSeconds: in.BudgetLimit.WindowSeconds,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func limitsToAPI(in PolicyLimits) api.AgentNetworkPolicyLimits {
|
||||
return api.AgentNetworkPolicyLimits{
|
||||
TokenLimit: api.AgentNetworkPolicyTokenLimit{
|
||||
Enabled: in.TokenLimit.Enabled,
|
||||
GroupCap: in.TokenLimit.GroupCap,
|
||||
UserCap: in.TokenLimit.UserCap,
|
||||
WindowSeconds: in.TokenLimit.WindowSeconds,
|
||||
},
|
||||
BudgetLimit: api.AgentNetworkPolicyBudgetLimit{
|
||||
Enabled: in.BudgetLimit.Enabled,
|
||||
GroupCapUsd: in.BudgetLimit.GroupCapUsd,
|
||||
UserCapUsd: in.BudgetLimit.UserCapUsd,
|
||||
WindowSeconds: in.BudgetLimit.WindowSeconds,
|
||||
},
|
||||
}
|
||||
}
|
||||
252
management/internals/modules/agentnetwork/types/provider.go
Normal file
252
management/internals/modules/agentnetwork/types/provider.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
// ProviderModel is one row in the provider's models list. The operator
|
||||
// pins the per-1k input/output price for cost tracking; ID is the
|
||||
// model identifier the upstream provider expects on the wire.
|
||||
type ProviderModel struct {
|
||||
ID string `json:"id"`
|
||||
InputPer1k float64 `json:"input_per_1k"`
|
||||
OutputPer1k float64 `json:"output_per_1k"`
|
||||
}
|
||||
|
||||
// Provider is an Agent Network AI provider record persisted per account.
|
||||
// The proxy cluster fronting the account lives on the per-account
|
||||
// agent-network Settings row, not on the Provider — every provider in
|
||||
// an account routes through the same cluster.
|
||||
type Provider struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
ProviderID string `gorm:"index:idx_agent_network_provider"`
|
||||
Name string
|
||||
// UpstreamURL is the full upstream URL (e.g. https://api.openai.com)
|
||||
// the operator selected.
|
||||
UpstreamURL string `gorm:"column:upstream_url"`
|
||||
APIKey string `gorm:"column:api_key"`
|
||||
// ExtraValues holds operator-typed values for catalog-declared
|
||||
// ExtraHeaders (see catalog.Provider.ExtraHeaders). Keyed by
|
||||
// header name (e.g. "x-portkey-config"); a non-empty value is
|
||||
// stamped on every upstream request to this provider via the
|
||||
// proxy's identity-inject middleware (anti-spoof Remove + Add).
|
||||
// Empty / missing keys = no header stamped. Stored as a JSON
|
||||
// blob so the schema doesn't grow per-catalog-entry.
|
||||
ExtraValues map[string]string `gorm:"serializer:json;column:extra_values"`
|
||||
// Models is the operator's curated list of models exposed by this
|
||||
// provider together with their per-1k input/output prices (USD).
|
||||
// Empty means all catalog models are allowed at catalog prices.
|
||||
Models []ProviderModel `gorm:"serializer:json"`
|
||||
Enabled bool
|
||||
// SessionPrivateKey + SessionPublicKey are the ed25519 keypair the
|
||||
// synthesised reverse-proxy service uses to sign / verify session
|
||||
// JWTs after a successful OIDC handshake. Generated once on
|
||||
// provider create and never rotated by the manager so existing
|
||||
// session cookies survive provider edits. SessionPrivateKey is
|
||||
// encrypted at rest via EncryptSensitiveData /
|
||||
// DecryptSensitiveData; SessionPublicKey is plain.
|
||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||
// IdentityHeaderUserID + IdentityHeaderGroups are the operator-
|
||||
// chosen wire header names for HeaderPair-style identity
|
||||
// injection on catalog entries that flag the shape as
|
||||
// Customizable (e.g. Bifrost, where the operator picks between
|
||||
// the always-on x-bf-lh- log-metadata family and the
|
||||
// label-declared x-bf-dim- telemetry family). Empty value
|
||||
// disables stamping for that dimension; the inject middleware
|
||||
// already no-ops on empty header names. Catalog entries with
|
||||
// Customizable=false ignore these fields and use the static
|
||||
// header names defined in their HeaderPairInjection block.
|
||||
IdentityHeaderUserID string `gorm:"column:identity_header_user_id"`
|
||||
IdentityHeaderGroups string `gorm:"column:identity_header_groups"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// TableName uses an explicit name so the Agent Network provider rows live
|
||||
// in their own table, separate from any future "providers"-named entity.
|
||||
func (Provider) TableName() string { return "agent_network_providers" }
|
||||
|
||||
// NewProvider returns a new Provider with a freshly minted ID.
|
||||
func NewProvider(accountID string) *Provider {
|
||||
now := time.Now().UTC()
|
||||
return &Provider{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
// FromAPIRequest applies the request payload onto the receiver. The api_key
|
||||
// is only overwritten when the caller provided one — empty/nil leaves the
|
||||
// existing key intact, so updates can omit it.
|
||||
func (p *Provider) FromAPIRequest(req *api.AgentNetworkProviderRequest) {
|
||||
p.ProviderID = req.ProviderId
|
||||
p.Name = req.Name
|
||||
p.UpstreamURL = req.UpstreamUrl
|
||||
if req.ApiKey != nil && strings.TrimSpace(*req.ApiKey) != "" {
|
||||
p.APIKey = *req.ApiKey
|
||||
}
|
||||
if req.ExtraValues != nil {
|
||||
// Replace the whole map (rather than merge) so unsetting a
|
||||
// value on the dashboard actually clears it. Empty strings
|
||||
// are dropped so we don't waste a row on no-op values.
|
||||
next := make(map[string]string, len(*req.ExtraValues))
|
||||
for k, v := range *req.ExtraValues {
|
||||
v = strings.TrimSpace(v)
|
||||
if v != "" {
|
||||
next[k] = v
|
||||
}
|
||||
}
|
||||
if len(next) == 0 {
|
||||
p.ExtraValues = nil
|
||||
} else {
|
||||
p.ExtraValues = next
|
||||
}
|
||||
}
|
||||
p.Models = p.Models[:0]
|
||||
if req.Models != nil {
|
||||
for _, m := range *req.Models {
|
||||
p.Models = append(p.Models, ProviderModel{
|
||||
ID: m.Id,
|
||||
InputPer1k: m.InputPer1k,
|
||||
OutputPer1k: m.OutputPer1k,
|
||||
})
|
||||
}
|
||||
}
|
||||
if p.Models == nil {
|
||||
p.Models = []ProviderModel{}
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
p.Enabled = *req.Enabled
|
||||
}
|
||||
// Identity-header overrides for catalogs flagged Customizable.
|
||||
// nil pointer = "field omitted on the wire" → leave the stored
|
||||
// value untouched (per the openapi description). Empty string is
|
||||
// an explicit clear that disables stamping for this dimension.
|
||||
if req.IdentityHeaderUserId != nil {
|
||||
p.IdentityHeaderUserID = strings.TrimSpace(*req.IdentityHeaderUserId)
|
||||
}
|
||||
if req.IdentityHeaderGroups != nil {
|
||||
p.IdentityHeaderGroups = strings.TrimSpace(*req.IdentityHeaderGroups)
|
||||
}
|
||||
}
|
||||
|
||||
// ToAPIResponse renders the provider as the API representation. The API
|
||||
// key is intentionally never surfaced.
|
||||
func (p *Provider) ToAPIResponse() *api.AgentNetworkProvider {
|
||||
models := make([]api.AgentNetworkProviderModel, 0, len(p.Models))
|
||||
for _, m := range p.Models {
|
||||
models = append(models, api.AgentNetworkProviderModel{
|
||||
Id: m.ID,
|
||||
InputPer1k: m.InputPer1k,
|
||||
OutputPer1k: m.OutputPer1k,
|
||||
})
|
||||
}
|
||||
created := p.CreatedAt
|
||||
updated := p.UpdatedAt
|
||||
resp := &api.AgentNetworkProvider{
|
||||
Id: p.ID,
|
||||
ProviderId: p.ProviderID,
|
||||
Name: p.Name,
|
||||
UpstreamUrl: p.UpstreamURL,
|
||||
Models: models,
|
||||
Enabled: p.Enabled,
|
||||
CreatedAt: &created,
|
||||
UpdatedAt: &updated,
|
||||
}
|
||||
if len(p.ExtraValues) > 0 {
|
||||
out := make(map[string]string, len(p.ExtraValues))
|
||||
for k, v := range p.ExtraValues {
|
||||
out[k] = v
|
||||
}
|
||||
resp.ExtraValues = &out
|
||||
}
|
||||
if p.IdentityHeaderUserID != "" {
|
||||
v := p.IdentityHeaderUserID
|
||||
resp.IdentityHeaderUserId = &v
|
||||
}
|
||||
if p.IdentityHeaderGroups != "" {
|
||||
v := p.IdentityHeaderGroups
|
||||
resp.IdentityHeaderGroups = &v
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the provider.
|
||||
func (p *Provider) Copy() *Provider {
|
||||
clone := *p
|
||||
if p.Models != nil {
|
||||
clone.Models = append([]ProviderModel(nil), p.Models...)
|
||||
}
|
||||
if p.ExtraValues != nil {
|
||||
clone.ExtraValues = make(map[string]string, len(p.ExtraValues))
|
||||
for k, v := range p.ExtraValues {
|
||||
clone.ExtraValues[k] = v
|
||||
}
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
|
||||
// EventMeta is the audit-log payload for activity events.
|
||||
func (p *Provider) EventMeta() map[string]any {
|
||||
return map[string]any{
|
||||
"name": p.Name,
|
||||
"provider_id": p.ProviderID,
|
||||
}
|
||||
}
|
||||
|
||||
// EncryptSensitiveData encrypts the upstream API key and the session
|
||||
// signing key in place.
|
||||
func (p *Provider) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
if p.APIKey != "" {
|
||||
encrypted, err := enc.Encrypt(p.APIKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt agent network provider api key: %w", err)
|
||||
}
|
||||
p.APIKey = encrypted
|
||||
}
|
||||
if p.SessionPrivateKey != "" {
|
||||
encrypted, err := enc.Encrypt(p.SessionPrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt agent network provider session key: %w", err)
|
||||
}
|
||||
p.SessionPrivateKey = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecryptSensitiveData decrypts the upstream API key and the session
|
||||
// signing key in place.
|
||||
func (p *Provider) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
if p.APIKey != "" {
|
||||
decrypted, err := enc.Decrypt(p.APIKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt agent network provider api key: %w", err)
|
||||
}
|
||||
p.APIKey = decrypted
|
||||
}
|
||||
if p.SessionPrivateKey != "" {
|
||||
decrypted, err := enc.Decrypt(p.SessionPrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt agent network provider session key: %w", err)
|
||||
}
|
||||
p.SessionPrivateKey = decrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
78
management/internals/modules/agentnetwork/types/settings.go
Normal file
78
management/internals/modules/agentnetwork/types/settings.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// DefaultAccessLogRetentionDays is the retention applied to new accounts'
|
||||
// agent-network access logs. Usage records are not subject to this — they are
|
||||
// the long-term aggregate and are retained independently.
|
||||
const DefaultAccessLogRetentionDays = 30
|
||||
|
||||
// Settings is the per-account agent-network configuration row. One
|
||||
// row per account. Cluster + Subdomain are immutable once written and
|
||||
// produce the public endpoint agents call (`<subdomain>.<cluster>`).
|
||||
type Settings struct {
|
||||
AccountID string `gorm:"primaryKey"`
|
||||
Cluster string
|
||||
Subdomain string `gorm:"index:idx_agent_network_settings_cluster_subdomain"`
|
||||
|
||||
// Account-level collection controls sourced by the synthesizer.
|
||||
// EnableLogCollection gates the per-request access-log trail and defaults
|
||||
// ON for new accounts. EnablePromptCollection is the master gate for
|
||||
// request/response prompt capture (AND-gated with the policy-level
|
||||
// guardrail). RedactPii enables PII redaction on captured prompts;
|
||||
// effective redaction is account OR policy.
|
||||
EnableLogCollection bool
|
||||
EnablePromptCollection bool
|
||||
RedactPii bool
|
||||
|
||||
// AccessLogRetentionDays bounds how long full access-log rows are kept; a
|
||||
// periodic sweep deletes older rows. <= 0 means keep indefinitely. Usage
|
||||
// records are unaffected.
|
||||
AccessLogRetentionDays int
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// TableName puts the rows in their own table to keep the agent-network
|
||||
// schema cohesive.
|
||||
func (Settings) TableName() string { return "agent_network_settings" }
|
||||
|
||||
// Endpoint returns the bare hostname agents reach this account at:
|
||||
// `<subdomain>.<cluster>`.
|
||||
func (s *Settings) Endpoint() string {
|
||||
return s.Subdomain + "." + s.Cluster
|
||||
}
|
||||
|
||||
// ToAPIResponse renders the settings as the API representation.
|
||||
func (s *Settings) ToAPIResponse() *api.AgentNetworkSettings {
|
||||
created := s.CreatedAt
|
||||
updated := s.UpdatedAt
|
||||
retention := s.AccessLogRetentionDays
|
||||
return &api.AgentNetworkSettings{
|
||||
Cluster: s.Cluster,
|
||||
Subdomain: s.Subdomain,
|
||||
Endpoint: s.Endpoint(),
|
||||
EnableLogCollection: s.EnableLogCollection,
|
||||
EnablePromptCollection: s.EnablePromptCollection,
|
||||
RedactPii: s.RedactPii,
|
||||
AccessLogRetentionDays: &retention,
|
||||
CreatedAt: &created,
|
||||
UpdatedAt: &updated,
|
||||
}
|
||||
}
|
||||
|
||||
// FromAPIRequest applies the mutable settings fields from the request. Cluster
|
||||
// and Subdomain are immutable and intentionally not touched here.
|
||||
func (s *Settings) FromAPIRequest(req *api.AgentNetworkSettingsRequest) {
|
||||
s.EnableLogCollection = req.EnableLogCollection
|
||||
s.EnablePromptCollection = req.EnablePromptCollection
|
||||
s.RedactPii = req.RedactPii
|
||||
if req.AccessLogRetentionDays != nil {
|
||||
s.AccessLogRetentionDays = *req.AccessLogRetentionDays
|
||||
}
|
||||
}
|
||||
47
management/internals/modules/agentnetwork/types/usage.go
Normal file
47
management/internals/modules/agentnetwork/types/usage.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// AgentNetworkUsage is the stripped, always-collected per-request usage record
|
||||
// powering the Usage overview. Unlike AgentNetworkAccessLog it carries no
|
||||
// request detail (host/path/source IP/prompt) — only the dimensions needed to
|
||||
// aggregate and filter spend by user / group / provider / model over time.
|
||||
//
|
||||
// It is written unconditionally on every served agent-network request,
|
||||
// independent of the account's EnableLogCollection toggle: when log collection
|
||||
// is off the proxy ships a stripped, usage-only entry and management still
|
||||
// records the usage row (but skips the full AgentNetworkAccessLog row).
|
||||
type AgentNetworkUsage struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Timestamp time.Time `gorm:"index"`
|
||||
UserID string `gorm:"index"`
|
||||
ResolvedProviderID string `gorm:"index"`
|
||||
Provider string // vendor, e.g. "openai"
|
||||
Model string `gorm:"index"`
|
||||
SessionID string `gorm:"index"` // llm.session_id — groups a conversation / coding session
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
TotalTokens int64
|
||||
CostUSD float64
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// TableName keeps usage records in their own stripped table. Named
|
||||
// distinctly (…_request_usage) to avoid colliding with any pre-existing
|
||||
// agent_network_usage table in a shared database.
|
||||
func (AgentNetworkUsage) TableName() string { return "agent_network_request_usage" }
|
||||
|
||||
// AgentNetworkUsageGroup is the normalised many-to-many row linking a usage
|
||||
// record to one authorising group, mirroring AgentNetworkAccessLogGroup so the
|
||||
// usage overview can filter by group with a `group_id IN (...)` join.
|
||||
type AgentNetworkUsageGroup struct {
|
||||
UsageID string `gorm:"primaryKey"`
|
||||
GroupID string `gorm:"primaryKey;index"`
|
||||
AccountID string `gorm:"index"`
|
||||
}
|
||||
|
||||
// TableName names the usage group child table.
|
||||
func (AgentNetworkUsageGroup) TableName() string { return "agent_network_request_usage_group" }
|
||||
@@ -0,0 +1,96 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// UsageGranularity is the time-bucket width for the usage overview. New values
|
||||
// can be added here and handled in bucketStart without touching the store.
|
||||
type UsageGranularity string
|
||||
|
||||
const (
|
||||
UsageGranularityDay UsageGranularity = "day"
|
||||
UsageGranularityWeek UsageGranularity = "week"
|
||||
UsageGranularityMonth UsageGranularity = "month"
|
||||
)
|
||||
|
||||
// ParseUsageGranularity maps the API query value to a granularity, defaulting
|
||||
// to day for empty/unknown input.
|
||||
func ParseUsageGranularity(s string) UsageGranularity {
|
||||
switch UsageGranularity(s) {
|
||||
case UsageGranularityWeek:
|
||||
return UsageGranularityWeek
|
||||
case UsageGranularityMonth:
|
||||
return UsageGranularityMonth
|
||||
default:
|
||||
return UsageGranularityDay
|
||||
}
|
||||
}
|
||||
|
||||
// AgentNetworkUsageBucket is one aggregated usage time bucket. PeriodStart is
|
||||
// the UTC start of the bucket as YYYY-MM-DD.
|
||||
type AgentNetworkUsageBucket struct {
|
||||
PeriodStart string
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
TotalTokens int64
|
||||
CostUSD float64
|
||||
}
|
||||
|
||||
// ToAPIResponse renders the bucket as the API representation.
|
||||
func (b *AgentNetworkUsageBucket) ToAPIResponse() api.AgentNetworkUsageBucket {
|
||||
return api.AgentNetworkUsageBucket{
|
||||
PeriodStart: b.PeriodStart,
|
||||
InputTokens: b.InputTokens,
|
||||
OutputTokens: b.OutputTokens,
|
||||
TotalTokens: b.TotalTokens,
|
||||
CostUsd: b.CostUSD,
|
||||
}
|
||||
}
|
||||
|
||||
// bucketStart truncates t (in UTC) to the start of its bucket for the given
|
||||
// granularity. Week buckets start on Monday (ISO week).
|
||||
func bucketStart(t time.Time, g UsageGranularity) time.Time {
|
||||
t = t.UTC()
|
||||
switch g {
|
||||
case UsageGranularityWeek:
|
||||
// Monday-start week. time.Weekday: Sunday=0..Saturday=6.
|
||||
offset := (int(t.Weekday()) + 6) % 7
|
||||
day := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
return day.AddDate(0, 0, -offset)
|
||||
case UsageGranularityMonth:
|
||||
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
default: // day
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
}
|
||||
|
||||
// AggregateUsageByGranularity buckets the usage rows by the requested
|
||||
// granularity and returns the buckets ordered oldest-first. Aggregation is done
|
||||
// in Go (rather than per-engine SQL date_trunc) so granularities stay portable
|
||||
// across SQLite/Postgres/MySQL and easy to extend.
|
||||
func AggregateUsageByGranularity(rows []*AgentNetworkUsage, g UsageGranularity) []*AgentNetworkUsageBucket {
|
||||
byPeriod := make(map[string]*AgentNetworkUsageBucket)
|
||||
for _, r := range rows {
|
||||
key := bucketStart(r.Timestamp, g).Format("2006-01-02")
|
||||
b := byPeriod[key]
|
||||
if b == nil {
|
||||
b = &AgentNetworkUsageBucket{PeriodStart: key}
|
||||
byPeriod[key] = b
|
||||
}
|
||||
b.InputTokens += r.InputTokens
|
||||
b.OutputTokens += r.OutputTokens
|
||||
b.TotalTokens += r.TotalTokens
|
||||
b.CostUSD += r.CostUSD
|
||||
}
|
||||
|
||||
out := make([]*AgentNetworkUsageBucket, 0, len(byPeriod))
|
||||
for _, b := range byPeriod {
|
||||
out = append(out, b)
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool { return out[i].PeriodStart < out[j].PeriodStart })
|
||||
return out
|
||||
}
|
||||
109
management/internals/modules/agentnetwork/wire_shape_test.go
Normal file
109
management/internals/modules/agentnetwork/wire_shape_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// TestSynthesizedService_WireShape locks down the proto shape that
|
||||
// flows from the synthesizer through ToProtoMapping to the proxy.
|
||||
// Drift between this test and what the proxy expects manifests as
|
||||
// "service not matching" — the proxy receives a mapping but can't
|
||||
// register an SNI/HTTP route from it.
|
||||
func TestSynthesizedService_WireShape(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
provider := newSynthTestProvider()
|
||||
policy := newSynthTestPolicy(provider.ID, "grp-eng", "")
|
||||
|
||||
expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(),
|
||||
[]*types.Provider{provider},
|
||||
[]*types.Policy{policy},
|
||||
[]*types.Guardrail{})
|
||||
|
||||
services, err := SynthesizeServices(ctx, mockStore, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
svc := services[0]
|
||||
mapping := svc.ToProtoMapping(rpservice.Create, "test-token", proxy.OIDCValidationConfig{})
|
||||
|
||||
// Identifiers — account-scoped service ID, settings-derived domain.
|
||||
assert.Equal(t, "agent-net-svc-acct-1", mapping.GetId(), "stable account-scoped virtual service ID")
|
||||
assert.Equal(t, testAccountID, mapping.GetAccountId(), "account id round-trips")
|
||||
assert.Equal(t, testEndpoint, mapping.GetDomain(), "domain matches settings.Endpoint() output")
|
||||
|
||||
// Mode + listen port — addMapping at proxy/server.go switches on Mode.
|
||||
assert.Equal(t, "http", mapping.GetMode(), "synthesised services are HTTP mode")
|
||||
assert.Equal(t, int32(0), mapping.GetListenPort(), "no custom listen port for HTTP services")
|
||||
|
||||
// Auth token + private/tunnel shape: agent-network endpoints authenticate
|
||||
// inbound agents via ValidateTunnelPeer against AccessGroups, not OIDC.
|
||||
assert.Equal(t, "test-token", mapping.GetAuthToken(), "auth token round-trips for proxy CreateProxyPeer")
|
||||
assert.True(t, mapping.GetPrivate(), "synthesised services are private (tunnel-peer auth via AccessGroups)")
|
||||
require.NotNil(t, mapping.GetAuth(), "auth payload carries the session key")
|
||||
assert.False(t, mapping.GetAuth().GetOidc(), "OIDC is off for tunnel-auth agent-network services")
|
||||
|
||||
// Path mappings — proxy/server.go::setupHTTPMapping early-returns when
|
||||
// len(mapping.GetPath()) == 0, so this is a critical assertion.
|
||||
require.Len(t, mapping.GetPath(), 1, "exactly one path mapping for the cluster target")
|
||||
pm := mapping.GetPath()[0]
|
||||
assert.Equal(t, "/", pm.GetPath(), "default path is '/'")
|
||||
assert.Equal(t, "https://noop.invalid/", pm.GetTarget(),
|
||||
"target URL is the placeholder; the router middleware rewrites it per request")
|
||||
require.NotNil(t, pm.GetOptions(), "target options must be populated so direct_upstream + middleware chain reach the proxy")
|
||||
assert.True(t, pm.GetOptions().GetDirectUpstream(), "synth targets imply direct_upstream so the proxy dials via the host stack")
|
||||
assert.True(t, pm.GetOptions().GetAgentNetwork(), "agent_network flag must travel on the wire so the proxy can tag access logs")
|
||||
|
||||
mws := pm.GetOptions().GetMiddlewares()
|
||||
require.Len(t, mws, 8, "eight middlewares reach the proxy: request_parser, router, limit_check, identity_inject, guardrail, limit_record, cost_meter, response_parser")
|
||||
|
||||
assert.Equal(t, middlewareIDLLMRequestParser, mws[0].GetId(), "first middleware id")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[0].GetSlot(), "request parser slot")
|
||||
|
||||
assert.Equal(t, middlewareIDLLMRouter, mws[1].GetId(), "second middleware id")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[1].GetSlot(), "router slot")
|
||||
require.NotEmpty(t, mws[1].GetConfigJson(), "router config must travel on the wire")
|
||||
var routerCfg routerConfig
|
||||
require.NoError(t, json.Unmarshal(mws[1].GetConfigJson(), &routerCfg), "router config decodes")
|
||||
require.Len(t, routerCfg.Providers, 1, "the only enabled provider reaches the router")
|
||||
assert.Equal(t, provider.ID, routerCfg.Providers[0].ID, "router provider id matches synth provider")
|
||||
assert.Equal(t, "Bearer sk-test-key", routerCfg.Providers[0].AuthHeaderValue,
|
||||
"openai catalog template substitutes the API key on the wire")
|
||||
|
||||
assert.Equal(t, middlewareIDLLMLimitCheck, mws[2].GetId(),
|
||||
"limit_check runs after the router so the resolved provider id is available, before identity_inject so a deny doesn't pay the header-stamp cost")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[2].GetSlot())
|
||||
|
||||
assert.Equal(t, middlewareIDLLMIdentityInject, mws[3].GetId(), "fourth middleware id")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[3].GetSlot(), "identity inject slot")
|
||||
require.NotEmpty(t, mws[3].GetConfigJson(), "identity inject config JSON must travel on the wire")
|
||||
|
||||
assert.Equal(t, middlewareIDLLMGuardrail, mws[4].GetId(), "fifth middleware id")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[4].GetSlot(), "guardrail slot")
|
||||
require.NotEmpty(t, mws[4].GetConfigJson(), "guardrail middleware config JSON must travel on the wire")
|
||||
|
||||
assert.Equal(t, middlewareIDLLMLimitRecord, mws[5].GetId(),
|
||||
"limit_record sits FIRST in the response section so it RUNS LAST at runtime — slot order on the response leg is reverse-of-slice")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE, mws[5].GetSlot())
|
||||
|
||||
assert.Equal(t, middlewareIDCostMeter, mws[6].GetId(), "seventh middleware id")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE, mws[6].GetSlot(), "cost meter slot")
|
||||
|
||||
assert.Equal(t, middlewareIDLLMResponseParser, mws[7].GetId(), "eighth middleware id")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE, mws[7].GetSlot(), "response parser slot")
|
||||
}
|
||||
@@ -220,12 +220,36 @@ func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, er
|
||||
func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
||||
existingPeerID, err := m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
||||
if err == nil && existingPeerID != "" {
|
||||
// Peer already exists
|
||||
// Same pubkey already registered — idempotent.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Dedupe stale embedded peer records for the same (account, cluster).
|
||||
// The proxy generates a fresh WireGuard keypair on every startup
|
||||
// (proxy/internal/roundtrip/netbird.go), so without this sweep the
|
||||
// prior embedded peer would linger forever — holding its CGNAT IP
|
||||
// allocation, polluting other peers' rosters, and (most visibly)
|
||||
// leaving the synth DNS pointing at the dead address. The
|
||||
// (account, cluster) tuple identifies "the embedded peer for this
|
||||
// proxy instance at this cluster"; any record matching that tuple
|
||||
// with a different pubkey is by definition stale and must go.
|
||||
staleIDs, err := m.findStaleEmbeddedProxyPeers(ctx, accountID, cluster, peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("scan for stale embedded proxy peers: %w", err)
|
||||
}
|
||||
if len(staleIDs) > 0 {
|
||||
// userID="" + checkConnected=false: the deletion is initiated
|
||||
// by management itself on behalf of the freshly-registering
|
||||
// proxy, not by an end user; the stale peer may still be
|
||||
// marked Connected from its prior session, but its session is
|
||||
// dead by definition (its key no longer exists).
|
||||
if err := m.DeletePeers(ctx, accountID, staleIDs, "", false); err != nil {
|
||||
return fmt.Errorf("delete stale embedded proxy peers %v: %w", staleIDs, err)
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("proxy-%s", xid.New().String())
|
||||
peer := &peer.Peer{
|
||||
newPeer := &peer.Peer{
|
||||
Ephemeral: true,
|
||||
ProxyMeta: peer.ProxyMeta{
|
||||
Cluster: cluster,
|
||||
@@ -242,10 +266,36 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee
|
||||
},
|
||||
}
|
||||
|
||||
_, _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
|
||||
_, _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", newPeer, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create proxy peer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findStaleEmbeddedProxyPeers returns the peer IDs of embedded proxy peer
|
||||
// records in accountID that target the same cluster but carry a different
|
||||
// WireGuard pubkey than the freshly-registering one. Used by CreateProxyPeer
|
||||
// to garbage-collect stale records left behind when the proxy restarts with a
|
||||
// regenerated keypair.
|
||||
func (m *managerImpl) findStaleEmbeddedProxyPeers(ctx context.Context, accountID, cluster, newKey string) ([]string, error) {
|
||||
account, err := m.store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var stale []string
|
||||
for _, p := range account.Peers {
|
||||
if p == nil || !p.ProxyMeta.Embedded {
|
||||
continue
|
||||
}
|
||||
if p.ProxyMeta.Cluster != cluster {
|
||||
continue
|
||||
}
|
||||
if p.Key == newKey {
|
||||
continue
|
||||
}
|
||||
stale = append(stale, p.ID)
|
||||
}
|
||||
return stale, nil
|
||||
}
|
||||
|
||||
@@ -39,6 +39,10 @@ type AccessLogEntry struct {
|
||||
BytesDownload int64 `gorm:"index"`
|
||||
Protocol AccessLogProtocol `gorm:"index"`
|
||||
Metadata map[string]string `gorm:"serializer:json"`
|
||||
// AgentNetwork marks the entry as emitted by a synthesised agent-network
|
||||
// service. Sourced from proto.AccessLog.AgentNetwork the proxy stamps
|
||||
// before shipping. Indexed so the agent-network log surface filters cheaply.
|
||||
AgentNetwork bool `gorm:"index"`
|
||||
}
|
||||
|
||||
// FromProto creates an AccessLogEntry from a proto.AccessLog
|
||||
@@ -58,6 +62,7 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
|
||||
a.BytesDownload = serviceLog.GetBytesDownload()
|
||||
a.Protocol = AccessLogProtocol(serviceLog.GetProtocol())
|
||||
a.Metadata = maps.Clone(serviceLog.GetMetadata())
|
||||
a.AgentNetwork = serviceLog.GetAgentNetwork()
|
||||
|
||||
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
|
||||
if addr, err := netip.ParseAddr(sourceIP); err == nil {
|
||||
|
||||
@@ -2,12 +2,15 @@ package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
@@ -16,6 +19,28 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// Metadata keys the proxy stamps on agent-network access-log entries. These
|
||||
// mirror the constants in proxy/internal/middleware/keys.go and form the wire
|
||||
// contract between the proxy and management; management flattens them into
|
||||
// queryable columns. Keep in sync with the proxy side.
|
||||
const (
|
||||
metaKeyProvider = "llm.provider"
|
||||
metaKeyModel = "llm.model"
|
||||
metaKeyResolvedProviderID = "llm.resolved_provider_id"
|
||||
metaKeySelectedPolicyID = "llm.selected_policy_id"
|
||||
metaKeyPolicyDecision = "llm_policy.decision"
|
||||
metaKeyPolicyReason = "llm_policy.reason"
|
||||
metaKeyInputTokens = "llm.input_tokens" //nolint:gosec // metadata key name, not a credential
|
||||
metaKeyOutputTokens = "llm.output_tokens" //nolint:gosec // metadata key name, not a credential
|
||||
metaKeyTotalTokens = "llm.total_tokens" //nolint:gosec // metadata key name, not a credential
|
||||
metaKeyCostUSDTotal = "cost.usd_total"
|
||||
metaKeyStream = "llm.stream"
|
||||
metaKeySessionID = "llm.session_id"
|
||||
metaKeyAuthorisingGroups = "llm.authorising_groups"
|
||||
metaKeyRequestPrompt = "llm.request_prompt"
|
||||
metaKeyResponseCompletion = "llm.response_completion"
|
||||
)
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
@@ -31,8 +56,14 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, geo g
|
||||
}
|
||||
}
|
||||
|
||||
// SaveAccessLog saves an access log entry to the database after enriching it
|
||||
// SaveAccessLog saves an access log entry to the database after enriching it.
|
||||
// Agent-network entries are flattened into their own dedicated table (queryable
|
||||
// LLM columns + group child rows) instead of the shared reverse-proxy table.
|
||||
func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.AccessLogEntry) error {
|
||||
if logEntry.AgentNetwork {
|
||||
return m.saveAgentNetworkAccessLog(ctx, logEntry)
|
||||
}
|
||||
|
||||
if m.geo != nil && logEntry.GeoLocation.ConnectionIP != nil {
|
||||
location, err := m.geo.Lookup(logEntry.GeoLocation.ConnectionIP)
|
||||
if err != nil {
|
||||
@@ -61,6 +92,184 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveAgentNetworkAccessLog flattens the metadata-bearing access-log entry and
|
||||
// persists it in two parts:
|
||||
//
|
||||
// - The stripped usage record is written unconditionally — usage/cost is
|
||||
// collected on every request regardless of the account's log-collection
|
||||
// toggle (the proxy ships a usage-only entry when logging is disabled).
|
||||
// - The full access-log row (with request detail + prompt) is written only
|
||||
// when the account's EnableLogCollection setting is on. This setting read
|
||||
// is the authoritative gate; the proxy-side strip is defense in depth.
|
||||
func (m *managerImpl) saveAgentNetworkAccessLog(ctx context.Context, logEntry *accesslogs.AccessLogEntry) error {
|
||||
entry, groups := flattenAgentNetworkLog(logEntry)
|
||||
|
||||
usage, usageGroups := usageFromFlattenedLog(entry, groups)
|
||||
if err := m.store.CreateAgentNetworkUsage(ctx, usage, usageGroups); err != nil {
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"account_id": entry.AccountID,
|
||||
"model": entry.Model,
|
||||
}).Errorf("failed to save agent-network usage: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
settings, err := m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, entry.AccountID)
|
||||
if err != nil {
|
||||
// No settings row (or a transient read error) means we can't confirm
|
||||
// log collection is enabled — usage is already saved, so skip the full
|
||||
// row rather than fail the whole ingest.
|
||||
log.WithContext(ctx).Debugf("skipping full agent-network access-log row for account %s: %v", entry.AccountID, err)
|
||||
return nil
|
||||
}
|
||||
if !settings.EnableLogCollection {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.store.CreateAgentNetworkAccessLog(ctx, entry, groups); err != nil {
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"account_id": entry.AccountID,
|
||||
"service_id": entry.ServiceID,
|
||||
"model": entry.Model,
|
||||
"status": entry.StatusCode,
|
||||
}).Errorf("failed to save agent-network access log: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// flattenAgentNetworkLog converts a reverse-proxy AccessLogEntry (whose LLM
|
||||
// dimensions live in the opaque Metadata map) into the flattened
|
||||
// agent-network row + authorising-group child rows.
|
||||
func flattenAgentNetworkLog(e *accesslogs.AccessLogEntry) (*agentNetworkTypes.AgentNetworkAccessLog, []agentNetworkTypes.AgentNetworkAccessLogGroup) {
|
||||
meta := e.Metadata
|
||||
|
||||
var sourceIP string
|
||||
if e.GeoLocation.ConnectionIP != nil {
|
||||
sourceIP = e.GeoLocation.ConnectionIP.String()
|
||||
}
|
||||
|
||||
entry := &agentNetworkTypes.AgentNetworkAccessLog{
|
||||
ID: e.ID,
|
||||
AccountID: e.AccountID,
|
||||
ServiceID: e.ServiceID,
|
||||
Timestamp: e.Timestamp,
|
||||
UserID: e.UserId,
|
||||
SourceIP: sourceIP,
|
||||
Method: e.Method,
|
||||
Host: e.Host,
|
||||
Path: e.Path,
|
||||
Duration: e.Duration,
|
||||
StatusCode: e.StatusCode,
|
||||
AuthMethod: e.AuthMethodUsed,
|
||||
BytesUpload: e.BytesUpload,
|
||||
BytesDownload: e.BytesDownload,
|
||||
|
||||
Provider: meta[metaKeyProvider],
|
||||
Model: meta[metaKeyModel],
|
||||
SessionID: meta[metaKeySessionID],
|
||||
ResolvedProviderID: meta[metaKeyResolvedProviderID],
|
||||
SelectedPolicyID: meta[metaKeySelectedPolicyID],
|
||||
Decision: meta[metaKeyPolicyDecision],
|
||||
DenyReason: meta[metaKeyPolicyReason],
|
||||
InputTokens: parseMetaInt(meta, metaKeyInputTokens),
|
||||
OutputTokens: parseMetaInt(meta, metaKeyOutputTokens),
|
||||
TotalTokens: parseMetaInt(meta, metaKeyTotalTokens),
|
||||
CostUSD: parseMetaFloat(meta, metaKeyCostUSDTotal),
|
||||
Stream: parseMetaBool(meta, metaKeyStream),
|
||||
RequestPrompt: meta[metaKeyRequestPrompt],
|
||||
ResponseCompletion: meta[metaKeyResponseCompletion],
|
||||
}
|
||||
|
||||
var groups []agentNetworkTypes.AgentNetworkAccessLogGroup
|
||||
for _, gid := range parseGroupCSV(meta[metaKeyAuthorisingGroups]) {
|
||||
groups = append(groups, agentNetworkTypes.AgentNetworkAccessLogGroup{
|
||||
LogID: entry.ID,
|
||||
GroupID: gid,
|
||||
AccountID: entry.AccountID,
|
||||
})
|
||||
}
|
||||
return entry, groups
|
||||
}
|
||||
|
||||
// usageFromFlattenedLog derives the stripped usage record (and its group child
|
||||
// rows) from an already-flattened access-log entry. The usage row shares the
|
||||
// log's ID so the two correlate.
|
||||
func usageFromFlattenedLog(e *agentNetworkTypes.AgentNetworkAccessLog, groups []agentNetworkTypes.AgentNetworkAccessLogGroup) (*agentNetworkTypes.AgentNetworkUsage, []agentNetworkTypes.AgentNetworkUsageGroup) {
|
||||
usage := &agentNetworkTypes.AgentNetworkUsage{
|
||||
ID: e.ID,
|
||||
AccountID: e.AccountID,
|
||||
Timestamp: e.Timestamp,
|
||||
UserID: e.UserID,
|
||||
ResolvedProviderID: e.ResolvedProviderID,
|
||||
Provider: e.Provider,
|
||||
Model: e.Model,
|
||||
SessionID: e.SessionID,
|
||||
InputTokens: e.InputTokens,
|
||||
OutputTokens: e.OutputTokens,
|
||||
TotalTokens: e.TotalTokens,
|
||||
CostUSD: e.CostUSD,
|
||||
}
|
||||
|
||||
usageGroups := make([]agentNetworkTypes.AgentNetworkUsageGroup, 0, len(groups))
|
||||
for _, g := range groups {
|
||||
usageGroups = append(usageGroups, agentNetworkTypes.AgentNetworkUsageGroup{
|
||||
UsageID: usage.ID,
|
||||
GroupID: g.GroupID,
|
||||
AccountID: g.AccountID,
|
||||
})
|
||||
}
|
||||
return usage, usageGroups
|
||||
}
|
||||
|
||||
// parseMetaInt parses a non-negative token count. Negative or unparseable
|
||||
// values are clamped to 0 so a malformed metric can't persist a negative
|
||||
// counter.
|
||||
func parseMetaInt(meta map[string]string, key string) int64 {
|
||||
if v, err := strconv.ParseInt(strings.TrimSpace(meta[key]), 10, 64); err == nil && v >= 0 {
|
||||
return v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// parseMetaFloat parses a non-negative, finite cost. Negative, NaN, Inf, or
|
||||
// unparseable values are clamped to 0 so a malformed metric can't poison the
|
||||
// stored cost.
|
||||
func parseMetaFloat(meta map[string]string, key string) float64 {
|
||||
if v, err := strconv.ParseFloat(strings.TrimSpace(meta[key]), 64); err == nil && v >= 0 && !math.IsInf(v, 0) {
|
||||
return v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func parseMetaBool(meta map[string]string, key string) bool {
|
||||
v, _ := strconv.ParseBool(strings.TrimSpace(meta[key]))
|
||||
return v
|
||||
}
|
||||
|
||||
// parseGroupCSV splits the comma-separated authorising-group id list the proxy
|
||||
// emits, trimming blanks and de-duplicating. Dedup matters because the group
|
||||
// rows are keyed by (log_id, group_id) / (usage_id, group_id): a repeated id
|
||||
// in the CSV would otherwise produce a duplicate primary key and fail the
|
||||
// insert transaction.
|
||||
func parseGroupCSV(raw string) []string {
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(raw, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
seen := make(map[string]struct{}, len(parts))
|
||||
for _, p := range parts {
|
||||
if p = strings.TrimSpace(p); p != "" {
|
||||
if _, dup := seen[p]; dup {
|
||||
continue
|
||||
}
|
||||
seen[p] = struct{}{}
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
|
||||
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
|
||||
@@ -66,6 +66,51 @@ type TargetOptions struct {
|
||||
// reachable without WireGuard (public APIs, LAN services, localhost
|
||||
// sidecars). Default false.
|
||||
DirectUpstream bool `json:"direct_upstream,omitempty"`
|
||||
// Middlewares carries per-target agent-network middleware configs. Empty
|
||||
// for private and operator-defined services; populated only by the
|
||||
// agent-network synthesizer.
|
||||
Middlewares []MiddlewareConfig `gorm:"serializer:json" json:"middlewares,omitempty"`
|
||||
CaptureMaxRequestBytes int64 `json:"capture_max_request_bytes,omitempty"`
|
||||
CaptureMaxResponseBytes int64 `json:"capture_max_response_bytes,omitempty"`
|
||||
CaptureContentTypes []string `gorm:"serializer:json" json:"capture_content_types,omitempty"`
|
||||
// AgentNetwork marks targets synthesised from Agent Network state. The
|
||||
// proxy uses it to gate agent-network-specific behaviour (access log
|
||||
// tagging, observability, etc.).
|
||||
AgentNetwork bool `json:"agent_network,omitempty"`
|
||||
// DisableAccessLog suppresses the per-request access-log emission for this
|
||||
// target. Defaults false to preserve access-log behaviour for every
|
||||
// non-agent-network target. The agent-network synthesizer sets this true
|
||||
// only when the account's EnableLogCollection toggle is off.
|
||||
DisableAccessLog bool `json:"disable_access_log,omitempty"`
|
||||
}
|
||||
|
||||
// MiddlewareSlot mirrors proto.MiddlewareSlot / middleware.Slot.
|
||||
type MiddlewareSlot string
|
||||
|
||||
const (
|
||||
MiddlewareSlotOnRequest MiddlewareSlot = "on_request"
|
||||
MiddlewareSlotOnResponse MiddlewareSlot = "on_response"
|
||||
MiddlewareSlotTerminal MiddlewareSlot = "terminal"
|
||||
)
|
||||
|
||||
// MiddlewareFailMode mirrors proto.MiddlewareConfig_FailMode.
|
||||
type MiddlewareFailMode string
|
||||
|
||||
const (
|
||||
MiddlewareFailOpen MiddlewareFailMode = "fail_open"
|
||||
MiddlewareFailClosed MiddlewareFailMode = "fail_closed"
|
||||
)
|
||||
|
||||
// MiddlewareConfig is the per-target configuration for a single
|
||||
// middleware instance. Mirrors proto.MiddlewareConfig.
|
||||
type MiddlewareConfig struct {
|
||||
ID string `json:"id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Slot MiddlewareSlot `json:"slot"`
|
||||
ConfigJSON []byte `json:"config_json,omitempty"`
|
||||
FailMode MiddlewareFailMode `json:"fail_mode,omitempty"`
|
||||
TimeoutMs int32 `json:"timeout_ms,omitempty"`
|
||||
CanMutate bool `json:"can_mutate"`
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
@@ -504,21 +549,75 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||
|
||||
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 &&
|
||||
len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
|
||||
len(opts.CustomHeaders) == 0 && !opts.DirectUpstream &&
|
||||
len(opts.Middlewares) == 0 && opts.CaptureMaxRequestBytes == 0 &&
|
||||
opts.CaptureMaxResponseBytes == 0 && len(opts.CaptureContentTypes) == 0 &&
|
||||
!opts.AgentNetwork && !opts.DisableAccessLog {
|
||||
return nil
|
||||
}
|
||||
popts := &proto.PathTargetOptions{
|
||||
SkipTlsVerify: opts.SkipTLSVerify,
|
||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||
CustomHeaders: opts.CustomHeaders,
|
||||
DirectUpstream: opts.DirectUpstream,
|
||||
SkipTlsVerify: opts.SkipTLSVerify,
|
||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||
CustomHeaders: opts.CustomHeaders,
|
||||
DirectUpstream: opts.DirectUpstream,
|
||||
AgentNetwork: opts.AgentNetwork,
|
||||
DisableAccessLog: opts.DisableAccessLog,
|
||||
}
|
||||
if opts.RequestTimeout != 0 {
|
||||
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
|
||||
}
|
||||
if len(opts.Middlewares) > 0 {
|
||||
popts.Middlewares = middlewaresToProto(opts.Middlewares)
|
||||
}
|
||||
popts.CaptureMaxRequestBytes = opts.CaptureMaxRequestBytes
|
||||
popts.CaptureMaxResponseBytes = opts.CaptureMaxResponseBytes
|
||||
if len(opts.CaptureContentTypes) > 0 {
|
||||
popts.CaptureContentTypes = append([]string(nil), opts.CaptureContentTypes...)
|
||||
}
|
||||
return popts
|
||||
}
|
||||
|
||||
// middlewaresToProto converts the internal middleware slice to the proto
|
||||
// representation sent to the proxy via the mapping stream.
|
||||
func middlewaresToProto(in []MiddlewareConfig) []*proto.MiddlewareConfig {
|
||||
out := make([]*proto.MiddlewareConfig, 0, len(in))
|
||||
for _, m := range in {
|
||||
pm := &proto.MiddlewareConfig{
|
||||
Id: m.ID,
|
||||
Enabled: m.Enabled,
|
||||
Slot: middlewareSlotToProto(m.Slot),
|
||||
ConfigJson: append([]byte(nil), m.ConfigJSON...),
|
||||
CanMutate: m.CanMutate,
|
||||
FailMode: middlewareFailModeToProto(m.FailMode),
|
||||
}
|
||||
if m.TimeoutMs > 0 {
|
||||
pm.Timeout = durationpb.New(time.Duration(m.TimeoutMs) * time.Millisecond)
|
||||
}
|
||||
out = append(out, pm)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func middlewareSlotToProto(s MiddlewareSlot) proto.MiddlewareSlot {
|
||||
switch s {
|
||||
case MiddlewareSlotOnRequest:
|
||||
return proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST
|
||||
case MiddlewareSlotOnResponse:
|
||||
return proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE
|
||||
case MiddlewareSlotTerminal:
|
||||
return proto.MiddlewareSlot_MIDDLEWARE_SLOT_TERMINAL
|
||||
default:
|
||||
return proto.MiddlewareSlot_MIDDLEWARE_SLOT_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func middlewareFailModeToProto(m MiddlewareFailMode) proto.MiddlewareConfig_FailMode {
|
||||
if m == MiddlewareFailClosed {
|
||||
return proto.MiddlewareConfig_FAIL_CLOSED
|
||||
}
|
||||
return proto.MiddlewareConfig_FAIL_OPEN
|
||||
}
|
||||
|
||||
// l4TargetOptionsToProto converts L4-relevant target options to proto.
|
||||
func l4TargetOptionsToProto(target *Target) *proto.PathTargetOptions {
|
||||
if !target.ProxyProtocol && target.Options.RequestTimeout == 0 && target.Options.SessionIdleTimeout == 0 {
|
||||
|
||||
@@ -26,9 +26,11 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
@@ -120,7 +122,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount, s.AgentNetworkManager())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -223,11 +225,35 @@ func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
proxyService.SetServiceManager(s.ServiceManager())
|
||||
proxyService.SetProxyController(s.ServiceProxyController())
|
||||
proxyService.SetAgentNetworkSynthesizer(newAgentNetworkSynthesizer(s.Store()))
|
||||
proxyService.SetAgentNetworkLimitsService(s.AgentNetworkManager())
|
||||
})
|
||||
return proxyService
|
||||
})
|
||||
}
|
||||
|
||||
// agentNetworkSynthesizerAdapter implements nbgrpc.AgentNetworkSynthesizer by
|
||||
// delegating to the agentnetwork package's store-backed synthesiser.
|
||||
type agentNetworkSynthesizerAdapter struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
func newAgentNetworkSynthesizer(s store.Store) *agentNetworkSynthesizerAdapter {
|
||||
return &agentNetworkSynthesizerAdapter{store: s}
|
||||
}
|
||||
|
||||
func (a *agentNetworkSynthesizerAdapter) SynthesizeServicesForCluster(ctx context.Context, clusterAddr string) ([]*rpservice.Service, error) {
|
||||
return agentnetwork.SynthesizeServicesForCluster(ctx, a.store, clusterAddr)
|
||||
}
|
||||
|
||||
func (a *agentNetworkSynthesizerAdapter) SynthesizeServicesForAccount(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
|
||||
return agentnetwork.SynthesizeServices(ctx, a.store, accountID)
|
||||
}
|
||||
|
||||
func (a *agentNetworkSynthesizerAdapter) SynthesizeServiceForDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||
return agentnetwork.SynthesizeServiceForDomain(ctx, a.store, domain)
|
||||
}
|
||||
|
||||
func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
|
||||
return Create(s, func() nbgrpc.ProxyOIDCConfig {
|
||||
return nbgrpc.ProxyOIDCConfig{
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
@@ -194,6 +195,24 @@ func (s *BaseServer) NetworksManager() networks.Manager {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AgentNetworkManager() agentnetwork.Manager {
|
||||
return Create(s, func() agentnetwork.Manager {
|
||||
mgr := agentnetwork.NewManager(
|
||||
s.Store(),
|
||||
s.PermissionsManager(),
|
||||
s.AccountManager(),
|
||||
s.ServiceProxyController(),
|
||||
)
|
||||
// Sweep expired agent-network access logs per account retention,
|
||||
// reusing the reverse-proxy cleanup interval config.
|
||||
mgr.StartAccessLogCleanup(
|
||||
context.Background(),
|
||||
s.Config.ReverseProxy.AccessLogCleanupIntervalHours,
|
||||
)
|
||||
return mgr
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ZonesManager() zones.Manager {
|
||||
return Create(s, func() zones.Manager {
|
||||
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
|
||||
|
||||
@@ -13,7 +13,7 @@ const (
|
||||
reconnThreshold = 5 * time.Minute
|
||||
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
|
||||
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
|
||||
metaChangeLimit = 5 // Number of reconnections with different metadata that triggers a ban of one peer
|
||||
)
|
||||
|
||||
type lfConfig struct {
|
||||
@@ -139,7 +139,7 @@ func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
|
||||
state.lastSeen = now
|
||||
}
|
||||
|
||||
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
|
||||
func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
|
||||
h := fnv.New64a()
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
@@ -147,14 +147,6 @@ func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
macs := uint64(0)
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
for _, r := range na.Mac {
|
||||
macs += uint64(r)
|
||||
}
|
||||
}
|
||||
|
||||
return h.Sum64() + macs
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
@@ -164,9 +164,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
KernelVersion: "5.15.0-76-generic",
|
||||
Hostname: "prod-server-database-01",
|
||||
SystemSerialNumber: "PC-1234567890",
|
||||
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
|
||||
}
|
||||
pubip := "8.8.8.8"
|
||||
|
||||
var resultString string
|
||||
var resultUint uint64
|
||||
@@ -175,7 +173,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = builderString(meta, pubip)
|
||||
resultString = builderString(meta)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -183,7 +181,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = fnvHashToString(meta, pubip)
|
||||
resultString = fnvHashToString(meta)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -191,7 +189,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultUint = metaHash(meta, pubip)
|
||||
resultUint = metaHash(meta)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -199,29 +197,20 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
_ = resultUint
|
||||
}
|
||||
|
||||
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
func fnvHashToString(meta nbpeer.PeerSystemMeta) string {
|
||||
h := fnv.New64a()
|
||||
|
||||
if len(meta.NetworkAddresses) != 0 {
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
h.Write([]byte(na.Mac))
|
||||
}
|
||||
}
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
return strconv.FormatUint(h.Sum64(), 16)
|
||||
}
|
||||
|
||||
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
mac := getMacAddress(meta.NetworkAddresses)
|
||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
|
||||
len(pubip) + len(mac) + 6
|
||||
func builderString(meta nbpeer.PeerSystemMeta) string {
|
||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) + 4
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(estimatedSize)
|
||||
@@ -235,23 +224,10 @@ func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
b.WriteString(meta.Hostname)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.SystemSerialNumber)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(pubip)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func getMacAddress(nas []nbpeer.NetworkAddress) string {
|
||||
if len(nas) == 0 {
|
||||
return ""
|
||||
}
|
||||
macs := make([]string, 0, len(nas))
|
||||
for _, na := range nas {
|
||||
macs = append(macs, na.Mac)
|
||||
}
|
||||
return strings.Join(macs, "/")
|
||||
}
|
||||
|
||||
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
|
||||
filter := newLoginFilterWithCfg(testAdvancedCfg())
|
||||
numKeys := 100000
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -35,6 +36,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
proxyauth "github.com/netbirdio/netbird/proxy/auth"
|
||||
@@ -60,6 +62,23 @@ type ProxyTokenChecker interface {
|
||||
}
|
||||
|
||||
// ProxyServiceServer implements the ProxyService gRPC server
|
||||
// AgentNetworkSynthesizer produces in-memory reverse-proxy services from
|
||||
// Agent Network provider/policy state for the proxy snapshot path; synthesised
|
||||
// services never appear in the reverseproxy_services table.
|
||||
type AgentNetworkSynthesizer interface {
|
||||
SynthesizeServicesForCluster(ctx context.Context, clusterAddr string) ([]*rpservice.Service, error)
|
||||
SynthesizeServicesForAccount(ctx context.Context, accountID string) ([]*rpservice.Service, error)
|
||||
SynthesizeServiceForDomain(ctx context.Context, domain string) (*rpservice.Service, error)
|
||||
}
|
||||
|
||||
// AgentNetworkLimitsService is the minimal slice of agentnetwork.Manager the
|
||||
// gRPC layer needs for CheckLLMPolicyLimits + RecordLLMUsage — kept narrow so
|
||||
// the grpc package doesn't take a hard import on the full manager.
|
||||
type AgentNetworkLimitsService interface {
|
||||
SelectPolicyForRequest(ctx context.Context, in agentnetwork.PolicySelectionInput) (*agentnetwork.PolicySelectionResult, error)
|
||||
RecordUsage(ctx context.Context, in agentnetwork.RecordUsageInput) error
|
||||
}
|
||||
|
||||
type ProxyServiceServer struct {
|
||||
proto.UnimplementedProxyServiceServer
|
||||
|
||||
@@ -72,6 +91,14 @@ type ProxyServiceServer struct {
|
||||
mu sync.RWMutex
|
||||
// Manager for reverse proxy operations
|
||||
serviceManager rpservice.Manager
|
||||
// agentNetworkSynth produces synthesised reverse-proxy services from
|
||||
// Agent Network state. Optional — when nil the snapshot path only ships
|
||||
// persisted services.
|
||||
agentNetworkSynth AgentNetworkSynthesizer
|
||||
// agentNetworkLimits handles the pre-flight selection (CheckLLMPolicyLimits)
|
||||
// and the post-flight consumption write (RecordLLMUsage). Optional — when
|
||||
// nil both RPCs return Unimplemented.
|
||||
agentNetworkLimits AgentNetworkLimitsService
|
||||
// ProxyController for service updates and cluster management
|
||||
proxyController proxy.Controller
|
||||
|
||||
@@ -209,6 +236,127 @@ func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
|
||||
s.serviceManager = manager
|
||||
}
|
||||
|
||||
// SetAgentNetworkSynthesizer wires the agent-network service synthesiser.
|
||||
// Optional — when nil the snapshot path skips agent-network synthesis. The
|
||||
// modules layer injects this after both the proxy server and the agent-network
|
||||
// manager are constructed.
|
||||
func (s *ProxyServiceServer) SetAgentNetworkSynthesizer(synth AgentNetworkSynthesizer) {
|
||||
s.mu.Lock()
|
||||
s.agentNetworkSynth = synth
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetAgentNetworkLimitsService wires the policy-selection + post-flight
|
||||
// consumption sink. Pass nil to disable; both RPCs return Unimplemented while
|
||||
// unset so partial wiring surfaces during integration.
|
||||
func (s *ProxyServiceServer) SetAgentNetworkLimitsService(svc AgentNetworkLimitsService) {
|
||||
s.mu.Lock()
|
||||
s.agentNetworkLimits = svc
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// agentNetworkSynthesizer returns the synthesiser under read lock.
|
||||
func (s *ProxyServiceServer) agentNetworkSynthesizer() AgentNetworkSynthesizer {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.agentNetworkSynth
|
||||
}
|
||||
|
||||
// CheckLLMPolicyLimits is the pre-flight policy gate the proxy calls before
|
||||
// forwarding an LLM request upstream. Delegates to the agent-network selector,
|
||||
// which scores applicable policies by remaining headroom and returns the
|
||||
// policy that pays for this request (or a deny when all are exhausted).
|
||||
func (s *ProxyServiceServer) CheckLLMPolicyLimits(ctx context.Context, req *proto.CheckLLMPolicyLimitsRequest) (*proto.CheckLLMPolicyLimitsResponse, error) {
|
||||
s.mu.RLock()
|
||||
svc := s.agentNetworkLimits
|
||||
s.mu.RUnlock()
|
||||
if svc == nil {
|
||||
return nil, status.Errorf(codes.Unimplemented, "agent-network limits service not configured on management")
|
||||
}
|
||||
if req.GetAccountId() == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "account_id is required")
|
||||
}
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := svc.SelectPolicyForRequest(ctx, agentnetwork.PolicySelectionInput{
|
||||
AccountID: req.GetAccountId(),
|
||||
UserID: req.GetUserId(),
|
||||
GroupIDs: req.GetGroupIds(),
|
||||
ProviderID: req.GetProviderId(),
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("select policy for request: %v", err)
|
||||
return nil, status.Error(codes.Internal, "select policy failed")
|
||||
}
|
||||
|
||||
if !res.Allow {
|
||||
return &proto.CheckLLMPolicyLimitsResponse{
|
||||
Decision: "deny",
|
||||
SelectedPolicyId: res.SelectedPolicyID,
|
||||
AttributionGroupId: res.AttributionGroupID,
|
||||
WindowSeconds: res.WindowSeconds,
|
||||
DenyCode: res.DenyCode,
|
||||
DenyReason: res.DenyReason,
|
||||
}, nil
|
||||
}
|
||||
return &proto.CheckLLMPolicyLimitsResponse{
|
||||
Decision: "allow",
|
||||
SelectedPolicyId: res.SelectedPolicyID,
|
||||
AttributionGroupId: res.AttributionGroupID,
|
||||
WindowSeconds: res.WindowSeconds,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RecordLLMUsage increments the per-(dimension, window) consumption counter for
|
||||
// the user and optional attribution group after a served request. Returns
|
||||
// Unimplemented when the agent-network limits service hasn't been wired.
|
||||
func (s *ProxyServiceServer) RecordLLMUsage(ctx context.Context, req *proto.RecordLLMUsageRequest) (*proto.RecordLLMUsageResponse, error) {
|
||||
s.mu.RLock()
|
||||
svc := s.agentNetworkLimits
|
||||
s.mu.RUnlock()
|
||||
if svc == nil {
|
||||
return nil, status.Errorf(codes.Unimplemented, "agent-network limits service not configured on management")
|
||||
}
|
||||
|
||||
accountID := req.GetAccountId()
|
||||
if accountID == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "account_id is required")
|
||||
}
|
||||
if err := enforceAccountScope(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokensIn := req.GetTokensInput()
|
||||
tokensOut := req.GetTokensOutput()
|
||||
costUSD := req.GetCostUsd()
|
||||
|
||||
// Reject impossible counters at the boundary instead of recording them:
|
||||
// a negative window, negative tokens, or a negative / non-finite cost
|
||||
// would otherwise decrement or poison the persisted consumption totals.
|
||||
if req.GetWindowSeconds() < 0 || tokensIn < 0 || tokensOut < 0 || costUSD < 0 || math.IsNaN(costUSD) || math.IsInf(costUSD, 0) {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "usage counters must be non-negative and finite")
|
||||
}
|
||||
|
||||
// Book the policy-window dimensions (when a policy cap bound this request)
|
||||
// and every applicable account budget rule's window in a single batched
|
||||
// transaction.
|
||||
if err := svc.RecordUsage(ctx, agentnetwork.RecordUsageInput{
|
||||
AccountID: accountID,
|
||||
UserID: req.GetUserId(),
|
||||
AttributionGroupID: req.GetGroupId(),
|
||||
GroupIDs: req.GetGroupIds(),
|
||||
WindowSeconds: req.GetWindowSeconds(),
|
||||
TokensIn: tokensIn,
|
||||
TokensOut: tokensOut,
|
||||
CostUSD: costUSD,
|
||||
}); err != nil {
|
||||
log.WithContext(ctx).Errorf("record usage: %v", err)
|
||||
return nil, status.Error(codes.Internal, "record usage failed")
|
||||
}
|
||||
return &proto.RecordLLMUsageResponse{}, nil
|
||||
}
|
||||
|
||||
// SetProxyController sets the proxy controller. Must be called before serving.
|
||||
func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) {
|
||||
s.mu.Lock()
|
||||
@@ -623,12 +771,40 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
return nil, fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
|
||||
if synth := s.agentNetworkSynthesizer(); synth != nil {
|
||||
var synthesised []*rpservice.Service
|
||||
var serr error
|
||||
// Account-scoped connections synthesise only their own account, so the
|
||||
// snapshot can never carry another tenant's mappings (which embed the
|
||||
// upstream auth header derived from that tenant's provider API key).
|
||||
// Global connections still see the whole cluster.
|
||||
if conn.accountID != nil {
|
||||
synthesised, serr = synth.SynthesizeServicesForAccount(ctx, *conn.accountID)
|
||||
} else {
|
||||
synthesised, serr = synth.SynthesizeServicesForCluster(ctx, conn.address)
|
||||
}
|
||||
if serr != nil {
|
||||
// Surface a real synthesis failure instead of silently shipping an
|
||||
// incomplete snapshot (which would drop the account's agent-network
|
||||
// routes). Consistent with the persisted-services error above; the
|
||||
// proxy retries the snapshot on connection error.
|
||||
return nil, fmt.Errorf("synthesise agent-network services: %w", serr)
|
||||
}
|
||||
services = append(services, synthesised...)
|
||||
}
|
||||
|
||||
oidcCfg := s.GetOIDCValidationConfig()
|
||||
var mappings []*proto.ProxyMapping
|
||||
for _, service := range services {
|
||||
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
||||
continue
|
||||
}
|
||||
// Defense in depth: an account-scoped proxy must never receive another
|
||||
// account's mapping, matching the per-account filtering the incremental
|
||||
// update path already applies.
|
||||
if conn.accountID != nil && service.AccountID != *conn.accountID {
|
||||
continue
|
||||
}
|
||||
|
||||
m := service.ToProtoMapping(rpservice.Create, "", oidcCfg)
|
||||
if !proxyAcceptsMapping(conn, m) {
|
||||
@@ -1617,7 +1793,29 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||
return s.serviceManager.GetServiceByDomain(ctx, domain)
|
||||
service, err := s.serviceManager.GetServiceByDomain(ctx, domain)
|
||||
if err == nil {
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// Fall back to the Agent Network synthesiser scoped directly to the domain's
|
||||
// account. Synthesised services are never persisted, so they must resolve
|
||||
// here for OIDC / session / tunnel-peer flows against agent-network
|
||||
// endpoints. Resolving by domain synthesises only the owning account rather
|
||||
// than every tenant on the cluster.
|
||||
if synth := s.agentNetworkSynthesizer(); synth != nil {
|
||||
svc, serr := synth.SynthesizeServiceForDomain(ctx, domain)
|
||||
if serr != nil {
|
||||
// A real synthesis failure must surface, not be masked by the
|
||||
// original store miss — otherwise a transient DB error looks like
|
||||
// "no such service".
|
||||
return nil, fmt.Errorf("synthesize agent-network service for %s: %w", domain, serr)
|
||||
}
|
||||
if svc != nil {
|
||||
return svc, nil
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
|
||||
|
||||
@@ -254,7 +254,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
metahashed := metaHash(peerMeta)
|
||||
if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
||||
@@ -306,7 +306,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
metahash := metaHash(peerMeta, realIP.String())
|
||||
metahash := metaHash(peerMeta)
|
||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
|
||||
@@ -732,7 +732,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
}
|
||||
|
||||
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
metahashed := metaHash(peerMeta)
|
||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.logBlockedPeers {
|
||||
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
|
||||
@@ -788,7 +788,11 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
ExtraDNSLabels: loginReq.GetDnsLabels(),
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
if errors.Is(err, internalStatus.ErrNoAuthMethodProvided) {
|
||||
log.WithContext(ctx).Tracef("failed logging in peer %s: %s", peerKey, err)
|
||||
} else {
|
||||
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
}
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
@@ -1205,7 +1209,7 @@ func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*pr
|
||||
return nil, msg
|
||||
}
|
||||
|
||||
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
|
||||
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()), realIP)
|
||||
if err != nil {
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
@@ -1254,7 +1258,10 @@ func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*prot
|
||||
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
|
||||
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
|
||||
for _, postureCheck := range postureChecks {
|
||||
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
|
||||
check := toProtocolCheck(postureCheck)
|
||||
if check != nil {
|
||||
protoChecks = append(protoChecks, check)
|
||||
}
|
||||
}
|
||||
|
||||
return protoChecks
|
||||
@@ -1278,5 +1285,9 @@ func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks {
|
||||
}
|
||||
}
|
||||
|
||||
if len(protoCheck.Files) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return protoCheck
|
||||
}
|
||||
|
||||
@@ -1889,12 +1889,12 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
||||
// concurrent stream that started earlier loses the optimistic-lock race
|
||||
// in MarkPeerConnected and bails without writing.
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP}, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, accountID, syncTime.UnixNano(), netMap); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
@@ -1914,13 +1914,13 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
|
||||
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
|
||||
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP, UpdateAccountPeers: true}, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user