Compare commits

..

24 Commits

Author SHA1 Message Date
mlsmaycon
668af0dc4f [agent-network] Co-locate HTTP handlers in the module (RegisterEndpoints)
Move the agent-network HTTP handlers from server/http/handlers/agentnetwork into
the module at internals/modules/agentnetwork/handlers (package handlers) and
rename the entrypoint AddEndpoints -> RegisterEndpoints, matching the
reverse-proxy module convention. Wiring in http/handler.go updated accordingly.
2026-06-27 02:58:21 +02:00
mlsmaycon
5f130959ea [agent-network] Relocate agentnetwork package to internals/modules
Move management/server/agentnetwork (and its catalog/, labelgen/, types/
subpackages) to management/internals/modules/agentnetwork, alongside the
reverse-proxy module, and rewrite all importers. Pure relocation: package names,
the synthesizer + affectedpeers registration hook, and store access (shared
store.Store) are unchanged, so no import cycle is introduced (affectedpeers
still depends only on the agentnetwork/types leaf).
2026-06-27 02:55:27 +02:00
mlsmaycon
5644279888 [management] Refine session expiration handling to support 3-state encoding for SSO deadlines 2026-06-27 01:40:41 +02:00
mlsmaycon
f22ac6d271 [agent-network] Polish module docs: remove internal review scaffolding, fix links, verify diagrams
Strip PR-review framing, commit references, absolute paths, and stale internal
references from the agent-network module docs; fix broken relative links; verify
all diagrams against the current architecture. Remove the internal AI-reviewer
prompt file.
2026-06-27 01:39:57 +02:00
mlsmaycon
9f485be2f9 [agent-network] Remove e2e shell-script suite from this branch
The end-to-end shell scripts under scripts/e2e/ are maintained in a separate
testing suite and are not part of this change set.
2026-06-27 01:39:57 +02:00
mlsmaycon
c83e46fbe1 [management] Set LastSeen on injected proxy peer in realstack test (MySQL strict-mode)
The injected embedded proxy peer had a PeerStatus with a zero LastSeen, which
serializes to '0000-00-00' and is rejected by MySQL in strict mode (SQLite
tolerates it). Set LastSeen to a valid time so SaveAccount succeeds on both
engines.
2026-06-27 01:31:20 +02:00
mlsmaycon
405607c584 [agent-network] Fix codespell typos and exclude false positives
- labelgen word pool: vermillion -> vermilion, racoon -> raccoon.
- codespell ignore list: add flate (Go compress/flate package), recordin
  (a test-local identifier), and unparseable (a valid alternative spelling used
  consistently across identifiers + a metadata-value constant).
2026-06-27 01:15:09 +02:00
mlsmaycon
29f55d4255 [agent-network] End-to-end test suite, module docs, and deployment preset 2026-06-27 00:43:35 +02:00
mlsmaycon
3993fa32e4 [proxy] IPv6 in-place apply and TCP accept-loop hardening on netstack listeners 2026-06-27 00:43:35 +02:00
mlsmaycon
6ade3839aa [proxy] LLM parsers, pricing, and builtin middlewares (OpenAI, Anthropic, Vertex AI, AWS Bedrock)
Request/response parsers and SSE/event-stream metering, the embedded pricing
table, and the builtin middleware set: request parser, router, policy
limit-check/record, cost meter, guardrail, identity inject, response parser.
Includes the path-routed providers — Google Vertex AI (keyfile:: service-account
OAuth minting) and AWS Bedrock (bearer auth, invoke/converse/streaming, optional
/bedrock prefix) — plus the Models allowlist and unmeterable-publisher deny.
2026-06-27 00:43:35 +02:00
mlsmaycon
d4d158a8f3 [proxy] Reverse-proxy middleware framework, chain, and request plumbing
The per-target middleware chain (slots, dispatcher, mutation gate, metadata
merger), body capture, access-log terminal sink, and the proxy wiring that
builds + runs chains for synthesized agent-network services.
2026-06-27 00:43:35 +02:00
mlsmaycon
6613d194ef [management] Fix agent-network proxy-peer fan-out on affected-peer recompute
The affected-peers resolver loaded only persisted reverse-proxy services, but
agent-network services are synthesized on demand and never persisted. As a
result the embedded proxy peer was never folded into the affected set when a
client's group changed, so the proxy received no network-map update for a newly
authorised client and rejected its handshake until a full resync (restart).

loadProxyServices now merges the synthesized agent-network services (injected
via a registration hook to avoid an import cycle), so proxy peers learn newly
authorised clients immediately.
2026-06-27 00:43:07 +02:00
mlsmaycon
769e12840d [agent-network] Management: store, manager, synthesizer, policy engine, provider catalog, HTTP/gRPC API
Adds the account-scoped agent-network module: provider/policy/budget CRUD and
store, the reverse-proxy service synthesizer, policy selection + limit
enforcement, the provider catalog (incl. Vertex AI and AWS Bedrock entries),
and the management HTTP + proxy gRPC surfaces.
2026-06-27 00:43:07 +02:00
mlsmaycon
350a96c640 [agent-network] Shared proto, OpenAPI schema, and generated types 2026-06-27 00:43:07 +02:00
dmitri-netbird
615631567a small gh workflow fixes (#6546)
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-26 19:59:15 +02:00
Pascal Fischer
f4daf59bcd [management] bring back client version check on login filter hash (#6552) 2026-06-26 16:36:50 +02:00
Maycon Santos
ff2787e184 [management] Optimize affected posture checks and add logs (#6522) 2026-06-25 17:15:28 +02:00
Pascal Fischer
e20b62ad65 [management] simplify affected peers ignore disabled (#6540) 2026-06-25 16:30:40 +02:00
Riccardo Manfrin
18b38943aa disable connect panel on disabled auto connect (#6542) 2026-06-25 16:20:19 +02:00
Pascal Fischer
a400828b89 [management] move some logs to trace (#6541) 2026-06-25 15:16:54 +02:00
Pascal Fischer
e2bb328a34 [management] less strict metaHash when blocking peers (#6531) 2026-06-25 15:02:43 +02:00
Pascal Fischer
221b9c012c [management] validate posture checks on meta change before account update (#6527) 2026-06-25 15:02:04 +02:00
Viktor Liu
17b2044596 [client] Skip re-resolving cached management cache domains (#6518) 2026-06-23 17:55:57 +02:00
Bethuel Mmbaga
07101c59ac [management] Reschedule inactivity expiration when a peer disconnects (#6523) 2026-06-23 17:44:32 +03:00
255 changed files with 40911 additions and 4178 deletions

View File

@@ -45,7 +45,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags 'devcert privileged' -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/testutil/privileged)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0

View File

@@ -48,14 +48,14 @@ jobs:
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -tags privileged -timeout 1m -failfast ./base62/...
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -tags privileged -timeout 8m -failfast -v -p 1 ./client/...
time go test -tags privileged -timeout 1m -failfast ./dns/...
time go test -tags privileged -timeout 1m -failfast ./encryption/...
time go test -tags privileged -timeout 1m -failfast ./formatter/...
time go test -tags privileged -timeout 1m -failfast ./client/iface/...
time go test -tags privileged -timeout 1m -failfast ./route/...
time go test -tags privileged -timeout 1m -failfast ./sharedsock/...
time go test -tags privileged -timeout 1m -failfast ./util/...
time go test -tags privileged -timeout 1m -failfast ./version/...
time go test -timeout 8m -failfast -v -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@@ -158,7 +158,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
@@ -229,7 +229,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags "devcert privileged" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server -e /client/testutil/privileged)
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
'
test_relay:
@@ -579,10 +579,11 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
env:
GIT_BRANCH: ${{ github.ref_name }}
api_benchmark:
name: "Management / Benchmark (API)"
@@ -673,12 +674,13 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags=benchmark \
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/server/http/...
env:
GIT_BRANCH: ${{ github.ref_name }}
api_integration_test:
name: "Management / Integration"

View File

@@ -68,7 +68,7 @@ jobs:
run: |
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
$cmd = "$goExe test -tags `"devcert privileged`" -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
- name: test

View File

@@ -21,7 +21,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals,flate,recordin,unparseable
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -1,4 +1,4 @@
.PHONY: lint lint-all lint-install setup-hooks test-unit test-privileged
.PHONY: lint lint-all lint-install setup-hooks
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
# Install golangci-lint locally if needed
@@ -25,15 +25,3 @@ setup-hooks:
@git config core.hooksPath .githooks
@chmod +x .githooks/pre-push
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
# Host-safe unit tests: excludes the privileged-tagged tests (root / system-mutating).
# Runs as a normal user with no sudo and leaves host networking untouched.
test-unit:
@go test -tags devcert -timeout 10m ./...
# Privileged suite: runs the `privileged`-tagged tests inside a --privileged
# --cap-add=NET_ADMIN container via the ory/dockertest harness. Requires Docker.
# Narrow the run with env vars, e.g.:
# PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
test-privileged:
@go test -tags 'devcert privileged' -timeout 30m -run TestRunPrivilegedSuiteInDocker -v ./client/testutil/privileged/...

View File

@@ -1,196 +0,0 @@
//go:build privileged
package cmd
import (
"context"
"fmt"
"os"
"runtime"
"testing"
"time"
"github.com/kardianos/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
statusPollInterval = 500 * time.Millisecond
)
// waitForServiceStatus waits for service to reach expected status with timeout
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
cfg, err := newSVCConfig()
if err != nil {
return false, err
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
return false, err
}
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
defer timeoutCancel()
ticker := time.NewTicker(statusPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
case <-ticker.C:
status, err := s.Status()
if err != nil {
// Continue polling on transient errors
continue
}
if status == expectedStatus {
return true, nil
}
}
}
}
// TestServiceLifecycle tests the complete service lifecycle
func TestServiceLifecycle(t *testing.T) {
// TODO: Add support for Windows and macOS
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
}
if os.Getenv("CONTAINER") == "true" {
t.Skip("Skipping service lifecycle test in container environment")
}
originalServiceName := serviceName
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
defer func() {
serviceName = originalServiceName
}()
tempDir := t.TempDir()
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {
installCmd.SetContext(ctx)
err := installCmd.RunE(installCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
status, err := s.Status()
assert.NoError(t, err)
assert.NotEqual(t, service.StatusUnknown, status)
})
t.Run("Start", func(t *testing.T) {
startCmd.SetContext(ctx)
err := startCmd.RunE(startCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Restart", func(t *testing.T) {
restartCmd.SetContext(ctx)
err := restartCmd.RunE(restartCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Reconfigure", func(t *testing.T) {
originalLogLevel := logLevel
logLevel = "debug"
defer func() {
logLevel = originalLogLevel
}()
reconfigureCmd.SetContext(ctx)
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Stop", func(t *testing.T) {
stopCmd.SetContext(ctx)
err := stopCmd.RunE(stopCmd, []string{})
require.NoError(t, err)
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
require.NoError(t, err)
assert.True(t, stopped)
})
t.Run("Uninstall", func(t *testing.T) {
uninstallCmd.SetContext(ctx)
err := uninstallCmd.RunE(uninstallCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
_, err = s.Status()
assert.Error(t, err)
})
}

View File

@@ -1,12 +1,16 @@
package cmd
import (
"context"
"fmt"
"os"
"os/signal"
"runtime"
"syscall"
"testing"
"time"
"github.com/kardianos/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -27,6 +31,186 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
statusPollInterval = 500 * time.Millisecond
)
// waitForServiceStatus waits for service to reach expected status with timeout
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
cfg, err := newSVCConfig()
if err != nil {
return false, err
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
return false, err
}
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
defer timeoutCancel()
ticker := time.NewTicker(statusPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
case <-ticker.C:
status, err := s.Status()
if err != nil {
// Continue polling on transient errors
continue
}
if status == expectedStatus {
return true, nil
}
}
}
}
// TestServiceLifecycle tests the complete service lifecycle
func TestServiceLifecycle(t *testing.T) {
// TODO: Add support for Windows and macOS
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
}
if os.Getenv("CONTAINER") == "true" {
t.Skip("Skipping service lifecycle test in container environment")
}
originalServiceName := serviceName
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
defer func() {
serviceName = originalServiceName
}()
tempDir := t.TempDir()
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {
installCmd.SetContext(ctx)
err := installCmd.RunE(installCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
status, err := s.Status()
assert.NoError(t, err)
assert.NotEqual(t, service.StatusUnknown, status)
})
t.Run("Start", func(t *testing.T) {
startCmd.SetContext(ctx)
err := startCmd.RunE(startCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Restart", func(t *testing.T) {
restartCmd.SetContext(ctx)
err := restartCmd.RunE(restartCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Reconfigure", func(t *testing.T) {
originalLogLevel := logLevel
logLevel = "debug"
defer func() {
logLevel = originalLogLevel
}()
reconfigureCmd.SetContext(ctx)
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Stop", func(t *testing.T) {
stopCmd.SetContext(ctx)
err := stopCmd.RunE(stopCmd, []string{})
require.NoError(t, err)
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
require.NoError(t, err)
assert.True(t, stopped)
})
t.Run("Uninstall", func(t *testing.T) {
uninstallCmd.SetContext(ctx)
err := uninstallCmd.RunE(uninstallCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
_, err = s.Status()
assert.Error(t, err)
})
}
// TestServiceEnvVars tests environment variable parsing
func TestServiceEnvVars(t *testing.T) {
tests := []struct {

View File

@@ -1,5 +1,3 @@
//go:build privileged
package iptables
import (

View File

@@ -1,4 +1,4 @@
//go:build !android && privileged
//go:build !android
package iptables

View File

@@ -1,5 +1,3 @@
//go:build privileged
package nftables
import (

View File

@@ -1,4 +1,4 @@
//go:build !android && privileged
//go:build !android
package nftables

View File

@@ -1,5 +1,3 @@
//go:build privileged
package iface
import (

View File

@@ -1,4 +1,4 @@
//go:build linux && !android && privileged
//go:build linux && !android
package wgproxy

View File

@@ -1,4 +1,4 @@
//go:build !linux || !privileged
//go:build !linux
package wgproxy

View File

@@ -1,4 +1,4 @@
//go:build linux && !android && privileged
//go:build linux && !android
package wgproxy
@@ -26,6 +26,64 @@ func compareUDPAddr(addr1, addr2 net.Addr) bool {
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
}
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
wgPort := 51850
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
wgPort := 51851
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
func TestRedirectAs_UDP_IPv4(t *testing.T) {
wgPort := 51852
@@ -198,64 +256,6 @@ func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *
}
}
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
wgPort := 51850
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
wgPort := 51851
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
func TestRedirectAs_Multiple_Switches(t *testing.T) {
wgPort := 51856

View File

@@ -51,13 +51,20 @@ type cachedRecord struct {
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
// guarded by mutex.
type Resolver struct {
records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
// failedResolves records the last failed initial resolve per domain so a
// domain that never resolves isn't retried on every server-domains update
// until refreshBackoff elapses. Entries are cleared on success and pruned
// to the current server-domains set.
failedResolves map[domain.Domain]time.Time
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
@@ -76,9 +83,10 @@ type Resolver struct {
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
failedResolves: make(map[domain.Domain]time.Time),
cacheTTL: resolveCacheTTL(),
}
}
@@ -173,7 +181,9 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype.
// entry for that qtype. When one family hard-errors while the other succeeds,
// the resolved family is still cached but AddDomain returns an error so the
// caller retries the incomplete resolve rather than treating it as complete.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
@@ -203,6 +213,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
if errA != nil || errAAAA != nil {
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
}
return nil
}
@@ -462,6 +476,7 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
delete(m.records, qAAAA)
delete(m.refreshing, qA)
delete(m.refreshing, qAAAA)
delete(m.failedResolves, d)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
@@ -505,6 +520,7 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
currentDomains := m.GetCachedDomains()
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
m.pruneFailedResolves(allDomains)
}
m.addNewDomains(ctx, newDomains)
@@ -577,13 +593,85 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
return m.mgmtDomain != nil && domain == *m.mgmtDomain
}
// addNewDomains resolves and caches all domains from the update
// addNewDomains resolves and caches domains that are not yet in the cache,
// running the lookups concurrently. Domains already cached are skipped and left
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
// synchronously: once NetBird owns the OS resolver the resolve runs through the
// handler chain and would otherwise dial the managed upstreams under the engine
// sync lock on every update.
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
var wg sync.WaitGroup
seen := make(map[domain.Domain]struct{}, len(newDomains))
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
if _, dup := seen[newDomain]; dup {
continue
}
seen[newDomain] = struct{}{}
if !m.needsResolve(newDomain) {
continue
}
wg.Add(1)
go func(d domain.Domain) {
defer wg.Done()
if err := m.AddDomain(ctx, d); err != nil {
m.markResolveFailed(d)
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
return
}
m.clearResolveFailed(d)
log.Debugf("added/updated management cache domain=%s", d.SafeString())
}(newDomain)
}
wg.Wait()
}
// needsResolve reports whether d should be resolved now. A recent failed or
// incomplete resolve gates retries on the backoff even when one family is
// already cached, so a transiently-failed family is retried instead of being
// treated as fully resolved. Otherwise a domain with any cached record is left
// to the stale-while-revalidate refresh path.
func (m *Resolver) needsResolve(d domain.Domain) bool {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.RLock()
defer m.mutex.RUnlock()
if failedAt, ok := m.failedResolves[d]; ok {
return time.Since(failedAt) >= refreshBackoff
}
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
if _, ok := m.records[q]; ok {
return false
}
}
return true
}
func (m *Resolver) markResolveFailed(d domain.Domain) {
m.mutex.Lock()
m.failedResolves[d] = time.Now()
m.mutex.Unlock()
}
func (m *Resolver) clearResolveFailed(d domain.Domain) {
m.mutex.Lock()
delete(m.failedResolves, d)
m.mutex.Unlock()
}
// pruneFailedResolves drops failure markers for domains no longer present in
// the server-domains set, keeping the map bounded to the current set (a
// failed-only domain has no cached record, so RemoveDomain never sees it).
func (m *Resolver) pruneFailedResolves(domains domain.List) {
m.mutex.Lock()
defer m.mutex.Unlock()
for d := range m.failedResolves {
if !slices.Contains(domains, d) {
delete(m.failedResolves, d)
}
}
}

View File

@@ -21,6 +21,7 @@ type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
qErr map[string]error
err error
hasRoot bool
onLookup func()
@@ -30,6 +31,7 @@ func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
qErr: map[string]error{},
hasRoot: true,
}
}
@@ -47,6 +49,9 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
f.calls[key]++
answers := f.answers[key]
err := f.err
if err == nil {
err = f.qErr[key]
}
onLookup := f.onLookup
f.mu.Unlock()
@@ -75,6 +80,12 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
}
}
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
f.mu.Lock()
defer f.mu.Unlock()
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()

View File

@@ -0,0 +1,183 @@
package mgmt
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/shared/management/domain"
)
// A domain already in the cache must not be re-resolved on a subsequent server
// domains update; it is left to the stale-while-revalidate refresh path.
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must resolve the domain")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"cached domain must not be re-resolved on a subsequent update")
}
// New domains in a single update must resolve concurrently rather than serially.
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
var inflight, maxInflight atomic.Int32
chain.onLookup = func() {
n := inflight.Add(1)
for {
old := maxInflight.Load()
if n <= old || maxInflight.CompareAndSwap(old, n) {
break
}
}
time.Sleep(50 * time.Millisecond)
inflight.Add(-1)
}
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
for _, d := range relays {
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
}
r.SetChainResolver(chain, 50)
start := time.Now()
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
require.NoError(t, err)
elapsed := time.Since(start)
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
}
// A domain that fails to resolve must not be retried on every update; the
// failure backoff suppresses re-resolution until it expires.
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must attempt the resolve")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"failed resolve must back off and not retry on the next update")
}
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
// the same host) must be resolved once per update, not once per occurrence.
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{
Stuns: []domain.Domain{"dup.example.com"},
Turns: []domain.Domain{"dup.example.com"},
}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
"a domain appearing under multiple server-domain types must resolve once")
}
// A failure marker must be dropped once its domain leaves the server-domains set
// so the map stays bounded to the current set.
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
require.True(t, marked, "failed resolve must be recorded")
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
}
// When one family hard-errors while the other resolves, the domain is cached
// for the working family but recorded as incomplete so the failed family is
// retried under backoff instead of being treated as fully resolved forever.
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
d := domain.Domain("relay.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
require.True(t, aCached, "the working family must still be cached")
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
r.mutex.Lock()
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
r.mutex.Unlock()
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
}
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
// not a failure: the domain must not be marked for retry, otherwise it would be
// re-resolved on every sync.
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
d := domain.Domain("v4only.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
}

View File

@@ -1,485 +0,0 @@
//go:build privileged
package dns
import (
"context"
"fmt"
"net/netip"
"os"
"testing"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
testCases := []struct {
name string
initUpstreamMap []handlerWrapper
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap []handlerWrapper
expectedLocalQs []dns.Question
}{
{
name: "Initial Config Should Succeed",
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
priority: PriorityUpstream,
},
{
domain: "netbird.cloud",
priority: PriorityLocal,
},
{
domain: nbdns.RootZone,
priority: PriorityDefault,
},
},
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
},
{
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: "netbird.cloud",
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
priority: PriorityUpstream,
},
{
domain: "netbird.cloud",
priority: PriorityLocal,
},
},
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: []handlerWrapper{{
domain: ".",
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
privKey, _ := wgtypes.GenerateKey()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun230%d", n),
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
err = wgIface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
err = wgIface.Close()
if err != nil {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Fatal(err)
}
err = dnsServer.Initialize()
if err != nil {
t.Fatal(err)
}
defer func() {
err = dnsServer.hostManager.restoreHostDNS()
if err != nil {
t.Log(err)
}
}()
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
if testCase.shouldFail {
return
}
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
}
for _, expected := range testCase.expectedUpstreamMap {
found := false
for _, got := range dnsServer.dnsMuxHandlers {
if got.domain == expected.domain && got.priority == expected.priority {
found = true
break
}
}
if !found {
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
}
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
for _, q := range testCase.expectedLocalQs {
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
Question: []dns.Question{q},
})
}
if len(testCase.expectedLocalQs) > 0 {
assert.NotNil(t, responseMSG, "response message should not be nil")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
}
})
}
}
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
if err != nil {
t.Errorf("create stdnet: %v", err)
return
}
privKey, _ := wgtypes.GeneratePrivateKey()
opts := iface.WGIFaceOpts{
IFaceName: "utun2301",
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
}
err = wgIface.Create()
if err != nil {
t.Errorf("create and init wireguard interface: %v", err)
return
}
defer func() {
if err = wgIface.Close(); err != nil {
t.Logf("close wireguard interface: %v", err)
}
}()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
return
}
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Errorf("create DNS server: %v", err)
return
}
err = dnsServer.Initialize()
if err != nil {
t.Errorf("run DNS server: %v", err)
return
}
defer func() {
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
t.Logf("restore DNS settings on the host: %v", err)
return
}
}()
dnsServer.dnsMuxHandlers = []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityUpstream,
},
}
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
}
// Start the server with regular configuration
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update2 := update
update2.ServiceEnable = false
// Disable the server, stop the listener
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update3 := update2
update3.NameServerGroups = update3.NameServerGroups[:1]
// But service still get updates and we checking that we handle
// internal state in the right way
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
}

View File

@@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -22,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test"
@@ -102,6 +104,466 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger())
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
testCases := []struct {
name string
initUpstreamMap []handlerWrapper
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap []handlerWrapper
expectedLocalQs []dns.Question
}{
{
name: "Initial Config Should Succeed",
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
priority: PriorityUpstream,
},
{
domain: "netbird.cloud",
priority: PriorityLocal,
},
{
domain: nbdns.RootZone,
priority: PriorityDefault,
},
},
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
},
{
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: "netbird.cloud",
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
priority: PriorityUpstream,
},
{
domain: "netbird.cloud",
priority: PriorityLocal,
},
},
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: []handlerWrapper{{
domain: ".",
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
privKey, _ := wgtypes.GenerateKey()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun230%d", n),
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
err = wgIface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
err = wgIface.Close()
if err != nil {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Fatal(err)
}
err = dnsServer.Initialize()
if err != nil {
t.Fatal(err)
}
defer func() {
err = dnsServer.hostManager.restoreHostDNS()
if err != nil {
t.Log(err)
}
}()
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
if testCase.shouldFail {
return
}
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
}
for _, expected := range testCase.expectedUpstreamMap {
found := false
for _, got := range dnsServer.dnsMuxHandlers {
if got.domain == expected.domain && got.priority == expected.priority {
found = true
break
}
}
if !found {
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
}
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
for _, q := range testCase.expectedLocalQs {
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
Question: []dns.Question{q},
})
}
if len(testCase.expectedLocalQs) > 0 {
assert.NotNil(t, responseMSG, "response message should not be nil")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
}
})
}
}
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
if err != nil {
t.Errorf("create stdnet: %v", err)
return
}
privKey, _ := wgtypes.GeneratePrivateKey()
opts := iface.WGIFaceOpts{
IFaceName: "utun2301",
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
}
err = wgIface.Create()
if err != nil {
t.Errorf("create and init wireguard interface: %v", err)
return
}
defer func() {
if err = wgIface.Close(); err != nil {
t.Logf("close wireguard interface: %v", err)
}
}()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
return
}
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Errorf("create DNS server: %v", err)
return
}
err = dnsServer.Initialize()
if err != nil {
t.Errorf("run DNS server: %v", err)
return
}
defer func() {
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
t.Logf("restore DNS settings on the host: %v", err)
return
}
}()
dnsServer.dnsMuxHandlers = []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityUpstream,
},
}
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
}
// Start the server with regular configuration
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update2 := update
update2.ServiceEnable = false
// Disable the server, stop the listener
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update3 := update2
update3.NameServerGroups = update3.NameServerGroups[:1]
// But service still get updates and we checking that we handle
// internal state in the right way
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
}
func TestDNSServerStartStop(t *testing.T) {
testCases := []struct {
name string

View File

@@ -1,565 +0,0 @@
//go:build privileged
package internal
import (
"context"
"fmt"
"net"
"runtime"
"strings"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
)
func TestEngine_SSH(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(
ctx, cancel,
&EngineConfig{
WgIfaceName: "utun101",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
MTU: iface.DefaultMTU,
SSHKey: sshKey,
},
EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
},
MobileDependency{},
)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
err = engine.Start(nil, nil)
require.NoError(t, err)
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
peerWithSSH := &mgmtProto.RemotePeerConfig{
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.21/24"},
SshConfig: &mgmtProto.SSHConfig{
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
},
}
// SSH server is not enabled so SSH config of a remote peer should be ignored
networkMap := &mgmtProto.NetworkMap{
Serial: 6,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
// SSH server is enabled, therefore SSH config should be applied
networkMap = &mgmtProto.NetworkMap{
Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{
SshEnabled: true,
JwtConfig: &mgmtProto.JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
MaxTokenAge: 3600,
},
}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now remove peer
networkMap = &mgmtProto.NetworkMap{
Serial: 8,
RemotePeers: []*mgmtProto.RemotePeerConfig{},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
// time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now disable SSH server
networkMap = &mgmtProto.NetworkMap{
Serial: 9,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
}
func TestEngine_Sync(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
t.Fatal(err)
}
}
return nil
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
err = engine.Start(nil, nil)
if err != nil {
t.Fatal(err)
return
}
peer1 := &mgmtProto.RemotePeerConfig{
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.10/24"},
}
peer2 := &mgmtProto.RemotePeerConfig{
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.11/24"},
}
peer3 := &mgmtProto.RemotePeerConfig{
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.12/24"},
}
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
updates <- &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{
Serial: 10,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
RemotePeersIsEmpty: false,
},
}
timeout := time.After(time.Second * 2)
for {
select {
case <-timeout:
t.Fatalf("timeout while waiting for test to finish")
return
default:
}
if getPeers(engine) == 3 && engine.networkSerial == 10 {
break
}
}
}
func TestEngine_MultiplePeers(t *testing.T) {
// log.SetLevel(log.DebugLevel)
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
sigServer, signalAddr, err := startSignal(t)
if err != nil {
t.Fatal(err)
return
}
defer sigServer.Stop()
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
if err != nil {
t.Fatal(err)
return
}
defer mgmtServer.GracefulStop()
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
mu := sync.Mutex{}
engines := []*Engine{}
numPeers := 10
wg := sync.WaitGroup{}
wg.Add(numPeers)
// create and start peers
for i := 0; i < numPeers; i++ {
j := i
go func() {
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
if err != nil {
wg.Done()
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
return
}
engine.dnsServer = &dns.MockServer{}
mu.Lock()
defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start(nil, nil)
if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err)
wg.Done()
return
}
engines = append(engines, engine)
wg.Done()
}()
}
// wait until all have been created and started
wg.Wait()
if len(engines) != numPeers {
t.Fatal("not all peers were started")
}
// check whether all the peer have expected peers connected
expectedConnected := numPeers * (numPeers - 1)
// adjust according to timeouts
timeout := 50 * time.Second
timeoutChan := time.After(timeout)
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
loop:
for {
select {
case <-timeoutChan:
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
break loop
case <-ticker.C:
totalConnected := 0
for _, engine := range engines {
totalConnected += getConnectedPeers(engine)
}
if totalConnected == expectedConnected {
log.Infof("total connected=%d", totalConnected)
break loop
}
log.Infof("total connected=%d", totalConnected)
}
}
// cleanup test
for n, peerEngine := range engines {
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
errStop := peerEngine.mgmClient.Close()
if errStop != nil {
log.Infoln("got error trying to close management clients from engine: ", errStop)
}
errStop = peerEngine.Stop()
if errStop != nil {
log.Infoln("got error trying to close testing peers engine: ", errStop)
}
}
}
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
if err != nil {
return nil, err
}
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
var ifaceName string
if runtime.GOOS == "darwin" {
ifaceName = fmt.Sprintf("utun1%d", i)
} else {
ifaceName = fmt.Sprintf("wt%d", i)
}
wgPort := 33100 + i
conf := &EngineConfig{
WgIfaceName: ifaceName,
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
WgPrivateKey: key,
WgPort: wgPort,
MTU: iface.DefaultMTU,
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{}), nil
e.ctx = ctx
return e, err
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper()
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &config.Host{
Proto: "http",
URI: "localhost:10000",
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, "", err
}
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
func getConnectedPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
i := 0
for _, id := range e.peerStore.PeersPubKey() {
conn, _ := e.peerStore.PeerConn(id)
if conn.IsConnected() {
i++
}
}
return i
}
func getPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
return len(e.peerStore.PeersPubKey())
}

View File

@@ -6,18 +6,37 @@ import (
"net"
"net/netip"
"os"
"runtime"
"strings"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
@@ -31,7 +50,18 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/routemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/monotime"
"github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
@@ -39,9 +69,25 @@ import (
"github.com/netbirdio/netbird/shared/netiputil"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
)
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
type MockWGIface struct {
CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
@@ -188,6 +234,129 @@ func TestMain(m *testing.M) {
os.Exit(code)
}
func TestEngine_SSH(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(
ctx, cancel,
&EngineConfig{
WgIfaceName: "utun101",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
MTU: iface.DefaultMTU,
SSHKey: sshKey,
},
EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
},
MobileDependency{},
)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
err = engine.Start(nil, nil)
require.NoError(t, err)
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
peerWithSSH := &mgmtProto.RemotePeerConfig{
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.21/24"},
SshConfig: &mgmtProto.SSHConfig{
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
},
}
// SSH server is not enabled so SSH config of a remote peer should be ignored
networkMap := &mgmtProto.NetworkMap{
Serial: 6,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
// SSH server is enabled, therefore SSH config should be applied
networkMap = &mgmtProto.NetworkMap{
Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{
SshEnabled: true,
JwtConfig: &mgmtProto.JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
MaxTokenAge: 3600,
},
}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now remove peer
networkMap = &mgmtProto.NetworkMap{
Serial: 8,
RemotePeers: []*mgmtProto.RemotePeerConfig{},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
// time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now disable SSH server
networkMap = &mgmtProto.NetworkMap{
Serial: 9,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
}
func TestEngine_SSHUpdateLogic(t *testing.T) {
// Test that SSH server start/stop logic works based on config
engine := &Engine{
@@ -462,6 +631,97 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}
}
func TestEngine_Sync(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
t.Fatal(err)
}
}
return nil
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
err = engine.Start(nil, nil)
if err != nil {
t.Fatal(err)
return
}
peer1 := &mgmtProto.RemotePeerConfig{
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.10/24"},
}
peer2 := &mgmtProto.RemotePeerConfig{
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.11/24"},
}
peer3 := &mgmtProto.RemotePeerConfig{
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.12/24"},
}
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
updates <- &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{
Serial: 10,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
RemotePeersIsEmpty: false,
},
}
timeout := time.After(time.Second * 2)
for {
select {
case <-timeout:
t.Fatalf("timeout while waiting for test to finish")
return
default:
}
if getPeers(engine) == 3 && engine.networkSerial == 10 {
break
}
}
}
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
testCases := []struct {
name string
@@ -845,6 +1105,104 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}
}
func TestEngine_MultiplePeers(t *testing.T) {
// log.SetLevel(log.DebugLevel)
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
sigServer, signalAddr, err := startSignal(t)
if err != nil {
t.Fatal(err)
return
}
defer sigServer.Stop()
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
if err != nil {
t.Fatal(err)
return
}
defer mgmtServer.GracefulStop()
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
mu := sync.Mutex{}
engines := []*Engine{}
numPeers := 10
wg := sync.WaitGroup{}
wg.Add(numPeers)
// create and start peers
for i := 0; i < numPeers; i++ {
j := i
go func() {
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
if err != nil {
wg.Done()
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
return
}
engine.dnsServer = &dns.MockServer{}
mu.Lock()
defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start(nil, nil)
if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err)
wg.Done()
return
}
engines = append(engines, engine)
wg.Done()
}()
}
// wait until all have been created and started
wg.Wait()
if len(engines) != numPeers {
t.Fatal("not all peers was started")
}
// check whether all the peer have expected peers connected
expectedConnected := numPeers * (numPeers - 1)
// adjust according to timeouts
timeout := 50 * time.Second
timeoutChan := time.After(timeout)
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
loop:
for {
select {
case <-timeoutChan:
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
break loop
case <-ticker.C:
totalConnected := 0
for _, engine := range engines {
totalConnected += getConnectedPeers(engine)
}
if totalConnected == expectedConnected {
log.Infof("total connected=%d", totalConnected)
break loop
}
log.Infof("total connected=%d", totalConnected)
}
}
// cleanup test
for n, peerEngine := range engines {
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
errStop := peerEngine.mgmClient.Close()
if errStop != nil {
log.Infoln("got error trying to close management clients from engine: ", errStop)
}
errStop = peerEngine.Stop()
if errStop != nil {
log.Infoln("got error trying to close testing peers engine: ", errStop)
}
}
}
func Test_ParseNATExternalIPMappings(t *testing.T) {
ifaceList, err := net.Interfaces()
if err != nil {
@@ -1168,6 +1526,187 @@ func TestCompareNetIPLists(t *testing.T) {
}
}
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
if err != nil {
return nil, err
}
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
var ifaceName string
if runtime.GOOS == "darwin" {
ifaceName = fmt.Sprintf("utun1%d", i)
} else {
ifaceName = fmt.Sprintf("wt%d", i)
}
wgPort := 33100 + i
conf := &EngineConfig{
WgIfaceName: ifaceName,
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
WgPrivateKey: key,
WgPort: wgPort,
MTU: iface.DefaultMTU,
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{}), nil
e.ctx = ctx
return e, err
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper()
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &config.Host{
Proto: "http",
URI: "localhost:10000",
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, "", err
}
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
func getConnectedPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
i := 0
for _, id := range e.peerStore.PeersPubKey() {
conn, _ := e.peerStore.PeerConn(id)
if conn.IsConnected() {
i++
}
}
return i
}
func getPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
return len(e.peerStore.PeersPubKey())
}
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
t.Helper()
b, err := netiputil.EncodePrefix(p)

View File

@@ -1,5 +1,3 @@
//go:build privileged
package routemanager
import (

View File

@@ -1,69 +0,0 @@
//go:build linux && !android
package systemops
import (
"fmt"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestEntryExists(t *testing.T) {
tempDir := t.TempDir()
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
content := []string{
"1000 reserved",
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
"9999 other_table",
}
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
file, err := os.Open(tempFilePath)
require.NoError(t, err)
defer func() {
assert.NoError(t, file.Close())
}()
tests := []struct {
name string
id int
shouldExist bool
err error
}{
{
name: "ExistsWithNetbirdPrefix",
id: 7120,
shouldExist: true,
err: nil,
},
{
name: "ExistsWithDifferentName",
id: 1000,
shouldExist: true,
err: ErrTableIDExists,
},
{
name: "DoesNotExist",
id: 1234,
shouldExist: false,
err: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
exists, err := entryExists(file, tc.id)
if tc.err != nil {
assert.ErrorIs(t, err, tc.err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.shouldExist, exists)
})
}
}

View File

@@ -1,191 +0,0 @@
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && privileged
package systemops
import (
"fmt"
"net"
"net/netip"
"os/exec"
"regexp"
"runtime"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func init() {
testCases = append(testCases, []testCase{
{
name: "To more specific route without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
},
}...)
}
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
var intf *net.Interface
var nexthop Nexthop
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}
r := New(nil, nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
baseIP = netip.MustParseAddr("192.0.2.0")
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()
if runtime.GOOS == "darwin" {
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return intf
}
prefix, err := netip.ParsePrefix(ipAddressCIDR)
require.NoError(t, err, "Failed to parse prefix")
netIntf, err := net.InterfaceByName(intf)
require.NoError(t, err, "Failed to get interface by name")
nexthop := Nexthop{netip.Addr{}, netIntf}
r := New(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")
t.Cleanup(func() {
err := r.removeFromRouteTable(prefix, nexthop)
assert.NoError(t, err, "Failed to remove route from table")
})
return intf
}
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
t.Helper()
var originalNexthop net.IP
if dstCIDR == "0.0.0.0/0" {
var err error
originalNexthop, err = fetchOriginalGateway()
if err != nil {
t.Logf("Failed to fetch original gateway: %v", err)
}
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
t.Logf("Failed to delete route: %v, output: %s", err, output)
}
}
t.Cleanup(func() {
if originalNexthop != nil {
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
assert.NoError(t, err, "Failed to restore original route")
}
})
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
require.NoError(t, err, "Failed to add route")
t.Cleanup(func() {
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
assert.NoError(t, err, "Failed to remove route")
})
}
func fetchOriginalGateway() (net.IP, error) {
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
if err != nil {
return nil, err
}
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
if len(matches) == 0 {
return nil, fmt.Errorf("gateway not found")
}
return net.ParseIP(matches[1]), nil
}
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
t.Helper()
if runtime.GOOS == "darwin" {
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
}
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
tunName := strings.TrimSpace(string(output))
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
intf, err := net.InterfaceByName(tunName)
require.NoError(t, err, "Failed to get interface by name")
t.Cleanup(func() {
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
}
})
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
}

View File

@@ -3,24 +3,79 @@
package systemops
import (
"fmt"
"net"
"net/netip"
"os/exec"
"regexp"
"runtime"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/route"
)
// Interface names used by the shared routing test fixtures. Kept untagged (no
// privileged build tag) so the non-privileged test files in this package compile.
//
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedVPNint = "utun100"
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedExternalInt = "lo0"
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedInternalInt = "lo0"
func init() {
testCases = append(testCases, []testCase{
{
name: "To more specific route without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
},
}...)
}
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
var intf *net.Interface
var nexthop Nexthop
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}
r := New(nil, nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
baseIP = netip.MustParseAddr("192.0.2.0")
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
}
func TestBits(t *testing.T) {
tests := []struct {
name string
@@ -67,3 +122,122 @@ func TestBits(t *testing.T) {
})
}
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()
if runtime.GOOS == "darwin" {
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return intf
}
prefix, err := netip.ParsePrefix(ipAddressCIDR)
require.NoError(t, err, "Failed to parse prefix")
netIntf, err := net.InterfaceByName(intf)
require.NoError(t, err, "Failed to get interface by name")
nexthop := Nexthop{netip.Addr{}, netIntf}
r := New(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")
t.Cleanup(func() {
err := r.removeFromRouteTable(prefix, nexthop)
assert.NoError(t, err, "Failed to remove route from table")
})
return intf
}
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
t.Helper()
var originalNexthop net.IP
if dstCIDR == "0.0.0.0/0" {
var err error
originalNexthop, err = fetchOriginalGateway()
if err != nil {
t.Logf("Failed to fetch original gateway: %v", err)
}
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
t.Logf("Failed to delete route: %v, output: %s", err, output)
}
}
t.Cleanup(func() {
if originalNexthop != nil {
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
assert.NoError(t, err, "Failed to restore original route")
}
})
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
require.NoError(t, err, "Failed to add route")
t.Cleanup(func() {
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
assert.NoError(t, err, "Failed to remove route")
})
}
func fetchOriginalGateway() (net.IP, error) {
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
if err != nil {
return nil, err
}
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
if len(matches) == 0 {
return nil, fmt.Errorf("gateway not found")
}
return net.ParseIP(matches[1]), nil
}
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
t.Helper()
if runtime.GOOS == "darwin" {
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
}
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
tunName := strings.TrimSpace(string(output))
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
intf, err := net.InterfaceByName(tunName)
require.NoError(t, err, "Failed to get interface by name")
t.Cleanup(func() {
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
}
})
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
}

View File

@@ -1,17 +0,0 @@
//go:build !android && !ios
package systemops
import (
"context"
"net"
)
// dialer is shared by the per-platform routing test cases. Kept untagged (no
// privileged build tag) so the non-privileged test files compile on every platform.
//
//nolint:unused // consumed by the privileged-tagged routing tests
type dialer interface {
Dial(network, address string) (net.Conn, error)
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

View File

@@ -1,4 +1,4 @@
//go:build !android && !ios && privileged
//go:build !android && !ios
package systemops
@@ -26,6 +26,11 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
type dialer interface {
Dial(network, address string) (net.Conn, error)
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func TestAddVPNRoute(t *testing.T) {
testCases := []struct {
name string
@@ -510,3 +515,125 @@ func setupTestEnv(t *testing.T) {
// unique route in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
}
func TestIsVpnRoute(t *testing.T) {
tests := []struct {
name string
addr string
vpnRoutes []string
localRoutes []string
expectedVpn bool
expectedPrefix netip.Prefix
}{
{
name: "Match in VPN routes",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Match in local routes",
addr: "10.1.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
},
{
name: "No match",
addr: "172.16.0.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Default route ignored",
addr: "192.168.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Default route matches but ignored",
addr: "172.16.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Longest prefix match local",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.0.0/16"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match local multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
},
{
name: "Longest prefix match vpn",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.0.0/16"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match vpn multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
},
{
name: "Duplicate prefix in both",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := netip.ParseAddr(tt.addr)
if err != nil {
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
}
var vpnRoutes, localRoutes []netip.Prefix
for _, route := range tt.vpnRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
}
vpnRoutes = append(vpnRoutes, prefix)
}
for _, route := range tt.localRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse local route %s: %v", route, err)
}
localRoutes = append(localRoutes, prefix)
}
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
})
}
}

View File

@@ -1,132 +0,0 @@
//go:build !android && !ios
package systemops
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsVpnRoute(t *testing.T) {
tests := []struct {
name string
addr string
vpnRoutes []string
localRoutes []string
expectedVpn bool
expectedPrefix netip.Prefix
}{
{
name: "Match in VPN routes",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Match in local routes",
addr: "10.1.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
},
{
name: "No match",
addr: "172.16.0.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Default route ignored",
addr: "192.168.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Default route matches but ignored",
addr: "172.16.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Longest prefix match local",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.0.0/16"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match local multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
},
{
name: "Longest prefix match vpn",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.0.0/16"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match vpn multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
},
{
name: "Duplicate prefix in both",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := netip.ParseAddr(tt.addr)
if err != nil {
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
}
var vpnRoutes, localRoutes []netip.Prefix
for _, route := range tt.vpnRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
}
vpnRoutes = append(vpnRoutes, prefix)
}
for _, route := range tt.localRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse local route %s: %v", route, err)
}
localRoutes = append(localRoutes, prefix)
}
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
})
}
}

View File

@@ -1,10 +1,13 @@
//go:build linux && !android && privileged
//go:build !android
package systemops
import (
"errors"
"fmt"
"net"
"os"
"strings"
"syscall"
"testing"
@@ -15,6 +18,10 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
var expectedVPNint = "wgtest0"
var expectedExternalInt = "dummyext0"
var expectedInternalInt = "dummyint0"
func init() {
testCases = append(testCases, []testCase{
{
@@ -26,6 +33,62 @@ func init() {
}...)
}
func TestEntryExists(t *testing.T) {
tempDir := t.TempDir()
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
content := []string{
"1000 reserved",
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
"9999 other_table",
}
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
file, err := os.Open(tempFilePath)
require.NoError(t, err)
defer func() {
assert.NoError(t, file.Close())
}()
tests := []struct {
name string
id int
shouldExist bool
err error
}{
{
name: "ExistsWithNetbirdPrefix",
id: 7120,
shouldExist: true,
err: nil,
},
{
name: "ExistsWithDifferentName",
id: 1000,
shouldExist: true,
err: ErrTableIDExists,
},
{
name: "DoesNotExist",
id: 1234,
shouldExist: false,
err: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
exists, err := entryExists(file, tc.id)
if tc.err != nil {
assert.ErrorIs(t, err, tc.err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.shouldExist, exists)
})
}
}
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
t.Helper()

View File

@@ -1,15 +0,0 @@
//go:build linux && !android
package systemops
// Interface names used by the shared routing test fixtures. Kept untagged (no
// privileged build tag) so the non-privileged test files in this package compile.
//
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedVPNint = "wgtest0"
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedExternalInt = "dummyext0"
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedInternalInt = "dummyint0"

View File

@@ -1,83 +0,0 @@
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
package systemops
import (
"net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// Shared, non-privileged routing test fixtures. The privileged TestRouting (and its
// per-platform init() appenders) consume these; they live here so the unprivileged
// BSD/darwin test files compile without the privileged build tag.
type PacketExpectation struct {
SrcIP net.IP
DstIP net.IP
SrcPort int
DstPort int
UDP bool
TCP bool
}
//nolint:unused // consumed by the privileged-tagged routing tests
type testCase struct {
name string
expectedInterface string
dialer dialer
expectedPacket PacketExpectation
}
//nolint:unused // consumed by the privileged-tagged routing tests
var testCases = []testCase{
{
name: "To external host without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To external host with custom dialer via physical interface",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To duplicate internal route with custom dialer via physical interface",
expectedInterface: expectedInternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
expectedInterface: expectedInternalInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To unique vpn route with custom dialer via physical interface",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
},
{
name: "To unique vpn route without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
},
}
//nolint:unused // consumed by the privileged-tagged routing tests
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
return PacketExpectation{
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
UDP: true,
}
}

View File

@@ -1,4 +1,4 @@
//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly) && privileged
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
package systemops
@@ -20,6 +20,63 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
type PacketExpectation struct {
SrcIP net.IP
DstIP net.IP
SrcPort int
DstPort int
UDP bool
TCP bool
}
type testCase struct {
name string
expectedInterface string
dialer dialer
expectedPacket PacketExpectation
}
var testCases = []testCase{
{
name: "To external host without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To external host with custom dialer via physical interface",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To duplicate internal route with custom dialer via physical interface",
expectedInterface: expectedInternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
expectedInterface: expectedInternalInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To unique vpn route with custom dialer via physical interface",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
},
{
name: "To unique vpn route without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
},
}
func TestRouting(t *testing.T) {
nbnet.Init()
for _, tc := range testCases {
@@ -45,6 +102,16 @@ func TestRouting(t *testing.T) {
}
}
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
return PacketExpectation{
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
UDP: true,
}
}
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
t.Helper()

View File

@@ -1,5 +1,3 @@
//go:build windows && privileged
package systemops
import (

View File

@@ -11,8 +11,6 @@ import (
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
// interface so route lookups for global IPv6 prefixes resolve in environments
// without v6 connectivity. If a default already exists it is left alone.
//
//nolint:unused // consumed by the privileged-tagged routing tests
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()

View File

@@ -1,4 +1,4 @@
//go:build linux && !android && privileged
//go:build linux && !android
package systemops

View File

@@ -8,14 +8,11 @@ import (
"testing"
)
//nolint:unused // consumed by the privileged-tagged routing tests
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
// interface so route lookups for global IPv6 prefixes resolve in environments
// without v6 connectivity. If a default already exists it is left alone.
//
//nolint:unused // consumed by the privileged-tagged routing tests
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()

View File

@@ -1,235 +0,0 @@
//go:build privileged
package server
import (
"context"
"net"
"os/user"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
)
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
// we will use a management server started via to simulate the server and capture the number of retries
func TestConnectWithRetryRuns(t *testing.T) {
// start the signal server
_, signalAddr, err := startSignal(t)
if err != nil {
t.Fatalf("failed to start signal server: %v", err)
}
counter := 0
// start the management server
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
if err != nil {
t.Fatalf("failed to start management server: %v", err)
}
ctx := internal.CtxInitState(context.Background())
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
defer cancel()
// create new server
ic := profilemanager.ConfigInput{
ManagementURL: "http://" + mgmtAddr,
ConfigPath: t.TempDir() + "/test-profile.json",
}
config, err := profilemanager.UpdateOrCreateConfig(ic)
if err != nil {
t.Fatalf("failed to create config: %v", err)
}
currUser, err := user.Current()
require.NoError(t, err)
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "test-profile",
Username: currUser.Username,
})
if err != nil {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "debug", "", false, false, false, false)
s.config = config
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
t.Setenv(retryInitialIntervalVar, "1s")
t.Setenv(maxRetryIntervalVar, "2s")
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
}
type mockServer struct {
mgmtProto.ManagementServiceServer
counter *int
}
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
*m.counter++
return m.ManagementServiceServer.Login(ctx, req)
}
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
t.Helper()
dataDir := t.TempDir()
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Signal: &config.Host{
Proto: "http",
URI: signalAddr,
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersManager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
if err != nil {
return nil, "", err
}
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
mock := &mockServer{
ManagementServiceServer: mgmtServer,
counter: counter,
}
mgmtProto.RegisterManagementServiceServer(s, mock)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}

View File

@@ -2,22 +2,124 @@ package server
import (
"context"
"net"
"net/url"
"os/user"
"path/filepath"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
daemonProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
)
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
// we will use a management server started via to simulate the server and capture the number of retries
func TestConnectWithRetryRuns(t *testing.T) {
// start the signal server
_, signalAddr, err := startSignal(t)
if err != nil {
t.Fatalf("failed to start signal server: %v", err)
}
counter := 0
// start the management server
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
if err != nil {
t.Fatalf("failed to start management server: %v", err)
}
ctx := internal.CtxInitState(context.Background())
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
defer cancel()
// create new server
ic := profilemanager.ConfigInput{
ManagementURL: "http://" + mgmtAddr,
ConfigPath: t.TempDir() + "/test-profile.json",
}
config, err := profilemanager.UpdateOrCreateConfig(ic)
if err != nil {
t.Fatalf("failed to create config: %v", err)
}
currUser, err := user.Current()
require.NoError(t, err)
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "test-profile",
Username: currUser.Username,
})
if err != nil {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "debug", "", false, false, false, false)
s.config = config
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
t.Setenv(retryInitialIntervalVar, "1s")
t.Setenv(maxRetryIntervalVar, "2s")
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
}
func TestServer_Up(t *testing.T) {
tempDir := t.TempDir()
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
@@ -157,3 +259,119 @@ func TestServer_SubcribeEvents(t *testing.T) {
assert.NoError(t, err)
}
type mockServer struct {
mgmtProto.ManagementServiceServer
counter *int
}
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
*m.counter++
return m.ManagementServiceServer.Login(ctx, req)
}
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
t.Helper()
dataDir := t.TempDir()
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Signal: &config.Host{
Proto: "http",
URI: signalAddr,
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersManager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
if err != nil {
return nil, "", err
}
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
mock := &mockServer{
ManagementServiceServer: mgmtServer,
counter: counter,
}
mgmtProto.RegisterManagementServiceServer(s, mock)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}

View File

@@ -1,118 +0,0 @@
//go:build privileged
package client
import (
"context"
"errors"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/ssh/testutil"
)
func TestSSHClient_CommandExecution(t *testing.T) {
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
}
server, _, client := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
t.Run("ExecuteCommand captures output", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo hello")
assert.NoError(t, err)
assert.Contains(t, string(output), "hello")
})
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
err := client.ExecuteCommandWithIO(ctx, "echo world")
assert.NoError(t, err)
})
t.Run("commands with flags work", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
assert.NoError(t, err)
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
})
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
var testCmd string
if runtime.GOOS == "windows" {
testCmd = "echo hello | Select-String notfound"
} else {
testCmd = "echo 'hello' | grep 'notfound'"
}
_, err := client.ExecuteCommand(ctx, testCmd)
assert.NoError(t, err)
})
}
func TestSSHClient_ContextCancellation(t *testing.T) {
server, serverAddr, _ := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
t.Run("connection with short timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
currentUser := testutil.GetTestUsername(t)
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
if err != nil {
// Check for actual timeout-related errors rather than string matching
assert.True(t,
errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
strings.Contains(err.Error(), "timeout"),
"Expected timeout-related error, got: %v", err)
}
})
t.Run("command execution cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Logf("client close error: %v", err)
}
}()
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cmdCancel()
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
if err != nil {
var exitMissingErr *cryptossh.ExitMissingError
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
errors.As(err, &exitMissingErr)
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
}
})
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/ssh"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
@@ -77,6 +78,53 @@ func TestSSHClient_DialWithKey(t *testing.T) {
assert.NotNil(t, client.client)
}
func TestSSHClient_CommandExecution(t *testing.T) {
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
}
server, _, client := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
t.Run("ExecuteCommand captures output", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo hello")
assert.NoError(t, err)
assert.Contains(t, string(output), "hello")
})
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
err := client.ExecuteCommandWithIO(ctx, "echo world")
assert.NoError(t, err)
})
t.Run("commands with flags work", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
assert.NoError(t, err)
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
})
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
var testCmd string
if runtime.GOOS == "windows" {
testCmd = "echo hello | Select-String notfound"
} else {
testCmd = "echo 'hello' | grep 'notfound'"
}
_, err := client.ExecuteCommand(ctx, testCmd)
assert.NoError(t, err)
})
}
func TestSSHClient_ConnectionHandling(t *testing.T) {
server, serverAddr, _ := setupTestSSHServerAndClient(t)
defer func() {
@@ -106,6 +154,59 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
}
}
func TestSSHClient_ContextCancellation(t *testing.T) {
server, serverAddr, _ := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
t.Run("connection with short timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
currentUser := testutil.GetTestUsername(t)
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
if err != nil {
// Check for actual timeout-related errors rather than string matching
assert.True(t,
errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
strings.Contains(err.Error(), "timeout"),
"Expected timeout-related error, got: %v", err)
}
})
t.Run("command execution cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Logf("client close error: %v", err)
}
}()
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cmdCancel()
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
if err != nil {
var exitMissingErr *cryptossh.ExitMissingError
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
errors.As(err, &exitMissingErr)
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
}
})
}
func TestSSHClient_NoAuthMode(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)

View File

@@ -1,423 +0,0 @@
//go:build privileged
package proxy
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"os"
"runtime"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func (m *mockDaemon) setJWTToken(token string) {
m.impl.jwtToken = token
}
func TestSSHProxy_Connect(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// TODO: Windows test times out - user switching and command execution tested on Linux
if runtime.GOOS == "windows" {
t.Skip("Skipping on Windows - covered by Linux tests")
}
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
// Configure SSH authorization for the test user
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0}, // Index 0 in AuthorizedUsers
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
defer func() { _ = clientConn.Close() }()
origStdin := os.Stdin
origStdout := os.Stdout
defer func() {
os.Stdin = origStdin
os.Stdout = origStdout
}()
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
go func() {
_, _ = io.Copy(stdinWriter, proxyConn)
}()
go func() {
_, _ = io.Copy(proxyConn, stdoutReader)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
connectErrCh := make(chan error, 1)
go func() {
connectErrCh <- proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 3 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err, "Should connect to proxy server")
defer func() { _ = sshClientConn.Close() }()
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
session, err := sshClient.NewSession()
require.NoError(t, err, "Should create session through full proxy to backend")
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output("echo hello-from-proxy")
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
require.NoError(t, err, "Command should execute successfully through proxy")
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
case <-time.After(3 * time.Second):
t.Fatal("Command execution timed out")
}
_ = session.Close()
_ = sshClient.Close()
_ = clientConn.Close()
cancel()
}
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
// when forwarding commands to the backend. This is critical for tools like
// Ansible that send commands such as:
//
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
//
// The single quotes must be preserved so the backend shell receives the
// subshell expression as a single argument to -c.
func TestSSHProxy_CommandQuoting(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
sshClient, cleanup := setupProxySSHClient(t)
defer cleanup()
// These commands simulate what the SSH protocol delivers as exec payloads.
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
// the local shell strips the outer single quotes, and the SSH exec request
// contains the raw string: /bin/sh -c "( echo hello )"
//
// The proxy must forward this string verbatim. Using session.Command()
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
// the command on the backend.
tests := []struct {
name string
command string
expect string
}{
{
name: "subshell_in_double_quotes",
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
expect: "from-subshell\nouter\n",
},
{
name: "printf_with_special_chars",
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
expect: "hello world\n",
},
{
name: "nested_command_substitution",
command: `/bin/sh -c "echo $(echo nested)"`,
expect: "nested\n",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
session, err := sshClient.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()
var stderrBuf bytes.Buffer
session.Stderr = &stderrBuf
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output(tc.command)
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
if stderrBuf.Len() > 0 {
t.Logf("stderr: %s", stderrBuf.String())
}
require.NoError(t, err, "command should succeed: %s", tc.command)
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
case <-time.After(5 * time.Second):
t.Fatalf("command timed out: %s", tc.command)
}
})
}
}
// setupProxySSHClient creates a full proxy test environment and returns
// an SSH client connected through the proxy to a backend NetBird SSH server.
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
t.Helper()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0},
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
mockDaemon := startMockDaemon(t)
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
origStdin := os.Stdin
origStdout := os.Stdout
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
clientConn, proxyConn := net.Pipe()
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
_ = proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err)
client := cryptossh.NewClient(sshClientConn, chans, reqs)
cleanupFn := func() {
_ = client.Close()
_ = clientConn.Close()
cancel()
os.Stdin = origStdin
os.Stdout = origStdout
_ = sshServer.Stop()
mockDaemon.stop()
jwksServer.Close()
}
return client, cleanupFn
}
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
t.Helper()
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64.RawURLEncoding.EncodeToString(n),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": user,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}

View File

@@ -1,12 +1,25 @@
package proxy
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"os"
"runtime"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
@@ -15,7 +28,11 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func TestMain(m *testing.M) {
@@ -89,6 +106,331 @@ func TestSSHProxy_verifyHostKey(t *testing.T) {
})
}
func TestSSHProxy_Connect(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// TODO: Windows test times out - user switching and command execution tested on Linux
if runtime.GOOS == "windows" {
t.Skip("Skipping on Windows - covered by Linux tests")
}
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
// Configure SSH authorization for the test user
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0}, // Index 0 in AuthorizedUsers
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
defer func() { _ = clientConn.Close() }()
origStdin := os.Stdin
origStdout := os.Stdout
defer func() {
os.Stdin = origStdin
os.Stdout = origStdout
}()
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
go func() {
_, _ = io.Copy(stdinWriter, proxyConn)
}()
go func() {
_, _ = io.Copy(proxyConn, stdoutReader)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
connectErrCh := make(chan error, 1)
go func() {
connectErrCh <- proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 3 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err, "Should connect to proxy server")
defer func() { _ = sshClientConn.Close() }()
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
session, err := sshClient.NewSession()
require.NoError(t, err, "Should create session through full proxy to backend")
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output("echo hello-from-proxy")
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
require.NoError(t, err, "Command should execute successfully through proxy")
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
case <-time.After(3 * time.Second):
t.Fatal("Command execution timed out")
}
_ = session.Close()
_ = sshClient.Close()
_ = clientConn.Close()
cancel()
}
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
// when forwarding commands to the backend. This is critical for tools like
// Ansible that send commands such as:
//
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
//
// The single quotes must be preserved so the backend shell receives the
// subshell expression as a single argument to -c.
func TestSSHProxy_CommandQuoting(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
sshClient, cleanup := setupProxySSHClient(t)
defer cleanup()
// These commands simulate what the SSH protocol delivers as exec payloads.
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
// the local shell strips the outer single quotes, and the SSH exec request
// contains the raw string: /bin/sh -c "( echo hello )"
//
// The proxy must forward this string verbatim. Using session.Command()
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
// the command on the backend.
tests := []struct {
name string
command string
expect string
}{
{
name: "subshell_in_double_quotes",
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
expect: "from-subshell\nouter\n",
},
{
name: "printf_with_special_chars",
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
expect: "hello world\n",
},
{
name: "nested_command_substitution",
command: `/bin/sh -c "echo $(echo nested)"`,
expect: "nested\n",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
session, err := sshClient.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()
var stderrBuf bytes.Buffer
session.Stderr = &stderrBuf
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output(tc.command)
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
if stderrBuf.Len() > 0 {
t.Logf("stderr: %s", stderrBuf.String())
}
require.NoError(t, err, "command should succeed: %s", tc.command)
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
case <-time.After(5 * time.Second):
t.Fatalf("command timed out: %s", tc.command)
}
})
}
}
// setupProxySSHClient creates a full proxy test environment and returns
// an SSH client connected through the proxy to a backend NetBird SSH server.
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
t.Helper()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0},
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
mockDaemon := startMockDaemon(t)
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
origStdin := os.Stdin
origStdout := os.Stdout
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
clientConn, proxyConn := net.Pipe()
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
_ = proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err)
client := cryptossh.NewClient(sshClientConn, chans, reqs)
cleanupFn := func() {
_ = client.Close()
_ = clientConn.Close()
cancel()
os.Stdin = origStdin
os.Stdout = origStdout
_ = sshServer.Stop()
mockDaemon.stop()
jwksServer.Close()
}
return client, cleanupFn
}
type mockDaemonServer struct {
proto.UnimplementedDaemonServiceServer
hostKeys map[string][]byte
@@ -150,6 +492,10 @@ func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
m.impl.hostKeys[addr] = pubKey
}
func (m *mockDaemon) setJWTToken(token string) {
m.impl.jwtToken = token
}
func (m *mockDaemon) stop() {
if m.server != nil {
m.server.Stop()
@@ -162,3 +508,63 @@ func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
require.NoError(t, err)
return pubKey
}
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
t.Helper()
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64.RawURLEncoding.EncodeToString(n),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": user,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}

View File

@@ -1,66 +0,0 @@
//go:build unix && privileged
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000, 1001},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "ls -la",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify the command is calling netbird ssh exec
assert.Contains(t, cmd.Args, "ssh")
assert.Contains(t, cmd.Args, "exec")
assert.Contains(t, cmd.Args, "--uid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--gid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--groups")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "1001")
assert.Contains(t, cmd.Args, "--working-dir")
assert.Contains(t, cmd.Args, "/home/testuser")
assert.Contains(t, cmd.Args, "--shell")
assert.Contains(t, cmd.Args, "/bin/bash")
assert.Contains(t, cmd.Args, "--cmd")
assert.Contains(t, cmd.Args, "ls -la")
}
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify no command mode (command is empty so no --cmd flag)
assert.NotContains(t, cmd.Args, "--cmd")
assert.NotContains(t, cmd.Args, "--interactive")
}

View File

@@ -73,6 +73,61 @@ func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) {
}
}
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000, 1001},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "ls -la",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify the command is calling netbird ssh exec
assert.Contains(t, cmd.Args, "ssh")
assert.Contains(t, cmd.Args, "exec")
assert.Contains(t, cmd.Args, "--uid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--gid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--groups")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "1001")
assert.Contains(t, cmd.Args, "--working-dir")
assert.Contains(t, cmd.Args, "/home/testuser")
assert.Contains(t, cmd.Args, "--shell")
assert.Contains(t, cmd.Args, "/bin/bash")
assert.Contains(t, cmd.Args, "--cmd")
assert.Contains(t, cmd.Args, "ls -la")
}
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify no command mode (command is empty so no --cmd flag)
assert.NotContains(t, cmd.Args, "--cmd")
assert.NotContains(t, cmd.Args, "--interactive")
}
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
// This test requires root privileges and will be skipped if not running as root
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {

View File

@@ -1,196 +0,0 @@
//go:build privileged && (linux || darwin)
// Package privileged provides a self-hosting harness that runs the repo's
// privileged-tagged test suite inside a --privileged --cap-add=NET_ADMIN
// container, so developers can exercise the root/system-mutating tests on a
// non-root host with a single `go test` invocation.
package privileged
import (
"bytes"
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"time"
"github.com/moby/moby/api/types/container"
"github.com/ory/dockertest/v4"
)
// containerImage / containerTag match the image used by the CI privileged job
// (.github/workflows/golang-test-linux.yml, test_client_on_docker).
const (
containerImage = "golang"
containerTag = "1.25-alpine"
)
const (
containerWorkdir = "/app"
containerGoCache = "/root/.cache/go-build"
containerGoModCache = "/go/pkg/mod"
)
// alpinePackages are the build/runtime deps the privileged tests need, mirroring
// the CI container setup.
const alpinePackages = "ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base"
// privilegedTestPackages is the package list the suite runs, excluding the
// server-side trees and UI/upload helpers, matching the CI Docker job's filter.
const privilegedTestPackages = `go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server`
// testWriter forwards container output to the test log line by line.
type testWriter struct{ t *testing.T }
func (w testWriter) Write(p []byte) (int, error) {
for _, line := range strings.Split(strings.TrimRight(string(p), "\n"), "\n") {
w.t.Log(line)
}
return len(p), nil
}
// TestRunPrivilegedSuiteInDocker spins up a privileged container, mounts the repo,
// and runs `go test -tags 'devcert privileged'` inside it. When already running
// inside that container (DOCKER_CI=true) it returns immediately so the real
// privileged tests in the suite execute in place instead of recursing.
func TestRunPrivilegedSuiteInDocker(t *testing.T) {
if os.Getenv("DOCKER_CI") == "true" {
t.Skip("inside privileged container, skipping container spawn; privileged tests run in place")
}
repoRoot, err := findRepoRoot()
if err != nil {
t.Fatalf("locate repo root: %v", err)
}
goCache, goModCache := hostGoCaches(t)
// dockertest reads DOCKER_HOST; point it at the active context's socket when
// the default one is absent (macOS Docker Desktop, Colima, OrbStack).
if host := dockerHost(); host != "" {
t.Setenv("DOCKER_HOST", host)
}
// NewPoolT registers container cleanup via t.Cleanup automatically.
pool := dockertest.NewPoolT(t, "", dockertest.WithMaxWait(30*time.Minute))
// Keep the container alive so the suite runs via Exec, which yields a clean
// exit code (the v4 Resource API exposes no container wait/exit-code).
resource := pool.RunT(t, containerImage,
dockertest.WithTag(containerTag),
dockertest.WithWorkingDir(containerWorkdir),
dockertest.WithMounts([]string{
repoRoot + ":" + containerWorkdir,
goCache + ":" + containerGoCache,
goModCache + ":" + containerGoModCache,
}),
dockertest.WithEnv([]string{
"CGO_ENABLED=1",
"CI=true",
"DOCKER_CI=true",
"CONTAINER=true",
"GOCACHE=" + containerGoCache,
"GOMODCACHE=" + containerGoModCache,
}),
dockertest.WithCmd([]string{"sleep", "infinity"}),
dockertest.WithHostConfig(func(hc *container.HostConfig) {
hc.Privileged = true
hc.CapAdd = []string{"NET_ADMIN"}
}),
dockertest.WithoutReuse(),
)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
result, err := resource.Exec(ctx, []string{"sh", "-c", buildTestScript()})
if err != nil {
t.Fatalf("run privileged suite in container: %v", err)
}
w := testWriter{t}
_, _ = w.Write([]byte(result.StdOut))
_, _ = w.Write([]byte(result.StdErr))
if result.ExitCode != 0 {
t.Fatalf("privileged test suite failed in container (exit code %d)", result.ExitCode)
}
}
// findRepoRoot walks up from the test's working directory to the module root.
func findRepoRoot() (string, error) {
dir, err := os.Getwd()
if err != nil {
return "", err
}
for {
if _, statErr := os.Stat(filepath.Join(dir, "go.mod")); statErr == nil {
return dir, nil
}
parent := filepath.Dir(dir)
if parent == dir {
return "", fmt.Errorf("go.mod not found above %s", dir)
}
dir = parent
}
}
// dockerHost returns a DOCKER_HOST override when the default socket is missing.
// An empty result means the caller should leave DOCKER_HOST untouched (it is
// already set, or the default unix socket exists). When neither is present
// (common on macOS Docker Desktop, Colima and OrbStack, which use a per-user
// socket), it resolves the active docker context's endpoint.
func dockerHost() string {
if os.Getenv("DOCKER_HOST") != "" {
return ""
}
if _, err := os.Stat("/var/run/docker.sock"); err == nil {
return ""
}
out, err := exec.Command("docker", "context", "inspect", "-f", "{{.Endpoints.docker.Host}}").Output()
if err != nil {
return ""
}
return strings.TrimSpace(string(out))
}
// hostGoCaches resolves the host GOCACHE/GOMODCACHE so the container reuses the
// existing build/module cache for speed.
func hostGoCaches(t *testing.T) (string, string) {
t.Helper()
return goEnv(t, "GOCACHE"), goEnv(t, "GOMODCACHE")
}
func goEnv(t *testing.T, key string) string {
t.Helper()
var out bytes.Buffer
cmd := exec.Command("go", "env", key)
cmd.Stdout = &out
if err := cmd.Run(); err != nil {
t.Fatalf("go env %s: %v", key, err)
}
return strings.TrimSpace(out.String())
}
// buildTestScript builds the in-container command. PRIV_PKGS overrides the package
// list (default: the full filtered set); PRIV_RUN adds a -run test-name filter.
// Both empty reproduces the full privileged suite.
func buildTestScript() string {
pkgs := privilegedTestPackages + " | xargs"
if p := os.Getenv("PRIV_PKGS"); p != "" {
pkgs = "echo " + p + " | xargs"
}
runFilter := ""
if r := os.Getenv("PRIV_RUN"); r != "" {
runFilter = "-run '" + r + "' "
}
return fmt.Sprintf(
"apk update >/dev/null && apk add --no-cache %s >/dev/null && %s go test -buildvcs=false -tags 'devcert privileged' %s-v -timeout 20m -p 1",
alpinePackages, pkgs, runFilter,
)
}

View File

@@ -418,7 +418,14 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
case args.showProfiles:
s.showProfilesUI()
case args.showQuickActions:
s.showQuickActionsUI()
// Suppress the on-boot Quick Actions popup when the daemon
// reports DisableAutoConnect=true — that flag carries both the
// user's "Connect on Startup = off" preference AND any MDM-
// enforced override (applyMDMPolicy writes the policy value
// into the same Config field). See netbirdio/netbird#5744.
if !s.disableAutoConnectFromDaemon() {
s.showQuickActionsUI()
}
case args.showUpdate:
s.showUpdateProgress(ctx, args.showUpdateVersion)
}
@@ -1338,6 +1345,40 @@ func (s *serviceClient) getFeatures() (*proto.GetFeaturesResponse, error) {
return features, nil
}
// disableAutoConnectFromDaemon returns true when the daemon reports
// the active profile has DisableAutoConnect=true. Used by the
// --quick-actions startup path to suppress the on-boot popup when the
// user (or an MDM admin) opted out of auto-connecting; both cases
// converge on the same Config field because applyMDMPolicy writes the
// policy value into it. Returns false on any RPC / lookup failure so a
// daemon hiccup does not silently swallow the popup.
func (s *serviceClient) disableAutoConnectFromDaemon() bool {
activeProf, err := s.profileManager.GetActiveProfile()
if err != nil {
log.Warnf("disableAutoConnectFromDaemon: get active profile: %v", err)
return false
}
currUser, err := user.Current()
if err != nil {
log.Warnf("disableAutoConnectFromDaemon: get current user: %v", err)
return false
}
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
log.Warnf("disableAutoConnectFromDaemon: get daemon client: %v", err)
return false
}
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
})
if err != nil {
log.Warnf("disableAutoConnectFromDaemon: GetConfig RPC: %v", err)
return false
}
return srvCfg.GetDisableAutoConnect()
}
// getSrvConfig from the service to show it in the settings window.
func (s *serviceClient) getSrvConfig() {
s.managementURL = profilemanager.DefaultManagementURL

View File

@@ -0,0 +1,109 @@
# Agent Networks — overview
Single-entry point. Feature scope, the module map, and the cross-cutting
topics worth keeping in mind, with links into every per-module guide.
## TL;DR
Agent Networks introduces an **LLM-aware reverse-proxy middleware system**
plus **account-level controls** (budget rules, log collection toggles,
PII redaction). The management server synthesises a per-peer middleware
chain that the proxy executes on every LLM request; the chain enforces
quotas, injects identity, redacts PII, parses tokens/cost, and emits
access-log entries. The dashboard exposes the surface as a single **AI
Observability** page with four tabs.
- **Backend** lives in this repo, primarily under
`management/server/agentnetwork`, `proxy/internal/middleware`, and
`proxy/internal/llm`, with wire contracts in `shared/management`.
- **Dashboard** lives in the dashboard repo under
`src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`.
## Reading order
| # | Doc | Why |
|---|-----|-----|
| 1 | [01-end-to-end-flows.md](01-end-to-end-flows.md) | Get the three big diagrams in your head first. |
| 2 | [modules/10-shared-api.md](modules/10-shared-api.md) | Wire contracts — every other module either produces or consumes these. |
| 3 | [modules/21-management-agentnetwork.md](modules/21-management-agentnetwork.md) | The largest module; everything the proxy executes originates here. |
| 4 | [modules/30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md) | The generic plugin system on the proxy side. |
| 5 | [modules/31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md) | The 8 LLM middlewares that ride on the framework. |
| 6 | Everything else in any order. | |
## Module map
11 modules. Each is described in detail in its own file under
[`modules/`](modules/).
| # | Module | Risk | BC impact |
|---|--------|------|-----------|
| 10 | [shared/api](modules/10-shared-api.md) — proto + OpenAPI | Low | Additive only |
| 20 | [management/store](modules/20-management-store.md) — SQL persistence | Medium | Auto-migrate (additive) |
| 21 | [management/agentnetwork](modules/21-management-agentnetwork.md) — domain layer + synthesizer | **High** | Additive |
| 22 | [management/handlers + wiring](modules/22-management-handlers-wiring.md) — HTTP API + gRPC delivery | Medium | Additive |
| 30 | [proxy/middleware-framework](modules/30-proxy-middleware-framework.md) — generic plugin system | High | Additive |
| 31 | [proxy/middleware-builtin](modules/31-proxy-middleware-builtin.md) — 8 LLM middlewares | High | Additive |
| 32 | [proxy/llm-parsers](modules/32-proxy-llm-parsers.md) — SDK adapters + pricing | Medium | Additive |
| 33 | [proxy/runtime](modules/33-proxy-runtime.md) — translate + serve + access-log | High | Additive (touches hot path) |
| 40 | [dashboard](modules/40-dashboard.md) — UI for everything above | Medium | Sidebar reshape |
| 50 | [path-routed-providers](modules/50-path-routed-providers.md) — Vertex AI + Bedrock | Medium | Additive (new catalog entries) |
The largest and highest-risk module is `management/agentnetwork`: it is
the single writer of the middleware chain the proxy executes.
## Cross-cutting topics
These are the items most likely to bite production. Each is fully
documented in the linked module guide.
1. **Capture-pointer semantics** (`*bool` for `capture_prompt` and
`capture_completion`): nil = legacy emit, false = suppress, true =
emit. nil-vs-false must be handled at every JSON hop. See
[21-management-agentnetwork.md](modules/21-management-agentnetwork.md)
and [31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md).
2. **`ProxyMapping.Private` preservation** on per-proxy live updates.
Failure mode: `auth` skips `ValidateTunnelPeer`
`CapturedData.UserGroups` empty → `llm_router` denies. See
[33-proxy-runtime.md](modules/33-proxy-runtime.md).
3. **respInput carrying `UserEmail`/`UserGroups`/`UserGroupNames` onto
the response leg** in `reverseproxy.go`. Load-bearing wire that lets
`llm_limit_record` ship non-empty `group_ids` on `RecordLLMUsage`. See
[33-proxy-runtime.md](modules/33-proxy-runtime.md).
4. **Min-wins all-must-pass budget rule semantics**. Every matching
rule's remaining quota must be > 0 for the request to proceed; one
exhausted rule blocks the whole call. Documented in
[21-management-agentnetwork.md](modules/21-management-agentnetwork.md)
and the `llm_limit_check` middleware in
[31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md).
5. **body-tap memory bounds**: per-direction 1 MiB cap, shared 256 MiB
budget, `LimitReader(r.Body, limit+1)` for truncation detection with
`replayReadCloser` fallback so upstream still sees the full body.
`cloneInputFor` deep-copies the body up to 16 times per chain — a
perf hot-spot. See
[30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md).
6. **UpstreamRewrite.AuthHeader bypasses the header denylist**
deliberately. The runtime consumer only unpacks it via the
trusted upstream-build path. See
[30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md).
7. **`disable_access_log` default-false semantics**: the synth target
sets it true, all other targets leave it false. See
[10-shared-api.md](modules/10-shared-api.md).
8. **String-typed `decision` / `deny_code`** on
`CheckLLMPolicyLimitsResponse` — would benefit from enum pinning
before external consumers integrate. See
[10-shared-api.md](modules/10-shared-api.md).
## Explicit non-goals
- **Reaper / GC pass over stale synth services** — designed but cut from
scope.
- **URL-sync for tab state on AI Observability** — read path is wired
(`?tab=`) but write path isn't. Future work.
- **CI golden-file regen-and-diff for `types.gen.go` /
`proxy_service.pb.go`** — would catch codegen drift; not yet in place.
## Where to read the code
Per-module file scopes are listed in each module guide. Behaviour is
covered by Go tests co-located with each package (and an end-to-end
chain integration test under `proxy/internal/proxy`).

View File

@@ -0,0 +1,217 @@
# End-to-end flows
Three cross-module mermaid diagrams. Each per-module guide repeats the
slice that's relevant to its own scope — these are the canonical
top-down views.
- [Flow A — Config → runtime (synth + deliver)](#flow-a--config--runtime-synth--deliver)
- [Flow B — Request lifecycle through the LLM chain](#flow-b--request-lifecycle-through-the-llm-chain)
- [Flow C — Budget rule feedback loop](#flow-c--budget-rule-feedback-loop)
---
## Flow A — Config → runtime (synth + deliver)
How an operator's change to a Provider, Policy, Guardrail, Budget Rule,
or Settings record ends up as live middleware on a peer's proxy.
```mermaid
sequenceDiagram
autonumber
actor Op as Operator
participant UI as Dashboard
participant HTTP as management/handlers
participant Mgr as agentnetwork.Manager
participant Store as management/store (SQL)
participant Ctl as network_map.Controller
participant Synth as agentnetwork.SynthesizeServices
participant Grpc as management gRPC
participant Proxy as netbird-proxy
participant Xlate as middleware_translate
participant Chain as middleware.Chain
Op->>UI: edit provider/policy/budget/settings
UI->>HTTP: REST PUT/POST /api/agent-network/*
HTTP->>Mgr: SaveProvider / SavePolicy / SaveBudgetRule / SaveSettings
Mgr->>Store: persist (gorm)
Mgr-->>Ctl: account change event (Network-Map dirty)
loop per connected peer
Ctl->>Synth: SynthesizeServices(ctx, store, accountID)
Synth->>Store: load providers, policies, guardrails, budget rules, settings
Synth-->>Synth: build per-peer Service list
Note over Synth: each Service has a middleware<br/>chain with capture_prompt /<br/>capture_completion / redact_pii<br/>baked from account settings
Synth-->>Ctl: []rpservice.Service
Ctl->>Grpc: NetworkMap push (services + middleware configs)
end
Grpc-->>Proxy: NetworkMap stream
Proxy->>Xlate: translate proto MiddlewareConfig → runtime Spec
Xlate->>Chain: register / replace per-service chain
Note over Chain: chain replacement is live<br/>(no proxy restart, in-flight<br/>requests unaffected)
```
**Notes on the diagram**
- The `network_map.Controller` synthesises on every push, not on a
timer. A single config change costs O(connected peers × policies ×
providers) per push. See [`modules/22-management-handlers-wiring.md`](modules/22-management-handlers-wiring.md).
- `SynthesizeServices` is the single source of truth for the wire
format the proxy executes. Anything the proxy does that the
synthesiser didn't request is a bug. See
[`modules/21-management-agentnetwork.md`](modules/21-management-agentnetwork.md).
- The translate step (step 13) is the only place that knows the
middleware-ID strings on the proxy side. It must reject unknown IDs;
silently dropping middlewares would create a security gap (e.g.
missing `llm_limit_check` ⇒ unbounded spend). See
[`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md).
---
## Flow B — Request lifecycle through the LLM chain
What happens when an agent on the client peer sends a chat-completion /
messages request through the synthesised reverse-proxy.
```mermaid
sequenceDiagram
autonumber
actor Agent as Agent (local)
participant Px as netbird-proxy
participant Auth as auth middleware
participant Map as service-mapping
participant Req as llm_request_parser
participant Rt as llm_router
participant Chk as llm_limit_check
participant Inj as llm_identity_inject
participant Grd as llm_guardrail
participant Up as upstream LLM
participant Resp as llm_response_parser
participant Cost as cost_meter
participant Rec as llm_limit_record
participant Log as access-log
participant MgmtGrpc as management gRPC
Agent->>Px: POST /v1/chat/completions (OpenAI / Anthropic)
Px->>Auth: identify peer (user, groups)
Auth->>Map: resolve service from Host + path
Map-->>Req: dispatch chain in slot order
Req->>Req: parse body → provider, model, prompt, token estimate
Note over Req: capture_prompt gates raw_prompt<br/>capture (nil = legacy emit,<br/>false = drop, true = emit)
Req->>Rt: pass metadata
Rt->>Chk: route to upstream candidate
Chk->>MgmtGrpc: CheckLLMPolicyLimits(provider, model, est_tokens, groups, user)
MgmtGrpc-->>Chk: decision = allow / deny + deny_code
alt decision == deny
Chk-->>Log: emit access-log with deny_code<br/>(if EnableLogCollection)
Chk-->>Agent: 429 (or 403 per deny_code)
else decision == allow
Chk->>Inj: continue
Inj->>Inj: inject NetBird identity headers per provider config
Inj->>Grd: continue
Grd->>Grd: enforce model allowlist
Grd->>Up: forward (over WireGuard)
Up-->>Resp: response (JSON or SSE stream)
Resp->>Resp: parse usage tokens, completion
Note over Resp: capture_completion gates raw<br/>completion capture
Resp->>Cost: tokens
Cost->>Cost: lookup pricing.yaml + compute cost
Cost->>Rec: tokens + cost
Rec->>MgmtGrpc: RecordLLMUsage(provider, model, prompt_t, completion_t, cost, groups, user)
Rec-->>Log: emit access-log entry<br/>(if EnableLogCollection)
Log-->>Agent: 200 + body (streamed if SSE)
end
```
**Notes on the diagram**
- The chain runs in synth-defined order. Re-ordering middlewares
changes invariants — `llm_limit_check` must precede `llm_router` so
a denied request never hits upstream, and `llm_limit_record` must
pair with `llm_limit_check` so a successful check is always recorded
(or the rate-limit semantics break). See
[`modules/31-proxy-middleware-builtin.md`](modules/31-proxy-middleware-builtin.md).
- `llm_guardrail` is also where PII redaction happens
(`redact_pii = settings.RedactPii`). Phones, emails, credit cards,
PII names — see `redact.go` for the full set. See
[`modules/31-proxy-middleware-builtin.md`](modules/31-proxy-middleware-builtin.md).
- SSE streaming requires special handling on the response side; the
parser must handle partial chunks without buffering the whole
stream. See [`modules/32-proxy-llm-parsers.md`](modules/32-proxy-llm-parsers.md).
- Access-log emission is gated on `settings.EnableLogCollection`. With
it OFF, neither the deny nor the allow leg writes an entry — the
chain still runs (budget rules are still enforced) but no audit trail
is kept. See
[`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md).
---
## Flow C — Budget rule feedback loop
How an account's budget rules tighten ceilings on every request and how
consumption flows back into the dashboard.
```mermaid
flowchart LR
subgraph Operator
DashBud[Dashboard Budget Settings tab]
end
subgraph Mgmt[Management]
Save[POST/PUT /api/agent-network/budget-rules]
Store[(SQL store)]
Synth[SynthesizeServices]
Check[CheckLLMPolicyLimits RPC]
Rec[RecordLLMUsage RPC]
Cons[/api/agent-network/consumption]
end
subgraph Proxy[Proxy]
Chk[llm_limit_check]
RecMw[llm_limit_record]
end
subgraph DashView[Dashboard Budget Dashboard tab]
Panel[AgentConsumptionPanel]
end
DashBud -->|create / update rules| Save
Save --> Store
Store --> Synth
Synth -->|push synth-services to peer| Proxy
Chk -->|per request| Check
Check -->|aggregate matching rules<br/>min-wins all-must-pass| Store
Check -->|allow / deny| Chk
RecMw -->|post-response| Rec
Rec -->|tokens + cost + groups + user| Store
Store -->|read counters| Cons
Cons --> Panel
```
**Notes on the diagram**
- **min-wins all-must-pass** is the core semantic. A budget rule binds
to (group set, user set) with a (window, ceiling). At check time,
every rule that matches the caller is evaluated; if ANY rule has
zero remaining quota the request is denied. This is the most
surprising semantic for operators — see the invariants section of
[`modules/21-management-agentnetwork.md`](modules/21-management-agentnetwork.md).
- The proxy never makes its own budget decisions. It always asks
management via `CheckLLMPolicyLimits` and reports back via
`RecordLLMUsage`. This keeps account-wide accounting in one place
and avoids per-proxy drift.
- `RecordLLMUsage` must carry `group_ids` and `user_id` so the
decrement hits the right rule(s). The wire that carries those
fields onto the response leg is `respInput` in `reverseproxy.go`. See
[`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md).
- The dashboard's Budget Dashboard tab polls
`/api/agent-network/consumption` — not gRPC, not WebSocket. Poll
interval lives in `AgentConsumptionPanel.tsx`. See
[`modules/40-dashboard.md`](modules/40-dashboard.md).
---
## Cross-references
- Per-module guides: [`modules/`](modules/)
- Overview + module map: [`00-overview.md`](00-overview.md)

View File

@@ -0,0 +1,66 @@
# Agent Networks — architecture documentation
A self-contained set of documents describing the agent-networks feature:
an LLM-aware reverse-proxy middleware system plus account-level controls
(budget rules, log collection toggles, PII redaction). The management
server synthesises a per-peer middleware chain that the proxy executes on
every LLM request.
## What to read first
1. **[00-overview.md](00-overview.md)** — the single entry point. Feature
scope, the module map, and the cross-cutting topics worth keeping in
mind, with links to every per-module guide.
2. **[01-end-to-end-flows.md](01-end-to-end-flows.md)** — three
high-level mermaid diagrams: config-to-runtime synth/delivery,
per-request lifecycle through the LLM chain, and the budget-rule
feedback loop.
3. **Per-module guides** under `modules/` — one file per package. Each
describes the module boundary, the file-level layout, its own flow
diagrams, the public contracts, the invariants it relies on, and the
areas worth the closest attention.
## Directory layout
```
docs/agent-networks/
├── README.md # you are here
├── 00-overview.md # feature summary + module map
├── 01-end-to-end-flows.md # cross-module mermaid diagrams
└── modules/
├── 10-shared-api.md # proto + OpenAPI wire contracts
├── 20-management-store.md # SQL persistence layer
├── 21-management-agentnetwork.md # domain layer + synthesizer (largest)
├── 22-management-handlers-wiring.md # HTTP API + gRPC delivery
├── 30-proxy-middleware-framework.md # generic plugin system
├── 31-proxy-middleware-builtin.md # 8 LLM-aware middlewares
├── 32-proxy-llm-parsers.md # OpenAI/Anthropic/Bedrock SDKs + pricing
├── 33-proxy-runtime.md # translate + serve + access-log
├── 40-dashboard.md # UI for everything above (lives in the dashboard repo)
└── 50-path-routed-providers.md # Vertex AI + Bedrock (path-routed, keyfile:: creds, /bedrock prefix)
```
The `40-dashboard.md` module documents code that lives in the **dashboard
repo**, not in this repo. The guide is co-located here so backend readers
see the full picture in one place.
## How the per-module guides are structured
Every `modules/*.md` follows the same template so the docs are easy to
scan:
- **Module boundary** — what this package owns; where it sits in the stack.
- **Files** — path / role.
- **Architecture & flow** — one or more mermaid diagrams.
- **Public contracts** — function signatures, gRPC messages, JSON shapes.
- **Invariants** — semantic guarantees the module relies on or enforces.
- **Things to scrutinize** — split by correctness / security /
concurrency / backward-compat / performance / observability.
- **Test coverage** — the test files that lock down behaviour in this
module.
- **Known limitations / non-goals** — what is intentionally out of scope.
- **Cross-references** — upstream/downstream module links + the
end-to-end flow + the overview.
See [00-overview.md](00-overview.md) for the module map and the
cross-cutting topics.

View File

@@ -0,0 +1,105 @@
# shared/api — wire contracts (proto + OpenAPI)
> **Risk level:** Medium — wire-format surface that every other module pins against; backward-compat hinges on field-number discipline more than on logic correctness.
> **Backward-compat impact:** Additive only (new proto fields use unallocated numbers, new RPCs default to `Unimplemented`, new OpenAPI schemas/paths are append-only; no existing field/RPC/schema removed or renumbered).
## Module boundary
This module owns the cross-process contract surface between management, proxy, and dashboard. Two artefacts: `shared/management/proto/proxy_service.proto` (management↔proxy gRPC) and `shared/management/http/api/openapi.yml` (dashboard/CLI↔management REST). Both have generated companions checked in (`proxy_service.pb.go`, `proxy_service_grpc.pb.go`, `types.gen.go`) which must travel in lockstep with their sources. `shared/management/status/error.go` is in scope only for the four new typed `NotFound` constructors that the new HTTP handlers return.
Everything downstream — `management/agentnetwork`, `management/server/http/handlers/*`, `proxy/internal/*`, the dashboard SDK — consumes these types verbatim. The concern here is wire stability and codegen reproducibility, not behaviour: behaviour is covered in the management and proxy module guides.
`management.proto` and `signalexchange.proto` are unchanged. `status/error.go` only receives four additive constructors (lines 208-227); no existing error types are reshaped.
## Files
| Path | Role |
| ---- | ---- |
| `shared/management/proto/proxy_service.proto` | Source of truth: 2 new RPCs, 1 new message group (`MiddlewareConfig` + slot enum), additive fields on `PathTargetOptions`, `AccessLog`, `RecordLLMUsageRequest` |
| `shared/management/proto/proxy_service.pb.go` | Generated (protoc-gen-go) |
| `shared/management/proto/proxy_service_grpc.pb.go` | Generated; adds `CheckLLMPolicyLimits` + `RecordLLMUsage` client/server stubs and `UnimplementedProxyServiceServer` defaults |
| `shared/management/http/api/openapi.yml` | 15 new `AgentNetwork*` schemas, 9 new path groups under `/api/agent-network/*` |
| `shared/management/http/api/types.gen.go` | Generated (oapi-codegen; see codegen note below) |
| `shared/management/status/error.go` | Four `NotFound` constructors for the new resource kinds (lines 208-227) |
## Architecture & flow
```mermaid
sequenceDiagram
participant Dash as Dashboard / CLI
participant Mgmt as management (HTTP+gRPC)
participant Px as proxy
Note over Dash,Mgmt: REST (OpenAPI / types.gen.go)
Dash->>Mgmt: PUT /api/agent-network/providers (AgentNetworkProviderRequest)
Dash->>Mgmt: PUT /api/agent-network/settings (AgentNetworkSettingsRequest)
Dash->>Mgmt: GET /api/agent-network/consumption -> [AgentNetworkConsumption]
Note over Mgmt,Px: gRPC ProxyService (proxy_service.proto)
Mgmt-->>Px: SyncMappingsResponse{ ProxyMapping.path[*].options.middlewares,<br/>agent_network, disable_access_log, capture_* }
Px->>Mgmt: CheckLLMPolicyLimits(account, user, groups, provider, model)
Mgmt-->>Px: decision=allow|deny + selected_policy_id + attribution_group_id + window_seconds
Px->>Mgmt: RecordLLMUsage(account, user, group_id, group_ids, window_seconds, tokens, cost)
Px->>Mgmt: SendAccessLog(AccessLog{ agent_network=true })
```
The proto changes split into three independent slices: (1) **mapping enrichment**`PathTargetOptions` grows fields 8-13 so management can ship middleware configs, capture limits, and the agent-network / log-suppression flags down to the proxy without a second RPC; (2) **two new request/response RPCs** (`CheckLLMPolicyLimits`, `RecordLLMUsage`) for per-LLM-request budget arbitration; (3) **observability tag**`AccessLog.agent_network` so management can route logs to the right surface.
The OpenAPI side is a thin CRUD surface — every resource (`Provider`, `Policy`, `Guardrail`, `BudgetRule`, `Settings`) follows the same `GET-list / POST / GET / PUT / DELETE` pattern, plus a read-only `/consumption` listing and a catalog endpoint. The `*Request` variants drop server-controlled fields (id, timestamps). `AgentNetworkBudgetRule` deliberately reuses `AgentNetworkPolicyLimits` to keep wire-shape parity with policies.
## Public contracts added
- gRPC RPCs (`proxy_service.proto:52-57`): `CheckLLMPolicyLimits(CheckLLMPolicyLimitsRequest) → CheckLLMPolicyLimitsResponse`, `RecordLLMUsage(RecordLLMUsageRequest) → RecordLLMUsageResponse`. Both unary; default `UnimplementedProxyServiceServer` returns `codes.Unimplemented` (`proxy_service_grpc.pb.go:283-289`).
- New messages (`proxy_service.proto:145-175,448-502`): `MiddlewareConfig`, `MiddlewareSlot` enum, `CheckLLMPolicyLimitsRequest`/`Response`, `RecordLLMUsageRequest`/`Response`.
- New `PathTargetOptions` fields 8-13 (`proxy_service.proto:124-140`): `capture_max_request_bytes`, `capture_max_response_bytes`, `capture_content_types`, `middlewares`, `agent_network`, `disable_access_log`. All default-false / zero; pre-existing fields 1-7 byte-for-byte unchanged.
- `AccessLog.agent_network = 18` (`proxy_service.proto:258-261`).
- `RecordLLMUsageRequest.group_ids = 8` (`proxy_service.proto:496-498`) — so the record path can fan out to every applicable budget rule's window without a re-lookup.
- 15 new OpenAPI component schemas (`openapi.yml:5072-5829`): `AgentNetworkProvider[Request|Model]`, `AgentNetworkCatalog{Model,Provider,IdentityInjection,HeaderPairInjection,JSONMetadataInjection,ExtraHeader}`, `AgentNetworkPolicy[Request|TokenLimit|BudgetLimit|Limits]`, `AgentNetworkGuardrail[Checks|Request]`, `AgentNetworkConsumption`, `AgentNetworkSettings[Request]`, `AgentNetworkBudgetRule[Request]`.
- 9 new path groups (`openapi.yml:12797-13460`): `/api/agent-network/{consumption,settings,budget-rules,budget-rules/{ruleId},catalog/providers,providers,providers/{providerId},policies,policies/{policyId},guardrails,guardrails/{guardrailId}}`.
- Four typed NotFound errors (`shared/management/status/error.go:208-227`).
## Invariants
- **Field-number monotonicity.** Every new proto field uses a previously-unallocated number in its message: `PathTargetOptions` 8-13 (was 1-7), `AccessLog` 18 (was 1-17), `RecordLLMUsageRequest` 8. `SendStatusUpdateRequest.inbound_listener = 50` (pre-existing) reserves 50+ for observability extensions, so 8 on `RecordLLMUsageRequest` doesn't conflict.
- **Old proxies stay compatible.** Old management never sends `disable_access_log`/`middlewares`/`agent_network` (zero value → existing behaviour); old proxies that don't decode these fields just drop them silently (proto3 unknown-field semantics) — log emission stays on. No pre-existing field number changed: the proto change is insertions only.
- **Old management stays compatible.** The two new RPCs are registered on the same `management.ProxyService` descriptor; old proxies hitting them get `codes.Unimplemented` from the unimplemented embed (`proxy_service_grpc.pb.go:283-289`), which is the same fallback pattern `SyncMappings` already documents (`proxy_service.proto:20-21`).
- **OpenAPI shapes are append-only.** New schemas are placed at the end of `components.schemas` (line 5072+); new paths at the end of `paths` (line 12797+). No existing schema's `required` list, enum, or property type was changed.
- **`*Request` vs response asymmetry.** Read shapes (`AgentNetworkProvider`, `AgentNetworkPolicy`, `AgentNetworkGuardrail`, `AgentNetworkSettings`, `AgentNetworkBudgetRule`) require `created_at`/`updated_at`; the matching `*Request` shapes do not — server fills them. `AgentNetworkProviderRequest.api_key` is write-only (`openapi.yml:5158-5161` "never returned in responses"); reviewers should confirm the response schema (5072-5138) actually omits `api_key`.
## Things to scrutinize
### Correctness
- `RecordLLMUsageRequest` carries both `group_id` (singular, the attribution group — field 3) and `group_ids` (plural, full membership — field 8). `b22d5a181` adds field 8 to drive account-budget fan-out; double-check that consumers can't accidentally key counters on the wrong one. Field comments at `proxy_service.proto:489-491` and `496-498` distinguish them but it's the kind of subtle thing a follow-up commit might collapse.
- `PathTargetOptions.disable_access_log` is the only field whose default-false meaning **changes semantics** on the proxy side: false → log (status quo), true → suppress. Synthesizer sets `DisableAccessLog = !settings.EnableLogCollection`, so a missing/default settings row yields `EnableLogCollection=false → DisableAccessLog=true → suppressed`. Worth confirming downstream (`agentnetwork.synthesizer`) that operator-defined private services never inherit this flag — the proto field default protects them, but only if synth code is explicit.
- `CheckLLMPolicyLimitsResponse.decision` is a free-form `string` (`proxy_service.proto:471`) rather than an enum. Only documented values are "allow" / "deny". An enum would prevent typo drift; consider before this RPC ships to external consumers.
- `deny_code` (`proxy_service.proto:478-481`) is documented as "a stable label" but is also a free string. Pin the allowed set somewhere observable to the proxy.
### Security
- `AgentNetworkProvider.api_key` MUST be write-only. Schema split (request has it at line 5158; response omits it) looks correct, but a regression here leaks the upstream provider credential to every dashboard reader. Check that the handler explicitly zeros it on the response path.
- `extra_values` / `identity_header_*` headers on `AgentNetworkProvider` get stamped onto upstream requests. Description at `openapi.yml:5099` says "values not declared by the catalog are ignored at synth time" — a contract this module documents but the synthesizer must enforce. Confirm the synth module honours it.
- Cluster + subdomain on `AgentNetworkSettings` are documented immutable (`openapi.yml:5686-5694`) and the `AgentNetworkSettingsRequest` (lines 5733-5752) doesn't accept them. Verify the `PUT /api/agent-network/settings` handler can't be tricked by extra JSON keys (oapi-codegen's `additionalProperties: false` is not declared here; spec defaults to permissive).
### Backward compatibility
- The proto change is field-number additive: every previously numbered field keeps the same name + type, and the change is insertions only (no deletions in `proxy_service.proto`), so this holds at the source-text level.
- `proxy_service_grpc.pb.go` adds two RPC handlers and registers them in `ProxyService_ServiceDesc.Methods` (lines 543-552). The existing entries are unchanged and order-preserving — gRPC method dispatch is name-keyed, so order doesn't matter, but reviewing the diff (no method renamed/dropped) is still worth a glance.
- OpenAPI 3.0 doesn't have a built-in deprecation flow for paths; if any client tooling iterates `paths.*`, the additive routes shouldn't break it, but generated SDKs (especially the dashboard's) need a regen to gain access to `AgentNetwork*`.
### Codegen pinning
- `generate.sh` (`shared/management/http/api/generate.sh:14`) installs `oapi-codegen@latest` rather than a pinned version. **This is a reproducibility gap** — re-running the script later may produce a different `types.gen.go`. Either pin the version in `generate.sh` (e.g. `@v2.7.0`) or document the pin in a `tools.go`.
- proto codegen has the protoc / protoc-gen-go version stamped in the generated file header (`proxy_service.pb.go:3-4`).
- Regenerate locally and confirm zero diff against the committed `types.gen.go` / `proxy_service.pb.go`.
## Test coverage
| Test file | Locks down |
| --------- | ---------- |
| None in this scope | The proto and OpenAPI sources are tested transitively by the handler tests (`shared/management/http/handlers/agentnetwork/...`) and by the synthesizer/manager tests (`management/server/agentnetwork/...`). No round-trip serialisation test exists in the `proto/` or `api/` packages themselves. |
| `shared/management/proto/*_test.go` | (absent) |
| `shared/management/http/api/*_test.go` | (absent) |
Acceptable for codegen artefacts, but a single golden-file test that re-runs `oapi-codegen` and `protoc` in CI and diffs against the checked-in files would close the reproducibility gap noted above.
## Known limitations / explicit non-goals
- **No deprecation surface.** Old fields/RPCs are kept silently; there is no `[deprecated = true]` annotation on anything. Acceptable here because nothing is being removed.
- **No proto-side validation.** Numeric ranges (e.g. `window_seconds >= 60`, `cost_usd >= 0`, capture-byte clamps) are enforced in the OpenAPI schema via `minimum:` and inside Go code by the proxy/management, but `proto3` itself can't express them; downstream is expected to validate every message.
- **`MiddlewareConfig.config_json` is `bytes`** (`proxy_service.proto:163`) — opaque to the proto layer. Schema validity is the middleware factory's problem. This is a deliberate tradeoff (per the comment at 161-162) but worth flagging: a corrupted/malicious config_json can only fail at proxy apply time, not at the wire-decode step.
- **No catalog endpoint schema for the catalog itself** — the catalog data ships as a `GET /api/agent-network/catalog/providers` returning `[AgentNetworkCatalogProvider]` (`openapi.yml:13024`), but the catalog source-of-truth lives in `management/server/agentnetwork/catalog`, not here.
- The reaper / GC design was cut from scope; no reaper-related types appear here.
## Cross-references
- Downstream: [management/store](20-management-store.md), [management/agentnetwork](21-management-agentnetwork.md), [management/handlers + wiring](22-management-handlers-wiring.md), [proxy/runtime](33-proxy-runtime.md)
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
- Top-level: [../00-overview.md](../00-overview.md)

View File

@@ -0,0 +1,112 @@
# management/store — persistence for agent-network entities
> **Risk level:** Medium — six brand-new tables behind AutoMigrate, one upsert-counter table that runs on the request hot path, and one column carrying an encrypted secret.
> **Backward-compat impact:** Additive (six new tables created by AutoMigrate; the `Store` interface gains 23 methods, but no existing column/index is touched).
## Module boundary
This module is the persistence layer for the Agent Network feature. Everything the management server stores about LLM proxying — providers, policies, guardrails, the per-account settings row, a usage-counter table written on every proxied LLM request, and the account-budget rules — flows through the methods added to `store.Store`. The module owns six tables, six entity types from `management/server/agentnetwork/types`, and a single hot-path upsert (`IncrementAgentNetworkConsumption`) consumed by the proxy fleet.
Out of scope here: the catalog of provider definitions (compiled-in, no DB), the synthesizer/manager built on top of these CRUDs (covered in [21-management-agentnetwork.md](21-management-agentnetwork.md)), and the HTTP handlers that translate API requests into Save/Delete calls.
## Files
| Path | Role |
| ---- | ---- |
| `management/server/store/sql_store_agentnetwork.go` | gorm implementations of all 23 store methods |
| `management/server/store/sql_store_agentnetwork_budgetrule_test.go` | round-trip + account-scoping coverage against a real sqlite store |
| `management/server/store/sql_store.go` | one import, six entities appended to the `AutoMigrate` slice (sql_store.go:40, sql_store.go:141-142) |
| `management/server/store/store.go` | 23 methods added to the `Store` interface (store.go:328-354) |
| `management/server/store/store_mock_agentnetwork.go` | mockgen output for the new interface surface |
## Tables added / migrations
All six tables are created by `db.AutoMigrate` invoked from `NewSqlStore` at sql_store.go:133-143. There is no hand-rolled SQL migration script — the schema is whatever GORM derives from the struct tags.
- `agent_network_providers``Provider.TableName()` at provider.go:76. PK `id`, index on `account_id`, named index `idx_agent_network_provider` on `provider_id`. Carries an at-rest-encrypted `api_key` and ed25519 `session_private_key` (provider.go:35,56). `extra_values` and `models` are JSON blobs (`serializer:json`).
- `agent_network_policies``Policy.TableName()` at policy.go:70. PK `id`, index on `account_id`. JSON columns: `source_groups`, `destination_provider_ids`, `guardrail_ids`, `limits`.
- `agent_network_guardrails``Guardrail.TableName()` at guardrail.go:41. PK `id`, index on `account_id`. JSON `checks`.
- `agent_network_settings``Settings.TableName()` at settings.go:33. PK `account_id` (one row per account), named index `idx_agent_network_settings_cluster_subdomain` on `subdomain` only — the index name implies a composite, but only one column is tagged.
- `agent_network_consumption``Consumption.TableName()` at consumption.go:46. Composite PK across `(account_id, dim_kind, dim_id, window_seconds, window_start_utc)` — the same tuple the upsert keys on.
- `agent_network_budget_rules``AccountBudgetRule.TableName()` at budgetrule.go:35. PK `id`, index on `account_id`. JSON `target_groups`, `target_users`, `limits`.
## CRUD surface added
Provider, Policy, Guardrail, BudgetRule follow the same pattern: `Get<Kind>ByID`, `GetAccount<Kind>` (list), `Save<Kind>` (upsert), `Delete<Kind>`, with account-scoping enforced by the existing `accountAndIDQueryCondition` / `accountIDCondition` constants (sql_store.go:59-62). Provider additionally exposes `GetAllAgentNetworkProviders` (cross-account, used by the synthesizer). Settings exposes `Get`/`GetByCluster`/`Save` (no delete — one row per account, created on first save). Consumption exposes the upsert `Increment`, a point `Get`, and a cross-window `List`.
## Architecture & flow
```mermaid
flowchart LR
handlers["HTTP handlers<br/>(management/server/agentnetwork)"] -->|Save/Delete| iface["Store interface<br/>store.go:328-354"]
manager["agentnetwork.Manager"] -->|Get*| iface
synth["synthesizer<br/>(global)"] -->|GetAllAgentNetworkProviders| iface
proxy["proxy fleet<br/>(hot path)"] -->|IncrementAgentNetworkConsumption| iface
iface --> sql["SqlStore methods<br/>sql_store_agentnetwork.go"]
iface -.gomock.-> mock["MockStore<br/>store_mock_agentnetwork.go"]
sql --> gorm["gorm.DB"]
gorm --> tables[("6 tables<br/>agent_network_*")]
sql --> enc["crypt.FieldEncrypt<br/>(provider only)"]
```
Reads decrypt provider secrets in-place; writes do `provider.Copy().EncryptSensitiveData(...)` before `db.Save` so the caller's in-memory object keeps the plaintext `api_key` (sql_store_agentnetwork.go:88-102). Every list/get takes a `LockingStrength` and applies `clause.Locking{Strength: ...}` when non-`None` — matching the rest of the store. The upsert path uses `clause.OnConflict` with `gorm.Expr` server-side increments so concurrent proxy nodes converge without read-modify-write races (sql_store_agentnetwork.go:321-335).
## Invariants enforced at the store layer
- **Account scoping.** Every entity-by-ID method keys on `account_id = ? and id = ?`; no cross-tenant leak path through the API is reachable as long as callers always pass the auth'd `accountID` (sql_store_agentnetwork.go:70,141,201,429).
- **NotFound mapping.** `gorm.ErrRecordNotFound` is translated to typed `status.NewAgentNetwork*NotFoundError`; `Delete*` returns NotFound when `RowsAffected == 0` (sql_store_agentnetwork.go:111-113,171-173,231-233,461-463).
- **Provider secret encryption at rest.** `SaveAgentNetworkProvider` always encrypts before persist; `Get*` always decrypts after read. The plaintext `api_key` never reaches the DB through this layer (sql_store_agentnetwork.go:31,54,80,90).
- **Consumption monotonicity.** The upsert only ever issues `col = col + ?` for the three counter columns — no decrement path exists (sql_store_agentnetwork.go:330-332).
- **Window alignment is the caller's responsibility.** The store stamps `WindowStartUTC` as-passed; alignment to epoch happens in `types.WindowStart` at consumption.go:51-58.
- **Settings has no Delete.** Intentional — one row per account, created on first save; the row sticks around for the account lifetime.
## Things to scrutinize
### Correctness
- `SaveAgentNetworkProvider` saves the copy (sql_store_agentnetwork.go:95). The caller's in-memory pointer therefore keeps plaintext `api_key` and any `CreatedAt`/`UpdatedAt` gorm autofills land on the copy, not the original. Callers that need synced timestamps must re-fetch.
- `IncrementAgentNetworkConsumption`'s `Create` provides initial counter values (`TokensInput: tokensIn`, etc.) in the row, and on conflict the assignments add the same deltas to the existing values. The insert-vs-update arithmetic is consistent. Cross-check that no engine in use (sqlite, postgres, mysql) silently rejects the `OnConflict` clause — GORM emits engine-specific SQL but `ON DUPLICATE KEY UPDATE` (mysql) vs `ON CONFLICT (...)` (sqlite/postgres) need their unique constraint to match the composite PK on `agent_network_consumption`; it does, by construction.
- `IncrementAgentNetworkConsumption` writes `updated_at: time.Now().UTC()` literally inside the assignments map (sql_store_agentnetwork.go:333) — fine, but it's a Go-side timestamp captured at call time, not a DB-side `now()`. Acceptable for an audit field.
- `GetAgentNetworkConsumption` returns a zero-valued non-nil row on `ErrRecordNotFound` (sql_store_agentnetwork.go:364-371). Document or rename — a typed sentinel error would be more orthodox; callers must know not to error-check.
### Concurrency / transactions
- Hot-path `IncrementAgentNetworkConsumption` runs outside any explicit transaction; concurrency safety relies entirely on the DB serialising the `ON CONFLICT` upsert against the composite PK. This is correct for postgres and mysql; for sqlite it serialises behind the single writer.
- `SaveAgentNetworkSettings` is a blind upsert with no version/etag — concurrent writes from two operators last-write-wins on the collection-toggle flags (settings.go:23-25). Acceptable for admin-curated state but worth flagging.
- `Save*Provider` uses `db.Save` on a struct with a PK already set — GORM emits UPDATE or INSERT based on row existence. No upsert clause is attached, so a race between two creates with the same generated `xid` (vanishingly unlikely) would surface as a PK violation.
### Migration safety
- All six tables ride `AutoMigrate` (sql_store.go:141-142). AutoMigrate is additive: new columns get added, but it never drops columns nor narrows types. Three `bool` columns on `agent_network_settings` (`EnableLogCollection`, `EnablePromptCollection`, `RedactPii`) default to false at the GORM/DDL layer for existing rows; the test at sql_store_agentnetwork_budgetrule_test.go:83-112 locks that down on a fresh sqlite. Verify postgres/mysql produce the same default.
- The named index `idx_agent_network_settings_cluster_subdomain` on settings.go:15 is declared on only `subdomain`. Either the cluster column also needs `gorm:"index:idx_agent_network_settings_cluster_subdomain"` to make it composite, or the name is misleading.
- The named index `idx_agent_network_provider` on `Provider.ProviderID` (provider.go:30) is *not* unique and not scoped to account — two providers in the same account with the same `provider_id` are permitted at the DB layer; uniqueness, if any, must live above the store.
### Backward compatibility
- Net additive. No removed methods, no renamed columns, no schema change to existing tables. Existing deployments running a prior binary continue to work; the first boot of the new binary creates the six tables.
- The `Store` interface grows by 23 methods (store.go:330-354); any non-mock external implementer of `store.Store` will fail to compile. The repo only has `SqlStore` + `MockStore`, both updated.
### Performance (indexes, N+1)
- All by-account list queries hit the `idx_account_id` per-table index. No N+1: list methods return the full slice in one query.
- `GetAgentNetworkSettingsByCluster` (sql_store_agentnetwork.go:263-277) does a tablescan on `cluster` — no index. Tolerable for the bootstrap label generator (one-shot at provisioning) but worth noting if the call moves onto a hot path.
- `ListAgentNetworkConsumption` returns every row ever recorded for the account (sql_store_agentnetwork.go:382-400) — unbounded growth, no `LIMIT`, no time filter. With one row per (dim, window) per request burst, this table grows fastest of the six; a retention job + a paginated list method are obvious follow-ups.
## Test coverage
| Test file | Locks down |
| --------- | ---------- |
| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkBudgetRule_RealStore_RoundTrip` | full save → reload of `AccountBudgetRule` including the JSON-serialised `PolicyLimits`, target slices, double-delete returns NotFound (lines 18-59) |
| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkBudgetRule_RealStore_ScopedByAccount` | cross-account isolation for budget rules (lines 63-78) |
| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkSettings_RealStore_CollectionTogglesRoundTrip` | collection toggles default off, survive save/reload at the set values (lines 83-112) |
Gap: there is no store-level test for providers (encryption round-trip), policies, guardrails, or `IncrementAgentNetworkConsumption` (concurrent upsert, window-key uniqueness). The consumption upsert is the most performance-sensitive method in this module and the only one without a real-sqlite test.
## Known limitations / explicit non-goals
- No retention / GC for `agent_network_consumption`.
- No `Delete` for `Settings` (one row per account, cleared with the account).
- No DB-engine-specific tuning — the same struct tags drive sqlite, mysql, postgres.
- Provider `extra_values` and `models` are JSON blobs; querying inside them is not supported by design.
- `GetAgentNetworkConsumption` "not-found = zero row" contract is convenient but unconventional.
## Cross-references
- Upstream: [shared/api](10-shared-api.md), [management/agentnetwork](21-management-agentnetwork.md)
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
- Top-level: [../00-overview.md](../00-overview.md)

View File

@@ -0,0 +1,225 @@
# management/agentnetwork — domain layer + synth pipeline
> **Risk level:** High — central business logic + budget enforcement + the source of every middleware-chain change the proxy executes.
> **Backward-compat impact:** Additive within the agent-network surface; one **behavioural difference for opted-out accounts** in parser capture (the capture flag is stamped explicitly false instead of being absent — see capture-pointer semantics below). Non-agent-network proxy services are untouched (the synth chain only ships on `agent-net-svc-*` targets).
## Module boundary
`management/server/agentnetwork` owns every agent-network entity (providers, policies, guardrails, account budget rules, per-account settings, consumption rows) and **translates them into the in-memory `*rpservice.Service` that the reverse-proxy controller turns into `proto.ProxyMapping`s and pushes to clusters**. It is the *only* writer of the agent-network middleware chain.
Inside the package: `manager.go` is the CRUD + permissions-gated facade; `synthesizer.go` walks settings + providers + policies + guardrails and emits the per-account service plus every middleware's JSON config; `policyselect.go` runs per-request attribution (min-wins account ceiling, then "drain bigger pool first"); `reconcile.go` diffs successive synth outputs and emits precise Create/Update/Delete proxy-mapping updates plus a peer-map refresh. `labelgen/` mints DNS-safe subdomain labels; `catalog/` is the static provider catalogue; `types/` carries gorm entity structs. The `_realstack_test.go` files in the parent `management/server/` directory exercise the manager + network-map controller end-to-end with no mocks.
## Files
| Path | Role |
| ---- | ---- |
| `agentnetwork/manager.go` | Manager interface + CRUD + permission gates + bootstrap-settings + reconcile trigger |
| `agentnetwork/synthesizer.go` | Settings/policy → wire-format synthesis; sole writer of the proxy middleware chain |
| `agentnetwork/policyselect.go` | Per-request policy attribution + account-budget ceiling (min-wins) |
| `agentnetwork/reconcile.go` | Per-account synth diff vs in-memory cache → Create/Update/Delete |
| `agentnetwork/catalog/catalog.go` | Static provider catalogue (auth headers, identity-injection shapes) |
| `agentnetwork/labelgen/{labelgen,words}.go` | DNS-safe subdomain picker + curated wordlist |
| `agentnetwork/types/provider.go` | Provider entity + APIKey + Models + ExtraValues + SessionKeys |
| `agentnetwork/types/policy.go` | Policy entity + `PolicyLimits` (token + budget) |
| `agentnetwork/types/guardrail.go` | Guardrail entity (`ModelAllowlist`, `PromptCapture`) |
| `agentnetwork/types/budgetrule.go` | `AccountBudgetRule` (reuses `PolicyLimits`) |
| `agentnetwork/types/settings.go` | Per-account `Settings` (Cluster, Subdomain, 3 toggles) |
| `agentnetwork/types/consumption.go` | `Consumption` row + `WindowStart` aligner |
| `agentnetwork/{synthesizer,policyselect,reconcile,wire_shape}_*test.go` | See test coverage table |
| `agentnetwork/types/consumption_test.go` | `WindowStart` alignment proofs |
| `agentnetwork/labelgen/labelgen_test.go` | Deterministic picks + exhaustion + fallback |
| `management/server/agentnetwork_realstack_test.go` | No-mock provider CRUD → network-map fan-out |
| `management/server/agentnetwork_budgetrule_realstack_test.go` | No-mock budget-rule CRUD + settings preserve-immutable |
## Architecture & flow
### Synthesis (settings/policy → wire format)
```mermaid
flowchart TD
A[Mutation: provider/policy/guardrail/settings] --> B[managerImpl.reconcile accountID]
B --> C{proxyController nil?}
C -- yes --> D[accountManager.UpdateAccountPeers only]
C -- no --> E[SynthesizeServices]
E --> F[loadSettings — NotFound returns ok=false, no synth]
F --> G[filterEnabledProviders sorted by CreatedAt]
G --> H[filterEnabledPolicies]
H --> I[backfillProviderSessionKeys if missing]
I --> J[indexProviderGroups: providerID -> sorted source groups]
J --> K[buildRouterConfigJSON drops orphan providers]
J --> L[buildIdentityInjectConfigJSON per catalog entry]
H --> M[mergeGuardrails: union allowlist, OR redact]
M --> N[applyAccountCollectionControls account toggle = SOLE capture control]
N --> O[marshalGuardrailConfig]
K --> P[buildMiddlewareChain 8 middleware entries]
L --> P
O --> P
P --> Q[buildAccountService: AccessGroups=union source groups, noop.invalid target]
Q --> R[reconcile.diffMappings vs cache]
R --> S[SendServiceUpdateToCluster CREATE/MODIFY/REMOVE]
R --> T[accountManager.UpdateAccountPeers — fans synth ACLs into network map]
```
### Budget rule resolution (min-wins, group+user bound)
```mermaid
flowchart TD
A[SelectPolicyForRequest in] --> B[checkAccountBudget — runs FIRST, independent of policies]
B --> C[GetAccountAgentNetworkBudgetRules]
C --> D{for each enabled rule}
D --> E{budgetRuleApplies?}
E -- no --> D
E -- yes --> F[attrGroup = lowestIntersect TargetGroups, in.GroupIDs]
F --> G{Token cap enabled?}
G -- yes --> H[evalTokenCap user dim + group dim]
H --> I{exhausted?}
I -- yes --> J[DENY: llm_account.token_cap_exceeded - STOP]
I -- no --> K{Budget cap enabled?}
G -- no --> K
K -- yes --> L[evalBudgetCap user dim + group dim]
L --> M{exhausted?}
M -- yes --> N[DENY: llm_account.budget_cap_exceeded - STOP]
M -- no --> D
K -- no --> D
D --> O[All rules passed -> fall through to per-policy selection]
```
Key invariant: **rules are checked sequentially and ANY exhausted rule denies (all-must-pass / min-wins).** Untargeted rules (`len(TargetGroups)==0 && len(TargetUsers)==0`) apply to every caller (`policyselect.go:393`).
### Policy selection (per-peer, per-request)
```mermaid
flowchart TD
A[Account-budget gate passed] --> B[GetAccountAgentNetworkPolicies]
B --> C[filterApplicablePolicies enabled + provider match + group intersect]
C --> D{candidates empty?}
D -- yes --> E[Allow, empty SelectedPolicyID]
D -- no --> F[scoreCandidates -> scoreOne per policy]
F --> G[scoreOne: attrGroup + window]
G --> H{any cap exhausted?}
H -- yes --> I[Drop policy; record last deny code]
H -- no --> K[Keep as live candidate]
F --> L{live candidates exist?}
L -- no --> M[Deny with last exhaustion code]
L -- yes --> N[Sort: uncapped wins -> larger group token -> group budget -> user token -> user budget -> oldest CreatedAt]
N --> O[winner = scored 0]
O --> P[Allow + SelectedPolicyID + AttributionGroupID + WindowSeconds]
```
End-to-end: a mutation calls `managerImpl.reconcile(ctx, accountID)` (`manager.go:205,239,...`). Reconcile defers an `accountManager.UpdateAccountPeers` so the network-map controller re-runs and `injectAllProxyPolicies` picks up the new access groups; with a `proxyController` wired, it re-synthesizes the service, diffs against `reconcileCache[accountID]` (guarded by `reconcileMu`), and emits proto mappings to the cluster derived from the mapping's domain (`reconcile.go:120`). Synthesis is stateless and idempotent. Sole persistent side effect: `backfillProviderSessionKeys` (`synthesizer.go:249`) mints ed25519 keys on legacy provider rows and writes them back.
At request time the path is independent: the proxy calls `SelectPolicyForRequest` (`policyselect.go:56`); account-budget ceiling first, then per-policy scoring. Token + budget caps share `evalTokenCap` / `evalBudgetCap` — same primitive for account rules and policy limits, `label` differentiates the deny reason. After a served request, `RecordAccountBudgetUsage` (`policyselect.go:415`) fans deltas to every applicable rule's distinct `(dim_kind, dim_id, window)` tuple, deduplicating to prevent double-count when two rules share target+window.
## Public contracts
- **Manager interface** (`manager.go:48-80`): CRUD for `Providers/Policies/Guardrails/BudgetRules`; `GetSettings/UpdateSettings` (cluster + subdomain immutable, only the three toggles mutate); `ListConsumption/RecordConsumption(account, kind, dimID, windowSec, in, out, USD)`; `RecordAccountBudgetUsage(account, user, groups, in, out, USD)`; `SelectPolicyForRequest(ctx, PolicySelectionInput) → *PolicySelectionResult{Allow, SelectedPolicyID, AttributionGroupID, WindowSeconds, DenyCode, DenyReason}`.
- **`PolicySelectionInput`** (`manager.go:85-90`): `{AccountID, UserID, GroupIDs, ProviderID}` — populated by the proxy from CapturedData + `llm_router` resolution.
- **Synthesized middleware chain** (`synthesizer.go:576-657`), order load-bearing — response slot runs reverse-of-slice:
| Slot | Idx | ID | ConfigJSON shape | CanMutate |
| --- | --- | --- | --- | --- |
| on_request | 0 | `llm_request_parser` | `{"capture_prompt": <bool>, "redact_pii"?: true}` | |
| on_request | 1 | `llm_router` | `{"providers":[{id, models[], upstream_*, auth_header_*, allowed_group_ids[]}]}` | **true** |
| on_request | 2 | `llm_limit_check` | `{}` | |
| on_request | 3 | `llm_identity_inject` | `{"providers":[{provider_id, header_pair?, json_metadata?, extra_headers?}]}` | **true** |
| on_request | 4 | `llm_guardrail` | `{"model_allowlist"?, "prompt_capture":{enabled,redact_pii}}` | |
| on_response | 5 | `llm_limit_record` | `{}` (runs LAST at runtime) | |
| on_response | 6 | `cost_meter` | `{}` | |
| on_response | 7 | `llm_response_parser` | `{"capture_completion": <bool>, "redact_pii"?: true}` | |
- **Synthesized service shape** (`synthesizer.go:739`): `Mode=HTTP`, `Private=true`, `Domain=<subdomain>.<cluster>`, `AccessGroups=unionSourceGroups(enabledPolicies)`, one `TargetTypeCluster` target with `Host=noop.invalid:443` (router rewrites per request), `Options.{DirectUpstream,AgentNetwork}=true`, `DisableAccessLog=!settings.EnableLogCollection`, `CaptureMax{Req,Resp}Bytes=1<<20`, `CaptureContentTypes=["application/json","text/event-stream"]`.
## Invariants
- **Min-wins / all-must-pass for account budget rules** (`checkAccountBudget`, `policyselect.go:353`): every applicable enabled rule is checked; first exhausted cap denies. Untargeted rules bind every caller.
- **Account toggle is the SOLE control for capture enablement.** `applyAccountCollectionControls` (`synthesizer.go:701`) sets `merged.PromptCapture.Enabled = settings.EnablePromptCollection` *unconditionally*.
- **Capture-pointer semantics on parser configs** — see "Things to scrutinize" below.
- **`EnableLogCollection``DisableAccessLog` is the only access-log toggle** (`synthesizer.go:770`). Default off ⇒ access log suppressed.
- **`RedactPii` flows verbatim to BOTH parsers** (`synthesizer.go:584-585`) and is OR'd into the merged guardrail (`synthesizer.go:706`).
- **Cluster and Subdomain are immutable on Settings.** `UpdateSettings` reloads existing row and overlays only the three toggles (`manager.go:558-561`).
- **Orphan providers (no enabled policy authorises them) NEVER reach the router** (`synthesizer.go:351-357`); skipped from `identity_inject` for symmetry.
- **Provider creation refuses empty `api_key`** (`manager.go:175`); **deletion refuses while any policy still references it** (`manager.go:265-273`).
- **Session keypair stability across provider edits** (`manager.go:226-228`) — server-managed, copied through every `UpdateProvider`, never API-surfaced.
## Things to scrutinize
### Correctness
- **Capture-pointer semantics — `*bool` vs `bool`.** Three states, owned by separate sides:
- **Wire JSON this module emits:** `buildParserConfigJSON` (`synthesizer.go:678-693`) *always* stamps the capture field. Agent-network targets ship `"capture_prompt": false` or `"capture_prompt": true` — never absent. Same for `"capture_completion"`. The happy-path test pins `{"capture_prompt":false}` (`synthesizer_test.go:174`).
- **Proxy-side parser config (consumer):** parsers decode into `*bool`. Matrix:
- `nil` (field absent) → **legacy default = emit**. Preserved for non-agent-network callers and pre-existing tests (the backward-compat hook).
- `false` (field present, value false) → **suppress emission entirely**. The behaviour for opted-out agent-network accounts. Without this, `enable_log_collection=true` + `enable_prompt_collection=false` would leak raw user input AND raw model output to the access log.
- `true` → emit normally.
- **Why the synth always stamps a value:** an agent-network mapping omitting the field would hit legacy "always emit" and re-introduce the leak. The `json.Marshal` error fallback at `synthesizer.go:687` degrades to `{}` — comment-claimed unreachable, but if ever fired re-introduces the leak. Consider fail-closed (return literal `{"capture_prompt":false}`) instead.
- **`scoreCandidates` non-cumulative deny code.** Only the *last* exhausted policy's deny code survives (`policyselect.go:188-190`). Iteration order is store's natural order. Auth signal is `len(scored)==0`, so this is informational only — verify no UI depends on "first exhausted policy" semantics.
- **`effectiveWindowSeconds` token-wins tiebreak.** When both halves are enabled with different windows, token's window wins (`policyselect.go:482`). Verify `RecordLLMUsage` increments against the winning window only.
- **`RecordAccountBudgetUsage` dedup.** Two rules with the same `(kind, dim_id, window)` would double-count without the `tuples` map (`policyselect.go:434-449`). Key includes all three dimensions — correct.
- **Fail-closed on bad provider:** unknown catalog id (`synthesizer.go:794-796`) or empty API key (`synthesizer.go:801-803`) drops the **entire** account's synth, not just the bad provider. Confirm matches operator UX.
### Security
- **Redact OR-merge:** merged `RedactPii` = account OR guardrail (`synthesizer.go:706`). **Parser-side flag is `settings.RedactPii` only, NOT the OR** — a guardrail-only opt-in does not propagate to parsers. Correct because the account toggle gates capture, but worth noting on the proxy side.
- **Group resolution must not leak across accounts.** Every store call carries `accountID` (`policyselect.go:73, 286, 298, 322, 334, 354`); `lowestIntersect` uses caller's claimed groups only (`policyselect.go:494`). Risk surface is upstream (handler populates `in.GroupIDs`).
- **`UpdateSettings` preserves immutable Cluster + Subdomain** (`manager.go:558`). A client can't rebind the cluster.
- **Provider session keypair backfill writes through `SaveAgentNetworkProvider`** (`synthesizer.go:256`) from a read-shaped call. Idempotent → worst case is a wasted write under concurrent reconcile + snapshot.
### Concurrency
- **`reconcileMu`** guards `reconcileCache`. Lock window is narrow — compute diff inside, send outside (`reconcile.go:56-68`).
- **`labelRngMu`** guards `labelRng` because `math/rand.Source` is unsafe for concurrent use (`manager.go:638-640`).
- **Real-store tests** use `store.NewTestStoreFromSQL` with `t.TempDir()` per test — no shared state, no `t.Parallel()`.
- **`RecordAccountBudgetUsage` dedup `tuples` map is per-call;** concurrent calls fan out fully — correct (each request's tokens book once per applicable rule).
- **Deferred `UpdateAccountPeers` runs inline after the proxy push** (`reconcile.go:28-35`); a slow call stretches CRUD response time.
### Backward compatibility
- **Capture-pointer semantics (restated):** non-agent-network callers see no field → legacy nil-default emit, identical to pre-PR. Agent-network targets always carry an explicit `capture_*` value.
- **`TestSynthesizeServices_HappyPath` was updated:** request-parser config moved from `{}` to `{"capture_prompt":false}` (`synthesizer_test.go:174`). External snapshot tests against synth output need updating.
- **`MergedGuardrails` retains zeroed `TokenLimits`/`Budget`/`Retention`** even though `Policy.Limits` carries the real values now; `llm_limit_check` is the authoritative enforcement. Comment at `synthesizer.go:940-948` calls this out.
### Performance
- **`SynthesizeServices` runs on every controller tick / mutation reconcile.** Cost: 4 store reads + optional per-provider keypair backfill. Sort + index + merge are O(N log N) / O(P × G); dominant cost is JSON marshalling. No nested loops escape these dimensions.
- **`reconcile.diffMappings` is O(N + M)** with N=M=1 per account today — effectively constant.
- **`SynthesizeServicesForCluster`** (`synthesizer.go:71`) walks every account on a cluster; per-account failures are **swallowed** (`synthesizer.go:91-93`) so a single misconfigured account doesn't drop the cluster. Runs per proxy reconnect.
### Observability
- **Activity codes:** `AgentNetwork{Provider,Policy,Guardrail,BudgetRule}{Created,Updated,Deleted}`; `AgentNetworkSettingsUpdated` with `log_collection/prompt_collection/redact_pii` payload (`manager.go:567-571`). **No activity code for `SelectPolicyForRequest` denies** — surfaced via proxy access log only (likely intentional given volume).
- **Deny codes** namespaced: `llm_policy.{token,budget}_cap_exceeded`, `llm_account.{token,budget}_cap_exceeded` (`policyselect.go:18-26`).
- **Reconcile failures are logged at warn and swallowed** (`reconcile.go:42-44`). Persistent synth failures (e.g. unknown catalog id) silently keep the proxy out of sync — consider a manager-level synth-health surface if this becomes a support burden.
## Test coverage
| Test file | Locks down |
| --------- | ---------- |
| `synthesizer_test.go` | Mock-store: `HappyPath` (8-mw chain ordering, `{"capture_prompt":false}` baseline); `No{Settings,Providers}`; `Disabled{Provider,Policy}_NoService`; `RouterConfigOrdering`; `PolicyCheckConfig_UnionsSourceGroups`; `OrphanProvider_HasEmptyAllowedGroups`; identity-inject for LiteLLM / Bifrost (overrides + partial disable) / Cloudflare / Portkey / Vercel / OpenRouter / generic non-customizable; `GuardrailMerge_AllowlistUnion_LimitsRestrictive`; `BackfillsMissingSessionKeys`; `HTTPUpstream_KeepsExplicitPort`; `UpstreamURLPath_FlowsToRouter`; `UnknownProviderID_FailsClosed`; `EmptyAPIKey_FailsClosed`. |
| `synthesizer_realstore_test.go` | Real-sqlite: `SurvivesStatusToggle` reproduces the disable/re-enable 403 regression; `Reconcile_RealStore_PushesPrivateAfterStatusToggle` extends through reconcile push. |
| `synthesizer_guardrail_realstore_test.go` | `PromptCaptureAccountIsSoleControl`; `PromptCaptureFlowsWhenAccountOptsIn`; `AccountRedactWithoutGuardrailRedact`; `NoGuardrail_CaptureOff`. |
| `synthesizer_log_collection_realstore_test.go` | `LogCollection{Off_SuppressesAccessLog,On_PermitsAccessLog}` — verifies `DisableAccessLog` propagation through `ToProtoMapping`. |
| `synthesizer_parser_redact_realstore_test.go` | **Capture-pointer regression suite:** `ParserConfigsCarryRedactPii`; `ParserConfigsSuppressCaptureWhenLogCollectionOnly` (log=on/prompt=off ⇒ both capture flags false); `ParserConfigsOmitRedactPiiWhenOff`. |
| `policyselect_test.go` | Mock-store: `NoApplicablePolicies`; `AllowWithLowestGroupAttribution`; `LargerPoolWinsAcrossUsageLevels`; `StaysOnLargerPoolAfterPartialDrain`; `FallsThroughToSmallerPoolWhenLargerExhausted`; `TiebreakBy{LargerGroupPool,CreatedAt}`; `DeniesWhenAllExhausted`; `UncappedPolicyAlwaysWinsAgainstCapped`; `DisabledPolicyIgnored`; `StoreErrorPropagates`; `RejectsEmptyAccount`; `SharesGroupCounterAcrossPolicies`; `AntiFallThroughOnLowestGroup`; `BudgetOnlyExhaustionDenies`; `BudgetTighterThanTokenWins`. |
| `policyselect_realstore_test.go` | Real-sqlite regression guard: `NoApplicablePolicies`; `AllowAndLowestGroupAttribution`; `LargerPoolWins_FallsThroughWhenExhausted`; `BudgetCapDenies`; `GroupCounterSharedAcrossPolicies`; `DisabledPolicyIgnored`. |
| `policyselect_account_realstore_test.go` | Account budget rules: `AccountCeilingBindsEvenWithUncappedPolicy` (min-wins); `AccountGroupCeiling`; `AccountTargetUsersBindsOnlyThatUser`; `AccountRuleRecordsToOwnWindow`. |
| `reconcile_test.go` | `FirstSynth_EmitsCreate`; `NoChange_EmitsNothingExtra` (re-push as Modified — verify desired); `PolicyRemoved_EmitsDelete`; `NilProxyController_NoOp`; `EmptyAccountID_NoOp`; `ClusterFromMapping`. |
| `wire_shape_test.go` | `TestSynthesizedService_WireShape` — proto-shape lockdown via `ToProtoMapping`. Catches "service not matching" (mapping reaches proxy but no SNI/HTTP route). Asserts ID, Domain, Mode, AuthToken, `Private`, `Auth.Oidc=false`, one path `/` + `https://noop.invalid/`, 8 middlewares with correct slot enums, router config `auth_header_value="Bearer sk-test-key"`. |
| `labelgen/labelgen_test.go` | `PickUnique_{DeterministicWithSeededRng,AvoidsTakenWordsWhenMostAreReserved,FallsBackWhenAllReserved}`; `UniqueWords_DropsDuplicates`. |
| `types/consumption_test.go` | `WindowStart_{AlignedToUnixEpoch,WithinWindowConverges,AcrossWindowsDiverges,DifferentWindowsHaveDifferentBuckets,SubMinuteAndMinuteAlignment,ZeroWindowReturnsInputUTC}`. Bucket alignment so multi-node reads converge. |
| `agentnetwork_realstack_test.go` | `ProviderCRUD_FansOutToProxyAndClientPeers` — no-mock end-to-end through real account manager + network-map + agentnetwork: provider create propagates the updated map to both proxy peer and client peer with the synth DNS surface. |
| `agentnetwork_budgetrule_realstack_test.go` | `BudgetRuleCRUD_RealManager`; `UpdateSettings_PreservesImmutableAndTogglesCollection`. |
## Known limitations / explicit non-goals
- **`MergedGuardrails.TokenLimits/Budget/Retention` emit at zero** (`synthesizer.go:940-948`); real enforcement is `Policy.Limits` via `llm_limit_check`. Future cleanup implied.
- **Session keys picked from first enabled provider by created_at** (`pickServiceSessionKeys`, `synthesizer.go:270`). Existing session cookies survive provider edits only while the first-by-CreatedAt provider stays in place. Document for operators.
- **Reconcile failures silently swallowed** (`reconcile.go:42-44`). Persistent failures keep the proxy out of sync until the next reconcile.
- **`scoreCandidates` exposes only the LAST exhaustion's deny code** when multiple policies are exhausted.
- **`bootstrapSettingsIfNeeded` failure is non-fatal to provider create** (`manager.go:200`): provider lands, synth is no-op until the next provider create retries the bootstrap.
- **Budget rules do not trigger a reconcile** (`manager.go:476-477`). Request-time evaluation only; new rules take effect on the next request without a proxy push.
## Cross-references
- **Upstream:** [shared/api](10-shared-api.md), [management/store](20-management-store.md), reverseproxy `service`/`proxy`/`sessionkey` packages, `management/server/permissions` + `activity`.
- **Downstream:** [management/handlers (HTTP wiring)](22-management-handlers-wiring.md), [proxy/middleware-builtin](31-proxy-middleware-builtin.md), network-map controller (`injectAllProxyPolicies` fan-out).
- **End-to-end flow:** [../01-end-to-end-flows.md](../01-end-to-end-flows.md) — "Provider create → reconcile → proxy push → peer map refresh" and "request → policy select → record" diagrams.
- **Top-level:** [../00-overview.md](../00-overview.md)

View File

@@ -0,0 +1,203 @@
# management/handlers + wiring — HTTP API + gRPC delivery
> **Risk level:** Medium — the surface is mostly additive, but two changes are load-bearing: `injectAllProxyPolicies` runs on every per-peer compute, and `shallowCloneMapping` must round-trip `Private` (a missed field silently breaks every MODIFIED).
> **Backward-compat impact:** Additive on the wire (new routes, new RPCs, new proto fields, new gorm column on `AccessLogEntry`). One management-internal break: `nbhttp.NewAPIHandler` gains a trailing `agentNetworkManager` parameter; `nil` is tolerated and silently skips route registration.
## Module boundary
This module is the seam between the public Agent Network HTTP API and the proxy fleet that serves agent traffic. North side: a `/api/agent-network/*` surface (providers, policies, guardrails, budget rules, settings, consumption) on the existing gorilla router, delegating to `agentnetwork.Manager`. Handlers are thin — they translate `api.*``types.*`, validate shape, forward. RBAC and event emission stay inside the manager (`manager.go:680-682`).
South side: `ProxyServiceServer` (`proxy.go`) learns to (a) ship synth services to a proxy on initial snapshot, (b) resolve agent-network domains in `getServiceByDomain` for OIDC/session/tunnel-peer flows, (c) gate LLM requests via `CheckLLMPolicyLimits` + `RecordLLMUsage`, (d) preserve `Private` through `shallowCloneMapping` so per-proxy live updates don't silently flip services public. The network_map controller prepends synth services to `account.Services` on every per-peer compute; `accesslogentry.go` gains an indexed `AgentNetwork` column so the dashboard can filter cheaply.
## Files
| Path | Role |
| ---- | ---- |
| `handlers/agentnetwork/providers_handler.go` | Catalog + provider CRUD + central `AddEndpoints` |
| `handlers/agentnetwork/policies_handler.go` | Policy CRUD + shared `validatePolicy*` |
| `handlers/agentnetwork/guardrails_handler.go` | Guardrail CRUD |
| `handlers/agentnetwork/budget_handler.go` | Account-level budget rule CRUD |
| `handlers/agentnetwork/settings_handler.go` | GET (200+`null` if unbootstrapped) + PUT toggles |
| `handlers/agentnetwork/consumption_handler.go` | Read-only consumption rows |
| `handlers/agentnetwork/handlers_test.go` | Real-store fixture; wire round-trip + validation |
| `handlers/agentnetwork/budget_handler_test.go` | Budget-rule + settings toggles |
| `server/http/handler.go` | New `agentNetworkManager` arg; conditional `AddEndpoints` |
| `server/permissions/modules/module.go` | New `AgentNetwork` module key |
| `internals/server/boot.go` | Wires synthesiser adapter + limits service into proxy server |
| `internals/server/modules.go` | `AgentNetworkManager()` lazy-create node |
| `internals/controllers/network_map/controller/controller.go` | `injectAllProxyPolicies` replaces 4 `InjectProxyPolicies` calls |
| `internals/controllers/network_map/controller/repository.go` | `SynthesizeAgentNetworkServices` repo method |
| `internals/modules/reverseproxy/service/service.go` | `MiddlewareConfig`, capture limits, `AgentNetwork`, `DisableAccessLog` + proto |
| `internals/modules/reverseproxy/accesslogs/accesslogentry.go` | Indexed `AgentNetwork bool` from proto |
| `internals/shared/grpc/proxy.go` | Synth wiring, 2 RPCs, domain fallback, `Private` in clone |
| `internals/shared/grpc/proxy_clone_test.go` | Locks every `ProxyMapping` field minus `AuthToken` |
| `server/activity/codes.go` | 13 new activity codes (125-137) |
## HTTP routes added
All routes inherit the platform's auth middleware. Perms enforced inside `agentnetwork.Manager.requirePermission` (`manager.go:680-682`) on `modules.AgentNetwork`. Permission column shows the `op` passed to `requirePermission` — read = `Read`, etc.
| Method | Path | Perm | Handler |
| ------ | ---- | ---- | ------- |
| GET | `/agent-network/catalog/providers` | authn only | `providers_handler.go:43` |
| GET | `/agent-network/providers` | read | `providers_handler.go:57` |
| POST | `/agent-network/providers` | create | `providers_handler.go:97` |
| GET | `/agent-network/providers/{providerId}` | read | `providers_handler.go:77` |
| PUT | `/agent-network/providers/{providerId}` | update | `providers_handler.go:132` |
| DELETE | `/agent-network/providers/{providerId}` | delete | `providers_handler.go:172` |
| GET | `/agent-network/policies` | read | `policies_handler.go:32` |
| POST | `/agent-network/policies` | create | `policies_handler.go:72` |
| GET | `/agent-network/policies/{policyId}` | read | `policies_handler.go:52` |
| PUT | `/agent-network/policies/{policyId}` | update | `policies_handler.go:102` |
| DELETE | `/agent-network/policies/{policyId}` | delete | `policies_handler.go:142` |
| GET | `/agent-network/guardrails` | read | `guardrails_handler.go:25` |
| POST | `/agent-network/guardrails` | create | `guardrails_handler.go:65` |
| GET | `/agent-network/guardrails/{guardrailId}` | read | `guardrails_handler.go:45` |
| PUT | `/agent-network/guardrails/{guardrailId}` | update | `guardrails_handler.go:95` |
| DELETE | `/agent-network/guardrails/{guardrailId}` | delete | `guardrails_handler.go:135` |
| GET | `/agent-network/budget-rules` | read | `budget_handler.go:24` |
| POST | `/agent-network/budget-rules` | create | `budget_handler.go:64` |
| GET | `/agent-network/budget-rules/{ruleId}` | read | `budget_handler.go:44` |
| PUT | `/agent-network/budget-rules/{ruleId}` | update | `budget_handler.go:95` |
| DELETE | `/agent-network/budget-rules/{ruleId}` | delete | `budget_handler.go:135` |
| GET | `/agent-network/settings` | read | `settings_handler.go:53` (200+`null` if no row) |
| PUT | `/agent-network/settings` | update | `settings_handler.go:27` |
| GET | `/agent-network/consumption` | read | `consumption_handler.go:21` |
## gRPC RPCs added (or modified)
| RPC | Direction | Trigger |
| --- | --------- | ------- |
| `CheckLLMPolicyLimits` | proxy→mgmt unary | Pre-flight gate; returns allow/deny, selected policy, attribution group, window, deny code+reason (`proxy.go:259-301`). `Unimplemented` when limits service is nil. |
| `RecordLLMUsage` | proxy→mgmt unary | Post-flight write of tokens+cost against policy-window dimensions + every applicable account budget rule (`proxy.go:303-349`). `window_seconds==0` ⇒ no policy cap, only account fan-out runs. |
| `GetMappingUpdate`/`SendServiceUpdate` (stream) | mgmt→proxy | Snapshot (`proxy.go:752-780`) now appends `SynthesizeServicesForCluster`. Live updates use `SendServiceUpdateToCluster` + `shallowCloneMapping`. |
## Architecture & flow
### HTTP request lifecycle
```mermaid
sequenceDiagram
participant DB as Dashboard
participant R as gorilla.Router (/api)
participant H as handler (agentnetwork)
participant M as agentnetwork.Manager
participant S as store.Store
participant AM as accountManager (StoreEvent)
DB->>R: POST /api/agent-network/providers
R->>H: createProvider (auth mw sets UserAuth)
H->>H: GetUserAuthFromContext + validate(req)
H->>M: CreateProvider(userID, provider, bootstrapCluster)
M->>M: requirePermission(AgentNetwork, Create)
M->>S: SaveAgentNetworkProvider
M->>AM: StoreEvent(AgentNetworkProviderCreated)
M-->>H: created provider
H-->>DB: 200 + api.AgentNetworkProvider JSON
```
### Synth-service delivery via gRPC
```mermaid
sequenceDiagram
participant P as Proxy
participant G as ProxyServiceServer
participant SM as service.Manager (persisted)
participant SA as synthesizerAdapter
participant AN as SynthesizeServicesForCluster
participant ST as store.Store
Note over P,G: Initial snapshot
P->>G: GetMappingUpdate (stream open)
G->>SM: GetServicesForCluster(conn.address)
SM-->>G: persisted []*Service
G->>SA: SynthesizeServicesForCluster(conn.address)
SA->>AN: SynthesizeServicesForCluster(store, clusterAddr)
AN->>ST: walk every account; read providers/policies/settings
AN-->>SA: in-memory []*Service
SA-->>G: []*Service
G->>P: response (persisted + synth)
Note over G,P: Per-request live update
G->>G: SendServiceUpdateToCluster(update, clusterAddr)
G->>G: shallowCloneMapping(update) %% Private MUST survive
G->>P: response with single mapping
```
End-to-end: HTTP write persists rows and emits an activity event; the manager then triggers `proxyController.SendServiceUpdate` so proxies re-render. **The snapshot path is the only one that calls into the synthesiser** — on stream open it pulls persisted services then appends synth services for the cluster. Synth services are never persisted. For OIDC/session/tunnel-peer flows, `getServiceByDomain` falls back to `SynthesizeServicesForCluster(clusterFromDomain(domain))` when persisted lookup misses (`proxy.go:1763-1793`). The network_map contribution is orthogonal: per-peer compute prepends the same synth services to `account.Services` before `InjectProxyPolicies`.
## Permissions model added
- `permissions/modules/module.go:22` adds `AgentNetwork Module = "agent_network"`, registered in `All` (`module.go:42`). Standard `operations.{Read,Create,Update,Delete}` matrix.
- Handlers don't call `permissionsManager` directly — they extract `UserAuth` and delegate to `agentnetwork.Manager`, which gates every mutation through `requirePermission` (`manager.go:168, 308, 549`, etc.). Confirm your role-set provider has `agent_network` rows for owner/admin/user/billing-admin before merging.
- `getCatalogProviders` (`providers_handler.go:43`) intentionally skips RBAC — catalog is global static data.
## Activity codes added
`activity/codes.go:244-274` adds Activities 125-137 + string/code mappings (`codes.go:428-444`), following `<domain>.<resource>.<action>` (e.g., `agent_network.provider.create`). Audit-log exporters / SIEM forwarders need to know the new codes.
## Invariants
- **Synth services are never persisted.** Snapshot appends after `serviceManager.GetServicesForCluster` (`proxy.go:761-770`); network_map prepends before `InjectProxyPolicies` (`controller.go:117-126`).
- **`shallowCloneMapping` must round-trip every `ProxyMapping` field except `AuthToken`** — `proxy_clone_test.go:50-58` enforces via `gproto.Equal`. The bug it guards: a missing `Private` made every MODIFIED arrive `private=false`, the proxy skipped `ValidateTunnelPeer`, `UserGroups` stayed empty, `llm_router` denied `no_authorised_provider`; a restart "fixed" it because the snapshot uses the original mapping.
- **Limit-window floor is 60s** (`policies_handler.go:189-220`); enabled cap with both per-group and per-user at zero is rejected. Budget rules reuse the same validator (`budget_handler.go:170`).
- **Manager is optional at boot.** `NewAPIHandler` registers routes only when non-nil (`handler.go:129`); `ProxyServiceServer` returns `Unimplemented` from both RPCs when limits service is unwired (`proxy.go:262-265, 306-309`).
- **Settings GET on an unbootstrapped account returns 200 + `null`** (`settings_handler.go:65-72`) — not 404.
## Things to scrutinize
### Correctness
- **`injectAllProxyPolicies` runs on every per-peer compute**: `controller.go:163, 309, 415, 681`. `sendUpdateAccountPeers` is the target of the buffered fan-out — synth runs once per debounced account-update tick **and** once per direct `UpdateAccountPeer`. Cost is O(providers + policies × users-per-group) per account under `LockingStrengthNone`. No per-account synth cache — verify it fits the buffer interval for your largest tenant.
- **`clusterFromDomain` strips at the first `.`** (`proxy.go:1784-1792`). A zero-dot domain returns `""` and the synth call walks every account. Confirm no path reaches this with a malformed/internal domain.
- **Account-budget `RecordConsumption` fans out even when `window_seconds == 0`** (`proxy.go:341-348`) — intentional. Verify the proxy never sends `RecordLLMUsage` for a request that wasn't actually allowed.
### Security
- Every handler extracts `UserAuth` via `nbcontext.GetUserAuthFromContext` before any work. Routes live behind the standard `/api` mux; bypass list is not extended.
- `CheckLLMPolicyLimits` / `RecordLLMUsage` ride the existing **proxy → mgmt** gRPC connection auth. No additional token check inside the RPCs — they trust the connection. Confirm the proxy-side token-verification interceptor in this package gates both.
- `RecordLLMUsage` only validates `account_id != ""` (`proxy.go:317-319`). A compromised proxy can attribute cost to any account in its cluster — was already true for prior RPCs but is louder now that data drives denials.
### Concurrency
- `SetAgentNetworkSynthesizer` / `SetAgentNetworkLimitsService` write under `s.mu.Lock`; read paths copy the interface under read lock (`proxy.go:236-247, 260-263, 304-307`). Same pattern as existing `serviceManager`/`proxyController` setters.
- Manager writes use `LockingStrengthUpdate`; synth reads use `LockingStrengthNone` — read-after-write via the proxy snapshot can observe a stale view by up to one fan-out tick.
- Network_map controller is single-threaded per account; cross-account is parallel.
### Backward compatibility
- `proxy_clone_test.go` is the regression net; any new `ProxyMapping` field must be cloned or explicitly nulled in the test.
- `AccessLogEntry` adds indexed `AgentNetwork bool` — implicit AutoMigrate; deploy story must handle table-rewrite cost on high-volume access-log tables.
- `TargetOptions` gains seven `omitempty` JSON fields (`service.go:69-94`); on-wire shape stays compatible. `targetOptionsToProto` tests all fields when deciding nil (`service.go:551-556`).
- `NewAPIHandler` signature changes — every caller must pass `agentNetworkManager`; `nil` is supported.
### Observability
- 13 new activity codes via `accountManager.StoreEvent` in the manager — confirm dashboard's audit-log UI maps them.
- `AccessLogEntry.AgentNetwork` is indexed for the dashboard's agent-network log filter.
- New RPCs log at error level on store/selector failures (`proxy.go:284, 327, 332, 348`). Snapshot synth failures degrade to warnings — stream is not aborted (`proxy.go:765`).
## Test coverage
| Test | Locks down |
| ---- | ---------- |
| `handlers_test.go::TestPolicyHandler_WindowSecondsRoundTrip` | GET carries `window_seconds`; legacy `window_hours`/`window_days` absent. |
| `handlers_test.go::TestPolicyHandler_RejectsSubMinuteWindow` | POST `<60s` returns 4xx. |
| `handlers_test.go::TestConsumptionHandler_EmptyAccountReturnsArray` | `/consumption` returns `[]` — never null. |
| `handlers_test.go::TestConsumptionHandler_PopulatedAccountListsRows` | RecordConsumption×2 surfaces both with correct tokens/cost/window. |
| `budget_handler_test.go::TestBudgetRuleHandler_RoundTrip` | Targets + PolicyLimits shape round-trip. |
| `budget_handler_test.go::TestBudgetRuleHandler_ListReturnsArray` | Empty-list shape. |
| `budget_handler_test.go::TestBudgetRuleHandler_{RejectsMissingName,RejectsSubMinuteWindow}` | Validation rejections are 4xx. |
| `budget_handler_test.go::TestSettingsHandler_GetExposesCollectionToggles` | All four toggles + computed `Endpoint`. |
| `proxy_clone_test.go::TestShallowCloneMapping_PreservesAllFieldsExceptAuthToken` | Future-proofs clone; every field round-trips, `AuthToken` dropped. |
Handler tests use a real sqlite store + real manager + always-allow permissions mock (`handlers_test.go:53-75`). Create/update/delete success paths flow through `accountManager.StoreEvent` which the fixture doesn't wire — covered by manager-level no-mock tests outside this module.
## Known limitations / explicit non-goals
- No pagination on any list endpoint; no bulk endpoints.
- Synth result is not cached — every snapshot and every per-peer compute repeats the store walk.
- `getSettings` returning `200 + null` is a deliberate dashboard concession.
- No rate-limiting beyond the global `/api` rate limiter.
## Cross-references
- Upstream: [shared/api](10-shared-api.md), [management/agentnetwork](21-management-agentnetwork.md), [management/store](20-management-store.md)
- Downstream: [proxy/runtime](33-proxy-runtime.md)
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
- Top-level: [../00-overview.md](../00-overview.md)

View File

@@ -0,0 +1,215 @@
# proxy/middleware-framework — generic plugin system
> **Risk level:** **High** — every proxied request transits this chain. Budget exhaustion, panic recovery, or chain-close bugs hit the hot path for all targets, not just agent-network ones.
> **Backward-compat impact:** Additive at the proxy. The `middleware` and `bodytap` packages are new (`proxy/internal/middleware/middleware.go:1`, `proxy/internal/middleware/bodytap/request.go:13`); existing proxy targets keep working until a chain is bound to them via `Manager.Rebuild`.
This module is the **framework only** — no LLM/agent-network domain knowledge is required, since every example built into it is generic.
## Module boundary
This module is the **framework only**: slots, chains, registry, dispatcher, accumulator, body-tap, output filters. No middleware *implementation* lives here — those land in `proxy/internal/middleware/builtin/*` (covered in module 31). The package contract is:
1. The proxy hands a `Manager` to its config-apply path. The synth pushes per-path `PathTargetBinding` lists (`proxy/internal/middleware/manager.go:26`) into `Manager.Rebuild`, which resolves each spec via the `Registry`/`Resolver` (`proxy/internal/middleware/registry.go:81-121`) and produces an immutable `Chain` keyed by `serviceID|pathID` (`proxy/internal/middleware/manager.go:410-412`).
2. The reverse-proxy handler captures the request body via `bodytap.CaptureRequest`, calls `Chain.RunRequest`, applies returned mutations (already filtered by `chain.applyMutations`), forwards to the upstream behind a `bodytap.CapturingResponseWriter`, then calls `Chain.RunResponse` and `Chain.RunTerminal`.
3. Middlewares are inert plugins that receive a deep-cloned `Input` and return an `Output` whose decision/mutations are clamped by the dispatcher's `filterOutput` (`proxy/internal/middleware/dispatcher.go:149-172`).
Everything that crosses the framework boundary in either direction is value-typed and deep-copied — middlewares cannot mutate the live request directly, and the framework cannot inadvertently leak middleware-owned slices into the request hot path.
## Files
| Path | Role |
| ---- | ---- |
| `proxy/internal/middleware/middleware.go` | `Middleware` + `Factory` interfaces. |
| `proxy/internal/middleware/types.go` | `Slot`, `FailMode`, `Decision`, all limit constants, `Input`/`Output`/`Mutations`/`UpstreamRewrite`/`AuthHeader` value types. |
| `proxy/internal/middleware/spec.go` | Apply-time `Spec` (validated wire shape + runtime-injected fields) and `Clone`. |
| `proxy/internal/middleware/registry.go` | `Registry` (factory map, RWMutex) and `Resolver` (Spec → bound `Middleware`). |
| `proxy/internal/middleware/manager.go` | `Manager`, `chainTable` reverse index, `Rebuild`/`Invalidate*`, async chain close. |
| `proxy/internal/middleware/chain.go` | `Chain.RunRequest`/`RunResponse`/`RunTerminal`, mutation gating, `cloneInputFor`. |
| `proxy/internal/middleware/chain_test.go` | Metadata threading, LIFO response order, rewrite gating, UserGroups propagation, terminal accumulation. |
| `proxy/internal/middleware/dispatcher.go` | Timeout/panic recovery, fail-mode, error classification, `filterOutput`. |
| `proxy/internal/middleware/decision.go` | `RenderDenyResponse`, deny-code regex, status clamp. |
| `proxy/internal/middleware/headerpolicy.go` | Compile-in header denylist + `FilterHeaderMutations`. |
| `proxy/internal/middleware/bodypolicy.go` | `ValidateBodyReplace` / `ApplyBodyReplace` smuggling guards. |
| `proxy/internal/middleware/keys.go` | Metadata key namespace constants. |
| `proxy/internal/middleware/metadata.go` | `Accumulator` — allowlist, per-mw/per-request byte caps, redaction. |
| `proxy/internal/middleware/metrics.go` | OTel instrument bundle (`proxy.middleware.*`). |
| `proxy/internal/middleware/redaction.go` | `Scan` — PEM/JWT/AWS/bearer/Luhn-validated CC patterns. |
| `proxy/internal/middleware/bodytap/request.go` | Capture + replay reader, `Budget` semaphore, bypass reason codes. |
| `proxy/internal/middleware/bodytap/response.go` | `CapturingResponseWriter` (tee with `PassthroughWriter` for Flusher/Hijacker preservation). |
## Slot model
Three slots, declared per-middleware exactly once (`proxy/internal/middleware/types.go:27-41`):
- **`SlotOnRequest`** (`Slot=1`) — runs **before** the upstream call, in registration order. May `DecisionDeny`, may emit `Mutations` (header add/remove, body replace, `UpstreamRewrite`) when both `Spec.CanMutate` and `Middleware.MutationsSupported()` are true. May emit metadata. Each middleware in the slot sees metadata that earlier ones in the same slot just emitted (`proxy/internal/middleware/chain.go:144-178`) — this is how the framework gives middlewares an intra-slot side channel without a global bag.
- **`SlotOnResponse`** (`Slot=2`) — runs **after** the upstream returns, in **reverse** registration order. Cannot deny (clamped in `dispatcher.filterOutput`, `proxy/internal/middleware/dispatcher.go:153-157`). May still mutate response headers in principle, but the current chain only forwards `RewriteUpstream` from on_request, so on_response mutations are observe-only in practice. Threads the same per-slot metadata view as on_request.
- **`SlotTerminal`** (`Slot=3`) — runs **after** every on_response middleware has emitted, in registration order. Sees the full accumulated bag plus prior terminal emissions (`chain.go:221-245`). Cannot deny, cannot mutate (`dispatcher.go:168-170`). Designed for sinks (access log, metrics push, audit emitter).
Splitting a feature across slots (e.g. "parse on the way out, ship on terminal") is the explicit architectural choice — `types.go:7-15` and `types.go:22-25` make it clear no middleware participates in more than one slot.
## Architecture & flow
### Chain dispatch
```mermaid
sequenceDiagram
autonumber
participant H as proxy HTTP handler
participant BT as bodytap.CaptureRequest
participant CH as Chain
participant DI as Dispatcher
participant MW as Middleware (per slot)
participant US as Upstream
participant CW as CapturingResponseWriter
H->>BT: CaptureRequest(r, cfg, budget)
BT-->>H: body[], truncated, release()
H->>CH: RunRequest(ctx, r, Input, Accumulator)
loop on_request, registration order
CH->>CH: cloneInputFor(in, OnRequest)
CH->>DI: Invoke(ctx, spec, mw, call)
DI->>MW: mw.Invoke(callCtx, in)
MW-->>DI: Output{decision, metadata, mutations?}
DI->>DI: filterOutput (clamp deny, gate mutations)
DI-->>CH: filtered Output
CH->>CH: Accumulator.Emit (allowlist + caps + redact)
alt DecisionDeny
CH-->>H: denied, merged, rewrite
else allow
CH->>CH: applyMutations(r, m) and capture rewrite
end
end
CH-->>H: nil, merged, rewrite
H->>US: ProxyRequest (with rewrite/mutations applied)
US-->>CW: bytes (streamed, tee'd into cap-bounded buf)
CW-->>H: passthrough complete
H->>CH: RunResponse(ctx, Input{RespBody:CW.Body(),...}, acc)
loop on_response, REVERSE order (LIFO)
CH->>DI: Invoke (same wrappers)
end
H->>CH: RunTerminal(ctx, Input{Metadata:full bag}, acc)
H->>BT: release() + CW.Release()
```
### Body-tap mechanics (request + response)
```mermaid
flowchart LR
subgraph req[Request capture — bodytap.CaptureRequest]
R0[r.Body] --> R1{cfg.MaxRequestBytes > 0?\nUpgrade absent?\nContent-Type allowed?\nCL <= cap?}
R1 -- no --> R2[bypass = reason\nbody = nil\nr.Body untouched]
R1 -- yes --> R3[Budget.Acquire(cap)]
R3 -- denied --> R4[bypass=BypassBudget]
R3 -- ok --> R5[io.LimitReader(r.Body, cap+1)\nio.ReadAll]
R5 --> R6{len > cap?}
R6 -- truncated --> R7[viewable = buf[:cap]\nr.Body = replayReadCloser{buf, tail}]
R6 -- whole --> R8[r.Body = NopCloser(bytes.Reader(buf))\nclose original]
R7 --> R9[(release captured\nbudget on req end)]
R8 --> R9
end
subgraph resp[Response capture — CapturingResponseWriter]
W0[client] -.-> CW[Write(p)]
CW --> P1[PassthroughWriter.Write(p)\n— bytes leave to client first]
P1 --> P2{!stopped?}
P2 -- yes --> P3{remaining = cap - buf.Len()}
P3 --> P4[buf.Write(p[:take])\nset truncated if take<n]
P2 -- no --> P5[silent drop into the tee\n(client write already done)]
end
```
The body-tap is the highest-leak-risk surface in this module; three details matter:
1. **Request capture is "read-and-replay", not "read-and-forward".** `CaptureRequest` always swaps `r.Body` for either a `bytes.Reader` (whole body fit) or a `replayReadCloser` that replays the captured prefix then drains the remaining stream from the original body (`bodytap/request.go:178-201`). This means the **upstream still sees the full body even when the tap truncates**. The original `r.Body` is **not** closed in the truncated branch — `replayReadCloser.Close()` only closes the tail (`bodytap/request.go:199-201`), which is the same reader, so close once on request end is correct, but reviewers should confirm the upstream proxy always reads to EOF (otherwise the tail is leaked).
2. **Response capture is a write-through tee.** `CapturingResponseWriter.Write` forwards to the underlying writer **first** (`bodytap/response.go:116-117`), then tees into `buf` under its own mutex. Client never blocks on the tee. `Flusher`/`Hijacker` are preserved via the embedded `responsewriter.PassthroughWriter`. SSE/chunked streams flow through untouched; middlewares only see the bounded prefix.
3. **Budget is a single shared semaphore.** `Manager` constructs one `bodytap.Budget` at startup (`manager.go:138-144`, default `256 MiB` from `bodytap/request.go:39`). Every capture pre-acquires its full `MaxRequestBytes` / `MaxResponseBytes` from the budget regardless of actual body size; that prevents a flood of small captures from collectively exceeding the cap, but it also means a misconfigured `MaxRequestBytes = 1 MiB` with 256 concurrent requests already exhausts the default budget. Reviewers should sanity-check the operator-facing defaults that ship with synth-service.
The framework explicitly aborts capture (and increments `proxy.middleware.capture_bypass_total`) before reading the first byte when `Upgrade`/`Connection: upgrade` is set (`bodytap/request.go:120-125`), when the content-type isn't in the allowlist (`bodytap/request.go:126-128`), or when the advertised `Content-Length` already exceeds the cap (`bodytap/request.go:131-133`). This is the right place to make sure WebSocket upgrades and large file uploads never reach the buffer.
## Public contracts
- **`Middleware` interface** (`middleware.go:14-36`): `ID()`, `Version()`, `Slot()`, `AcceptedContentTypes()`, `MetadataKeys()`, `MutationsSupported()`, `Invoke(ctx, *Input) (*Output, error)`, `Close()`. `MetadataKeys()` is the **closed set** the middleware is allowed to emit — the accumulator drops anything outside it (`metadata.go:71-75`). `Close` must be idempotent (called even when `Invoke` was never reached).
- **`Factory` interface** (`middleware.go:44-47`): `ID()`, `New(rawConfig []byte) (Middleware, error)`. `RawConfig` is opaque JSON bytes on the wire (`spec.go:6-12`); each factory owns its own typed config.
- **`Decision` type** (`types.go:59-69`): `Allow=0`, `Deny=1`, `Passthrough=2`. Default-zero is permissive — important because every middleware that omits `Decision` gets `Allow`. Dispatcher clamps `Deny` to `Passthrough` outside `SlotOnRequest` (`dispatcher.go:153-157`).
- **`Mutations`** (`types.go:196-201`): `HeadersAdd`/`HeadersRemove` (filtered through `headerpolicy.go`), `BodyReplace` (gated through `bodypolicy.go`), and `RewriteUpstream`. `RewriteUpstream` is **last-write-wins** within the on_request slot (`chain.go:170-172`, locked down by `TestChain_RunRequest_LatestRewriteWins`).
- **Metadata propagation keys** (`keys.go`): all keys live in a single file and follow `^[a-z][a-z0-9_-]*(\.[a-z0-9_-]*)+$` (`metadata.go:8`). Framework-injected error tagging uses `mw.<id>.error_kind` (`keys.go:81`) so operators can distinguish framework-emitted entries from middleware-emitted ones.
## Invariants
- **Per-request context isolation.** `cloneInputFor` deep-copies every mutable field (`Headers`, `RespHeaders`, `Metadata`, `Body`, `RespBody`, `UserGroups`, `UserGroupNames`) before each invocation (`chain.go:286-308`). A misbehaving middleware that mutates `in.Headers` only corrupts its own copy.
- **Body-tap bounded by capture limit.** Request side uses `io.LimitReader(r.Body, limit+1)` (`bodytap/request.go:152`) — the `+1` is how the code detects truncation (`bodytap/request.go:160`); the surfaced buffer is sliced back down to `limit`. Response side stops teeing once `buf.Len() >= cap` (`bodytap/response.go:121-133`). Neither side can grow the buffer past the configured cap.
- **Headers/body redaction order.** Accumulator runs `Scan(value)` **before** counting cost (`metadata.go:81-82`), so the byte budgets are computed against post-redaction sizes. `Scan` order is PEM → JWT → AWS key → bearer → Luhn-validated CC (`redaction.go:25-51`) — the comment block in `redaction.go:8-13` is explicit that this is best-effort, not DLP.
- **No middleware can starve the chain.** Every invocation runs inside `context.WithTimeout(ctx, clampTimeout(spec.Timeout))` in a separate goroutine (`dispatcher.go:51-94`), with the deadline race-`select`ed against the result channel. A blocked middleware fires the timeout path, gets fail-mode'd, and `IncError(kind=timeout)`. Timeouts are clamped to `[10ms, 5s]` (`types.go:80-86`, `dispatcher.go:174-185`).
- **Panic recovery.** `recover()` captures the panic, logs only the type + a 4 KiB stack prefix (no panic value — avoids leaking secrets the middleware was processing), and produces a `panicError` that flows through fail-mode (`dispatcher.go:64-76`).
- **Chain immutability + atomic swap.** `chainTable` is cloned on every `Rebuild`/`Invalidate*` and swapped via `atomic.Pointer` (`manager.go:44-69`, `manager.go:221-300`). Readers (`ChainFor`) are lock-free; writers serialise on `writeMu`. The retired chain is `Close`-d in a background goroutine bounded by `chainCloseTimeout = 2 * MaxTimeout` (`manager.go:21-22`, `manager.go:326-346`), so in-flight invocations finish on the old chain after the swap.
## Things to scrutinize
### Correctness
- **Chain ordering deterministic from synth output?** `Manager.buildChain` iterates `b.Specs` in slice order and appends to `bound` (`manager.go:366-391`); `NewChain` then partitions by slot but **preserves slice order within each slot** (`chain.go:50-60`). So order on the wire = order observed at runtime. Synth must therefore emit specs in the intended execution order — there is no per-spec `Priority` field. Worth flagging.
- **Decision short-circuit semantics.** `RunRequest` returns immediately on `DecisionDeny` (`chain.go:164-167`) **with the metadata accumulated so far** plus the `denied.Metadata`. Callers that ignore `merged` on deny will lose framework-injected `mw.<id>.error_kind` entries. The proxy runtime is the only caller; confirm it always feeds `merged` into the access log on the deny path as well.
- **`UpstreamRewrite` `AuthHeader` bypass** (`types.go:218-235`). The `AuthHeader`/`StripHeaders` fields *intentionally* bypass the header denylist on the basis that the proxy itself rewrites auth. The denylist still blocks middleware-emitted `HeadersAdd: Authorization=...`. This is a delicate carve-out — review the runtime consumer to confirm only the trusted upstream-build path unpacks `AuthHeader`, never the generic `applyMutations` loop.
- **`replayReadCloser.Close` only closes the tail** (`bodytap/request.go:199-201`). The replay buffer doesn't own a resource, so this is correct, but it conflates "replay finished" with "underlying body closed". If a caller `Close()`s without reading to EOF, the original body is closed but the captured prefix is lost; harmless for the proxy path (upstream always reads to EOF) but worth a doc-comment.
### Security
- **Body-tap memory bounds.** Discussed above — bounded by `MaxBodyCapBytes = 1 MiB` per direction (`types.go:77`) and the shared `Budget` (default 256 MiB). The concerning case is the **deep-copy in `cloneInputFor`** (`chain.go:300-306`): every middleware invocation gets its **own copy** of `Body` and `RespBody`. A chain of N middlewares with a 1 MiB body allocates N MiB of transient bytes per request. With `MaxMiddlewaresPerChain = 16` (`types.go:103`) that's up to 16 MiB extra per in-flight request. Worth pricing into the budget model.
- **Header redaction completeness.** `denyHeaders` (`headerpolicy.go:5-17`) covers the auth/forwarding family and framing (`Content-Length`, `Transfer-Encoding`, `Trailer`). `denyHeaderPrefixes` covers `X-Authenticated-*`, `X-Forwarded-*`, `X-Remote-*`, `X-NetBird-*`. Notably absent: `Range`, `If-Match`/`If-None-Match` (mutation could cause cache poisoning), `Origin`/`Referer`. Not necessarily wrong, but worth a deliberate decision.
- **Metadata key collisions across middlewares.** The accumulator has no cross-middleware uniqueness check; two middlewares with the same key in their allowlist can both emit it, and both copies land in `merged` (`metadata.go:51-99`). Downstream consumers must tolerate duplicates. Worth documenting.
- **Deny rendering.** `RenderDenyResponse` only allows codes matching `^[a-z][a-z0-9._-]{0,63}$` (`decision.go:9`), redacts/truncates message + detail values, caps `Details` at 8 entries (`decision.go:42-50`), clamps status to `[400,499]\{401}` (`decision.go:65-73`). The deny body type is fixed; middlewares cannot inject arbitrary JSON.
### Concurrency
- **Per-request state vs shared state in factories.** Each `Factory.New` is called once per chain build; the returned `Middleware` instance is **shared across all requests** for that chain. `Invoke` must be reentrant. The framework does not enforce this — a buggy middleware that holds per-call state on the struct will silently race. Suggest a `// Invoke must be safe for concurrent use` doc on the interface.
- **`chainTable` clone-on-write** is correct, but `addChain`/`removeChain` mutate the *cloned* table before the swap (`manager.go:71-108`), and they're called under `writeMu`. Readers only ever see the post-swap pointer. Good.
- **`Chain.inflight` WaitGroup**. `Run*` does `Add(1)`/`Done()` (`chain.go:142-143`, `chain.go:194-195`, `chain.go:225-226`); `Close` waits on it bounded by ctx (`chain.go:75-85`). One concern: a *new* `RunRequest` can `Add(1)` *after* `Close` started waiting if the caller still holds a stale chain pointer. `WaitGroup` does not panic on this if the count was already > 0 at `Wait` time, but it does panic if `Add` happens after `Wait` returns and another `Wait` runs. `Close` is documented one-shot, so single-`Wait` is fine, but callers must drop the chain reference before calling `Close`. Worth a code comment near `Close`.
- **Goroutine leaks.** `Dispatcher.Invoke` spawns one goroutine per call and *always* writes to a buffered (cap=1) channel (`dispatcher.go:62-76`), so even if the timeout fires the goroutine completes its send and exits. No leak.
- **`closeChainsAsync`** detaches retired chains into a goroutine (`manager.go:326-346`). If `Manager` is never GC'd this is fine, but there's no shutdown hook to wait on outstanding closes. Reviewers should confirm the proxy shutdown path explicitly drains in-flight requests before tearing down `Manager`, or accept that the last chain-close round may be cut short on exit.
### Performance
- **Allocations per request.** `cloneInputFor` allocates new slices for `Headers`, `RespHeaders`, `Metadata`, `Body`, `RespBody`, `UserGroups`, `UserGroupNames` — once per middleware per request. For a typical 5-middleware chain on a 1 KiB body that's ~10 small slice allocs plus one `Body` copy each. Not a hot-path crisis, but `sync.Pool` for the per-call `Input` would be a natural follow-up.
- **Accumulator allocates a fresh `allowSet` per `Emit` call** (`metadata.go:55-58`). One per middleware per slot pass = up to 48 per request. Cheap, but worth noting.
- **Regex cost.** `Scan` runs five regex passes on every accepted metadata value (`redaction.go:25-51`). Bounded by `MaxMetadataValueBytes = 4 KiB` so worst case is small.
### Observability
- **Per-middleware metrics.** `proxy.middleware.requests_total{middleware,target_id,outcome}` (`metrics.go:34-41`), `duration_ms`, `invocations_total`, `errors_total{kind}`, `metadata_rejected_total{reason}`, `header_mutation_blocked_total{header}`, `capture_bypass_total{reason}`. Comprehensive surface; operators can alert on `errors_total{kind=panic}` and `errors_total{kind=timeout}` separately. **Latency histogram is in milliseconds with default OTel buckets** — for a 10ms5s timeout range default buckets cover OK, but a custom bucket set centred on 1500ms would resolve the agent-network response-parser tail better.
- **Decision logs.** Panic logs (`dispatcher.go:69`) include `request_id`, type, and stack but not the panic value (safe). `Chain.Close` logs middleware-close errors at debug (`chain.go:91`). `applyMutations` logs body-replace rejections at warn (`chain.go:278`). No log on the deny path itself — by design, since the access-log terminal middleware is expected to record outcomes.
## Test coverage
| Test file | Locks down |
| --------- | ---------- |
| `proxy/internal/middleware/chain_test.go:77` | `RunRequest` threads metadata across on_request middlewares (regression for the "later mw can't see earlier mw's emissions" bug). |
| `chain_test.go:110` | `RunResponse` reverse-order threading. |
| `chain_test.go:142` | `cost_meter`-shaped scenario: response_parser registered after cost_meter still emits *before* cost_meter sees the bag (guards the `cost.skipped=missing_tokens` regression). |
| `chain_test.go:178` | `UpstreamRewrite` last-write-wins. |
| `chain_test.go:206` | No middleware emits → nil rewrite. |
| `chain_test.go:224` | Rewrite filtered when `CanMutate=false`. |
| `chain_test.go:245` | `Input.UserGroups` propagates verbatim through `cloneInputFor`. |
| `chain_test.go:304` | Terminal middlewares see the full accumulated bag + prior terminal emissions. |
**Gaps** worth raising with the author:
- No direct test for `Dispatcher.Invoke` timeout / panic / fail-mode behaviour at the framework level (covered indirectly by built-in tests, but a unit test pinning `errors_total{kind=...}` labels would be cheap insurance).
- No test for `bodytap.CaptureRequest` truncated replay (the upstream-sees-full-body invariant is exactly the kind of thing a regression would silently break).
- No test for `Budget` exhaustion behaviour under concurrency.
- No test for `Manager.InvalidateMiddleware` + `LiveServiceCheck` race (the auth-revocation race the comment at `manager.go:33-38` calls out is the load-bearing reason for `LiveServiceCheck`).
## Known limitations / explicit non-goals
- **No middleware-to-middleware RPC.** Side-channel is metadata only.
- **No streaming body inspection.** Middlewares see a bounded prefix; SSE / chunked parsing happens against that prefix in the response middleware.
- **No per-spec priority.** Order is registration order in the spec slice.
- **No retry / circuit-breaker** on middleware errors. Fail-mode is binary (open/closed) and per-spec.
- **Mutations cannot rewrite the request URL path or query** — only `RewriteUpstream` can change scheme/host (+ optional path replacement, see `types.go:218-235`).
- **Redaction is best-effort.** Explicitly documented in `redaction.go:8-13`. Not a DLP solution.
## Cross-references
- Upstream wire shape: [../modules/10-shared-api.md](10-shared-api.md) (Spec/RawConfig encoding from management).
- Built-in middlewares using this framework: [../modules/31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md).
- Runtime wiring (where `Manager`, `Chain`, and `bodytap` are consumed by the HTTP handler): [../modules/33-proxy-runtime.md](33-proxy-runtime.md).
- End-to-end request flow including capture + chain dispatch: [../01-end-to-end-flows.md](../01-end-to-end-flows.md).
- Top-level architecture: [../00-overview.md](../00-overview.md).

View File

@@ -0,0 +1,365 @@
# proxy/middleware-builtin — the LLM chain
The registry-mounted middleware set the proxy executes on every agent-network
LLM request. The two highest-blast-radius areas are the **capture-pointer
semantics** and the **limit_check ⇒ limit_record** record-once invariant.
Sibling module: [32-proxy-llm-parsers.md](./32-proxy-llm-parsers.md) — the SDK
adapters + pricing catalog this chain delegates to.
---
## Module boundary
This module is the registry-mounted middleware set the proxy executes on
every agent-network LLM request. Each sub-package registers itself via
`init()`
([builtin.go:3234](../../../proxy/internal/middleware/builtin/builtin.go));
the proxy server anonymous-imports the set
([all_test.go:1119](../../../proxy/internal/middleware/builtin/all_test.go))
so the registry is populated at boot. The chain is wired by the management
synthesiser and executed by the framework
(`proxy/internal/middleware/{chain,dispatcher,accumulator}.go` — both out
of scope). Everything here reads from / writes to one envelope: the
`middleware.KV` metadata bag plus `middleware.Mutations` for header/body
rewrites.
## The 8 middlewares
| Name | Slot | Inputs (metadata read) | Outputs (metadata written) | Side effects |
|---|---|---|---|---|
| `llm_request_parser` | OnRequest | `Input.{URL,Body,BodyTruncated}` | `llm.{provider,model,stream,request_prompt_raw,capture_truncated}` | none |
| `llm_router` | OnRequest | `llm.model`, `Input.{URL,UserGroups}` | `llm.{resolved_provider_id,authorising_groups}`, `llm_policy.{decision,reason}` | upstream rewrite + auth strip/inject |
| `llm_limit_check` | OnRequest | `llm.{resolved_provider_id,model}`, `Input.{AccountID,UserID,UserGroups}` | `llm.{selected_policy_id,attribution_group_id,attribution_window_seconds}`, `llm_policy.{decision,reason}` | gRPC `CheckLLMPolicyLimits` |
| `llm_identity_inject` | OnRequest | `llm.{resolved_provider_id,authorising_groups}`, `Input.{UserEmail,UserID,UserGroups,UserGroupNames}` | none | header strip/inject + optional body rewrite |
| `llm_guardrail` | OnRequest | `llm.{model,request_prompt_raw}` | `llm_policy.{decision,reason}`, `llm.request_prompt` | none (model allowlist deny) |
| `llm_response_parser` | OnResponse | `llm.provider`, `Input.{RespHeaders,RespBody,Status}` | `llm.{input,output,total,cached_input,cache_creation}_tokens`, `llm.response_completion` | none |
| `cost_meter` | OnResponse | `llm.{provider,model}`, token buckets | `cost.usd_total` or `cost.skipped` | pricing lookup |
| `llm_limit_record` | OnResponse | `llm.{attribution_group_id,attribution_window_seconds,input_tokens,output_tokens}`, `cost.usd_total` | none | gRPC `RecordLLMUsage` |
[all_test.go:2640](../../../proxy/internal/middleware/builtin/all_test.go)
locks the ID set; adding or removing one is a conscious extension.
## Files
| File | LOC | Notes |
|---|---:|---|
| `builtin.go` | 86 | Registry + `FactoryContext` (ctx, data dir, meter, logger, mgmt client) |
| `all_test.go` | 41 | Locks the 8-ID registry surface |
| `agentnetwork_chain_integration_test.go` | 319 | Live sqlite + real gRPC bufconn; gate→recorder wire path |
| `llm_request_parser/*` | 162 / 66 / 356 | Provider detection, body parse, prompt extraction with capture-pointer gating |
| `llm_router/*` | 385 / 84 / 586 | Three-pass route selection (model → groups → path-prefix) |
| `llm_limit_check/*` | 196 / 38 / 182 | Pre-flight `CheckLLMPolicyLimits` (2s, fail-open) |
| `llm_identity_inject/*` | 440 / 108 / 666 | HeaderPair (LiteLLM) + JSONMetadata (Portkey) + ExtraHeaders |
| `llm_guardrail/*` | 176 / 82 / 75 / 219 / 217 | Model allowlist + optional prompt capture with PII redaction |
| `llm_response_parser/*` | 258 / 222 / 43 / 433 / 169 / 111 | Buffered + SSE accumulation; AWS event-stream accumulator (`streaming_bedrock.go`) for Bedrock; capture-pointer gates completion emit |
| `cost_meter/*` | 181 / 84 / 439 | Token → USD via `proxy/internal/llm/pricing` |
| `llm_limit_record/*` | 144 / 35 / 191 | Post-flight `RecordLLMUsage` (5s, debug-on-error) |
## Per-middleware
### llm_request_parser
Detects the LLM provider via `llm.DetectParser` (URL sniff) or by name via
`llm.ParserByName` when synthesiser stamps `provider_id`
([middleware.go:9699](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)).
**Path-routed providers short-circuit first:** `parseVertexPath` and
`parseBedrockPath` ([middleware.go:8594](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go))
pull the model + vendor out of the URL before parser selection runs — Vertex
from `/v1/projects/.../publishers/{pub}/models/{model}:{action}` (publisher →
vendor via `vertexPublisherVendor`), Bedrock from `/model/{id}/{action}` with
`normalizeBedrockModel` stripping the region prefix + version suffix. See
[50-path-routed-providers.md](./50-path-routed-providers.md) for the full path
grammar. For body-routed providers it decodes the body into `RequestFacts`
(model + stream) and extracts the prompt. On
`capture_prompt=true` (or absent — see capture-pointer semantics below) the
prompt is run through `llm_guardrail.RedactPII` when `redact_pii=true` and
truncated rune-safely to 3500 bytes
([middleware.go:109122](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)).
**Key invariant:** redaction is parser-side, not guardrail-side — access-log
reads `llm.request_prompt_raw` directly.
### llm_router
Three-pass route selection in `matchRoute`
([middleware.go:241300](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)):
filter by `Models` claim → vendor-pin (a vendor-tagged request never crosses to
another vendor's route) → filter by `AllowedGroupIDs` intersection → model
precedence over path → tie-break by longest `UpstreamPath` prefix match.
Model-miss returns `llm_policy.model_not_routable`; known-but-unauthorised
returns `llm_policy.no_authorised_provider`. **Key invariant:** auth-header
strip+inject rides on `UpstreamRewrite.{StripHeaders,AuthHeader}`
([middleware.go:606646](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
— NOT `HeadersAdd/HeadersRemove` — because the framework's mutation gate
blocks `Authorization` on the generic header path.
**Path-routed providers route before the model table.** `Invoke` checks
`isVertexPath` / `isBedrockPath`
([middleware.go:138216](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
ahead of the model lookup, so a path-carried model can't be claimed by a
same-vendor body-routed provider. `matchPathRoute` enforces the route's `Models`
allowlist (empty = catch-all) even though the model came from the URL.
Two path-only behaviours:
- **Vertex unmeterable publisher** — when `llm_request_parser` emits no
`llm.provider` (e.g. Gemini/`google`), the router denies with
`llm_policy.unmeterable_publisher` (403) rather than forward it uncounted.
- **GCP token minting** — when the route carries `GCPServiceAccountKeyB64`
(set from a `keyfile::` api_key), `gcpBearer` mints + caches a short-lived
OAuth2 token per request instead of injecting a static value; a bad key or
unreachable token endpoint denies with `llm_policy.upstream_auth_failed`
(502). Bedrock uses its static bearer token directly (no minting).
- **`/bedrock` prefix** — an optional `/bedrock` gateway-namespace prefix is
accepted and stripped via `RewriteUpstream.StripPathPrefix` so the native
`/model/...` path reaches the upstream.
Full treatment in [50-path-routed-providers.md](./50-path-routed-providers.md).
### llm_limit_check
Pre-flight gate. Reads `llm.resolved_provider_id`, calls
`CheckLLMPolicyLimits` with a 2s context timeout
([middleware.go:24, 97106](../../../proxy/internal/middleware/builtin/llm_limit_check/middleware.go)),
on allow stamps `llm.selected_policy_id`, `llm.attribution_group_id`,
`llm.attribution_window_seconds`. **Key invariant:** fail-open. Nil
`MgmtClient`, empty provider id, or RPC error returns `allowNoAttribution()`
— management outage doesn't take down every LLM request. Operators audit via
the access-log; a future flag may switch this to fail-closed.
### llm_identity_inject
Dispatches per-rule between LiteLLM-shaped `HeaderPair`
([middleware.go:169](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go))
and Portkey-shaped `JSONMetadata`
([middleware.go:292](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go)).
Identity is the peer's email (or `UserID` fallback); tags are the
**authorising-groups intersection** emitted by `llm_router`, not the full
`UserGroups` — a peer in 5 groups authorised under 1 only tags as that 1.
**Anti-spoof:** every `HeadersAdd` is preceded by a `HeadersRemove` of the
same name; the framework runs `Remove` before `Add` so client-supplied
identity never reaches the upstream. Body-level inject (`tags_in_body`,
`end_user_id_in_body`) is skipped on empty / truncated / non-JSON bodies so
header attribution stays intact.
### llm_guardrail
Model allowlist deny + optional prompt-capture-with-redaction. Allowlist
match is case-insensitive via `normaliseModel`; empty allowlist disables the
check. Prompt capture reads `llm.request_prompt_raw` and emits
`llm.request_prompt` only when `prompt_capture.enabled`
([middleware.go:149165](../../../proxy/internal/middleware/builtin/llm_guardrail/middleware.go)).
**Key invariant:** `RedactPII` is the exported function the parsers call —
single PII contract across all three keys.
### llm_response_parser
Buffered and SSE paths share one `Invoke`
([middleware.go:102127](../../../proxy/internal/middleware/builtin/llm_response_parser/middleware.go)):
content-type sniffing dispatches to `invokeBuffered` (JSON, status<400) or
`invokeStreaming` (text/event-stream, partial bodies tolerated). Streaming
delegates to `accumulateStream`
([streaming.go:2130](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go))
using `llm.NewScanner`. A third path, `accumulateBedrockStream`
([streaming_bedrock.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock.go)),
decodes the AWS binary event-stream (`application/vnd.amazon.eventstream`)
returned by Bedrock's `-stream` actions — InvokeModel `chunk` frames wrap a
base64 Anthropic event, Converse frames carry text + a trailing usage block.
Cached / cache-creation buckets emit only when non-zero, preserving the existing
token schema.
### cost_meter
Reads `llm.provider` + `llm.model` + token buckets, looks up per-1k rate via
`pricing.Loader`, emits `cost.usd_total` or a closed-set `cost.skipped`
reason (`missing_provider/model/tokens`, `unparseable_tokens`, `zero_tokens`,
`unknown_model`). Loader's hot-reload goroutine is bound to proxy-lifetime
context via `startReloader`. **Key invariant:** provider-shape switch lives
in `pricing.Table.Cost` (sibling doc) — `cost_meter` stays provider-agnostic.
### llm_limit_record
Post-flight write. Always returns `DecisionAllow`; response has already been
served so RPC errors mustn't surface (logged at `Debugf`). Skip-on-no-signal
at line 81 (zero tokens + zero cost). **Key invariant:** the
skip-on-missing-attribution guard at line 98 is a safety net independent of
the framework's deny short-circuit — if the gate denied and the framework
still runs the recorder, the recorder skips on absent
`UserID`+`groupID`+`UserGroups` and no phantom counter materialises.
## Full-chain diagram (canonical order)
```mermaid
flowchart TD
A[HTTP request] --> B[llm_request_parser<br/>OnRequest]
B -->|llm.provider, llm.model,<br/>llm.stream, llm.request_prompt_raw| C[llm_router<br/>OnRequest]
C -->|llm.resolved_provider_id,<br/>llm.authorising_groups,<br/>upstream rewrite + auth| D[llm_limit_check<br/>OnRequest]
D -->|deny path| Z1[403 llm_policy.*]
D -->|allow + llm.selected_policy_id,<br/>llm.attribution_group_id,<br/>llm.attribution_window_seconds| E[llm_identity_inject<br/>OnRequest]
E -->|header strip+inject<br/>+ optional body rewrite| F[llm_guardrail<br/>OnRequest]
F -->|deny: model_blocked| Z2[403 llm_policy.model_blocked]
F -->|allow + llm.request_prompt| G[upstream LLM call]
G --> H[llm_response_parser<br/>OnResponse]
H -->|llm.{input,output,total,cached_input,cache_creation}_tokens,<br/>llm.response_completion| I[cost_meter<br/>OnResponse]
I -->|cost.usd_total or cost.skipped| J[llm_limit_record<br/>OnResponse]
J --> K[response to client]
```
## limit_check ⇒ limit_record record-once invariant
```mermaid
sequenceDiagram
participant LC as llm_limit_check
participant M as management gRPC
participant U as upstream LLM
participant LR as llm_limit_record
participant DB as sqlite consumption table
LC->>M: CheckLLMPolicyLimits (2s)
alt allow
M-->>LC: selected_policy_id, attribution_group_id, window_s
LC->>U: stamps attribution metadata
U-->>LR: response + tokens (via llm_response_parser + cost_meter)
LR->>M: RecordLLMUsage (5s, debug-on-error)
M->>DB: increment (user, group, window) row
else deny
M-->>LC: llm_policy.token_cap_exceeded
Note over LR: framework short-circuits; even if invoked,<br/>recorder skips on absent UserID+groupID+UserGroups
else mgmt nil / rpc error
LC-->>LC: allowNoAttribution() — fail open
Note over LR: no window_s ⇒ recorder books only account-level<br/>budget rules (which run independently)
end
```
The integration test
[agentnetwork_chain_integration_test.go](../../../proxy/internal/middleware/builtin/agentnetwork_chain_integration_test.go)
exercises all three branches against a real sqlite store + bufconn gRPC —
no mocks. Tests: `TestChain_AllowPath_StampsAttributionAndRecordsCounter`
(line 130), `TestChain_DenyPath_GateRejectsAndNoConsumptionWritten` (line
207), `TestChain_CapExhaustTransition` (line 265).
## Public contracts (per-middleware JSON config)
| Middleware | Config shape |
|---|---|
| `llm_request_parser` | `{provider_id?, redact_pii?, capture_prompt?: *bool}` ([factory.go:1937](../../../proxy/internal/middleware/builtin/llm_request_parser/factory.go)) |
| `llm_router` | `{providers: [{id, models, upstream_scheme, upstream_host, upstream_path?, auth_header_name, auth_header_value, allowed_group_ids}]}` |
| `llm_limit_check` | `{}` — pulls `MgmtClient` from `FactoryContext` |
| `llm_identity_inject` | `{providers: [{provider_id, header_pair?|json_metadata?, extra_headers?}]}` |
| `llm_guardrail` | `{model_allowlist: []string, prompt_capture: {enabled, redact_pii}}` |
| `llm_response_parser` | `{redact_pii?, capture_completion?: *bool}` |
| `cost_meter` | `{pricing_path?}` (basename inside data-dir; defaults `pricing.yaml`) |
| `llm_limit_record` | `{}` — same pattern as `llm_limit_check` |
All factories accept empty / null / `{}` / whitespace as zero-value config;
only structurally invalid JSON is rejected so misconfig surfaces at chain
build time.
## Invariants
1. **limit_check ↔ limit_record paired.** They MUST appear together. Gate
stamps attribution metadata on the request leg; recorder reads it on the
response leg. If a chain contains only the recorder, the
skip-on-missing-attribution guard at
[llm_limit_record/middleware.go:8187, 98103](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go)
keeps counters consistent but no enforcement runs. Only-gate means
counters never tick and headroom appears infinite.
2. **`capture_prompt` / `capture_completion` pointer semantics.** Both are
`*bool`. `nil` = "preserve legacy emit" (back-compat default for
non-agent-network callers and pre-toggle tests). `false` = suppress the
key entirely (access-log row carries zero prompt / completion content).
`true` = emit. The synthesiser sets the pointer explicitly to the
account's `EnablePromptCollection` toggle. The handling lives
in [llm_request_parser/factory.go:5561](../../../proxy/internal/middleware/builtin/llm_request_parser/factory.go)
and the symmetric [llm_response_parser/middleware.go:6268](../../../proxy/internal/middleware/builtin/llm_response_parser/middleware.go);
a missing pointer must not be treated as `false` (that would suppress
capture for legacy non-agent-network callers).
`redact_pii` is an orthogonal `bool` controlling **form** of emitted
content, not whether it's emitted.
3. **`redact_pii` is parser-side.** Both parsers import
`llm_guardrail.RedactPII` and run it BEFORE stamping the metadata bag.
Load-bearing because the access-log sink reads `llm.request_prompt_raw`
and `llm.response_completion` directly — by the time `llm_guardrail`
runs its own pass on `llm.request_prompt`, the raw key has already been
stamped. Tests: `TestInvoke_RedactPii_RedactsBeforeEmittingRawPrompt`,
`TestInvoke_RedactPii_RedactsCompletionBeforeEmit`.
4. **Metadata allowlist enforcement.** Every middleware declares
`MetadataKeys()`. The framework accumulator drops any KV outside that
allowlist. When adding a new key, also extend the docstring in
`middleware/keys.go`.
5. **Closed deny-code set.** All deny paths emit one of:
`llm_policy.model_not_routable`, `llm_policy.no_authorised_provider`,
`llm_policy.model_blocked`, `llm_policy.token_cap_exceeded`,
`llm_policy.unmeterable_publisher` (path-routed Vertex publisher with no
parser → 403), `llm_policy.upstream_auth_failed` (GCP token mint failure →
502), or the management-supplied code on `llm_limit_check`. These surface
verbatim; arbitrary middleware text never reaches the wire.
## Things to scrutinise
**Correctness.** `llm_router` model match treats an empty `Models` slice as
"claim every model"
([middleware.go:238248](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
for gateway-style providers — confirm no real provider record ships with an
empty `Models` by accident. Path-prefix tie-break falls back to declaration
order when no candidate prefix-matches, so the synthesiser must emit a
deterministic order. `llm_limit_record` discards `strconv.ParseInt` errors
([middleware.go:7880](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go))
— relies on `llm_response_parser` always emitting parseable values; spot-check
the streaming partial path on truncated bodies.
**Security.** Auth headers must NEVER appear on `Mutations.HeadersAdd/Remove`
for the router — a direct headers path would bypass the framework gate. The
capture-pointer handling is the kind of place a bug ships PII to logs
silently; every synthesiser config path must set the pointer explicitly.
`llm_identity_inject` body inject silently skips on a
non-object `metadata` field
([middleware.go:262270](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go))
— header path still attributes, but body-level tag-budget enforcement
doesn't run for that request.
**Concurrency.** `cost_meter` shares a `pricing.Loader` via
`atomic.Pointer[Table]`; readers always see a consistent table. Every
middleware is a stateless value receiver. Integration test uses real bufconn
gRPC — race detector is the meaningful bar.
**Perf.** Hot path is `lookupKV` linear scan over <10 KVs; `cost_meter.Cost`
is O(1); SSE accumulation is single-pass. No map allocation per call.
**Observability.** Every deny stamps `llm_policy.decision=deny` and a
matching `llm_policy.reason` — access-log can pivot on either.
`llm_limit_record` only logs at `Debugf` on RPC failure
([middleware.go:125130](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go));
operators need an alternate signal (metric on `RecordLLMUsage` failures) for
counter accuracy.
## Test coverage
| File | Tests | Notes |
|---|---:|---|
| `all_test.go` | 1 | Registry surface lock |
| `agentnetwork_chain_integration_test.go` | 3 | Allow/deny/cap-exhaust vs live sqlite + bufconn gRPC |
| `llm_request_parser/middleware_test.go` | 18 | `provider_id` bypass, redaction, capture-pointer, rune-safe truncation |
| `llm_router/middleware_test.go` | 19 | Three-pass match, deny codes, path-prefix tie-break, header strip+inject |
| `llm_limit_check/middleware_test.go` | 6 | Allow/deny, fail-open on nil mgmt / RPC error, attribution stamping |
| `llm_identity_inject/middleware_test.go` | 28 | HeaderPair, JSONMetadata, ExtraHeaders, body inject, anti-spoof |
| `llm_guardrail/middleware_test.go` | 15 | Allowlist case-insensitivity, prompt capture toggle, deny shape |
| `llm_guardrail/redact_test.go` | 15 | Email, SSN, phone (E.164 + NA), bearer, IPv4; fixture-driven |
| `llm_response_parser/middleware_test.go` | 18 | Buffered OAI+Anthro, capture-pointer, redact, truncation |
| `llm_response_parser/streaming_test.go` | 7 | OAI usage frame, Anthro message_delta, truncated body best-effort |
| `cost_meter/middleware_test.go` | 17 | Each skip reason, provider-shape, pricing loader integration |
| `llm_limit_record/middleware_test.go` | 7 | Skip-on-no-signal, skip-on-missing-attribution, RPC failure swallowed |
## Cross-references
- Sibling: [32-proxy-llm-parsers.md](./32-proxy-llm-parsers.md) — SDK adapters
+ SSE framer + pricing loader.
- Path-routed providers (Vertex AI + Bedrock), `keyfile::` credential, GCP
token minting, `/bedrock` prefix:
[50-path-routed-providers.md](./50-path-routed-providers.md).
- Upstream config: `management/server/agentnetwork/synthesizer` (out of scope).
- Framework: `proxy/internal/middleware/{chain,dispatcher,accumulator,registry}.go`.
- Metadata key registry: `proxy/internal/middleware/keys.go`.
- gRPC surface: `proto.ProxyServiceClient.{CheckLLMPolicyLimits,RecordLLMUsage}`.

View File

@@ -0,0 +1,392 @@
# proxy/llm-parsers — SDK adapters + pricing + SSE
The runtime-agnostic LLM library: the OpenAI Responses API (`/v1/responses`)
and the older Chat Completions API (`/v1/chat/completions`), the Anthropic
Messages API (`/v1/messages`), the SSE wire format (`event:` / `data:` lines,
`\n\n` framing, CRLF tolerance), and per-provider token accounting (OpenAI's
cached-prompt **subset** vs Anthropic's cache_read **additive** model). The
pricing table's per-provider cost formula is the highest-leverage place a
small bug would silently mis-bill operators.
Sibling module: [31-proxy-middleware-builtin.md](./31-proxy-middleware-builtin.md)
— the 8 middlewares that consume this package's parsers + pricing loader.
---
## Module boundary
`proxy/internal/llm` is the runtime-agnostic LLM library shared by every
middleware that needs to understand provider-specific shapes. Zero
proxy-framework dependencies:
- `parser.go``Parser` interface, `Provider` enum, public factories
(`Parsers`, `DetectParser`, `ParserByName`).
- `openai.go` / `anthropic.go` / `bedrock.go` — per-provider `Parser` impls.
- `sse.go` — SSE scanner (`Scanner`, `Event`, `NewScanner`).
- `errors.go` — sentinels callers branch on with `errors.Is`.
- `pricing/` — embedded-default + hot-reload override table with
symlink-safe Unix loader (build-tagged stub elsewhere).
- `fixtures/` — captured request/response/stream bodies the tests replay.
The package carries zero proxy-framework dependencies so the same parsers can
be reused later by a WASM adapter
([parser.go:16](../../../proxy/internal/llm/parser.go)).
## Files
| File | LOC | Notes |
|---|---:|---|
| `parser.go` | 104 | Interface + factories + `Provider{Unknown,OpenAI,Anthropic}` enum |
| `openai.go` | 347 | Chat Completions + Completions + Responses API; cached_tokens subset |
| `openai_test.go` | 222 | 11 tests; fixture replay + cached/Responses-API matrix |
| `anthropic.go` | 172 | Messages + legacy `/v1/complete`; cache_read + cache_creation additive |
| `anthropic_test.go` | 154 | 7 tests including streaming-extraction-skipped contract |
| `bedrock.go` | 190 | AWS Bedrock InvokeModel (snake_case) + Converse (camelCase) response shapes; model lives in URL path |
| `bedrock_test.go` | — | InvokeModel + Converse usage shapes; AWS event-stream content-type → `ErrStreamingUnsupported` on buffered `ParseResponse` |
| `sse.go` | 117 | `bufio`-backed scanner; CRLF normalised; trailing-event handling |
| `sse_test.go` | 175 | 12 tests; fixture replay + multiline + size limits |
| `parser_test.go` | 53 | `Parsers()`, `DetectParser`, provider enum values |
| `errors.go` | 31 | 6 sentinels: `Err{Unknown,Unsupported}Provider/Model`, `Err{NotLLM,Malformed}Response`, `ErrStreamingUnsupported`, `ErrMalformedRequest` |
| `pricing/pricing.go` | 421 | `Loader`, `Table`, `Entry`; embedded defaults + atomic swap + mtime reload |
| `pricing/pricing_unix.go` | 69 | `O_NOFOLLOW` + fstat-from-FD + 1 MiB cap |
| `pricing/pricing_other.go` | 21 | Stub returning "not supported on this platform" |
| `pricing/pricing_test.go` | 432 | 21 tests — symlink rejection, reload race, path traversal, oversize |
| `pricing/defaults_pricing.yaml` | 85 | go:embed source of truth |
| `fixtures/*` | 2159 | OAI chat/responses/stream + Anthro messages/stream + pricing starter |
## Request body → parser dispatch
```mermaid
flowchart TD
A[HTTP request<br/>URL + JSON body] --> B{ParserByName?<br/>provider_id config set}
B -- yes --> P[matched Parser]
B -- no --> C[DetectParser]
C --> D{loop Parsers<br/>OpenAIParser, AnthropicParser}
D -- DetectFromURL match --> P
D -- no match --> X[ok=false<br/>middleware skips]
P --> E[ParseRequest body]
E -->|err: ErrMalformedRequest| Y[middleware emits provider only]
E --> F[RequestFacts<br/>model + stream]
P --> G[ExtractPrompt body]
G --> H[joinMessages<br/>extractContentParts<br/>decodeStringOrJoin]
H --> I[prompt text<br/>or empty]
F --> J[stamps llm.model + llm.stream]
I --> K[stamps llm.request_prompt_raw<br/>subject to capture_prompt gate]
```
OpenAI's URL hints
([openai.go:2733](../../../proxy/internal/llm/openai.go)) include
both `/v1/chat/completions` and the bare `/chat/completions` — the latter
covers Cloudflare AI Gateway, which rewrites the canonical version segment.
Anthropic's hints are `/v1/messages` and `/v1/complete`
([anthropic.go:1417](../../../proxy/internal/llm/anthropic.go)).
Both implementations use case-insensitive substring matching so a proxy prefix
strip / rewrite doesn't defeat detection.
`ParserByName` ([parser.go:93103](../../../proxy/internal/llm/parser.go))
is the **agent-network bypass**: the synthesiser knows which parser to use
because it built the synth service from the catalog, so it stamps
`provider_id` on the parser config and the middleware skips URL sniffing
entirely. This is what makes the same parser set work whether the request
flows to OpenAI direct, to LiteLLM, to Portkey, or to any gateway with a
non-canonical URL shape.
**Path-routed providers (Vertex AI, Bedrock) bypass both `ParserByName` and
`DetectParser`.** The model and the parser surface live in the URL path, so the
request middleware extracts them directly (`parseVertexPath` /
`parseBedrockPath`) before the parser-selection step. For Vertex the publisher
segment picks the parser (`anthropic` → Anthropic parser; `google`/Gemini →
none, request denied as unmeterable). For Bedrock the dedicated `BedrockParser`
handles the response. Full treatment in
[50-path-routed-providers.md](./50-path-routed-providers.md).
## Streaming response → SSE chunker → response parser → completion + token count
```mermaid
sequenceDiagram
participant U as upstream LLM
participant LR as llm_response_parser<br/>(OnResponse)
participant S as llm.NewScanner<br/>(SSE framer)
participant P as Parser-specific accumulator<br/>(accumulateOpenAIStream<br/>or accumulateAnthropicStream)
U-->>LR: text/event-stream<br/>(buffered prefix in RespBody)
LR->>S: NewScanner(bytes.NewReader(body))
loop until EOF or [DONE]
S-->>LR: Event{Type, Data}
LR->>P: dispatch per event.Type<br/>(OpenAI: data-only<br/>Anthropic: named events)
P-->>P: accumulate completion text<br/>track usage from final frame
end
P-->>LR: llm.Usage + completion string
LR->>LR: appendUsage stamps<br/>llm.{input,output,total,cached_input,cache_creation}_tokens
LR->>LR: truncateCompletion(3500 bytes, rune-safe)
LR->>LR: redactPII if redact_pii && captureCompletion
```
`Scanner.Next`
([sse.go:4487](../../../proxy/internal/llm/sse.go)) returns one
event per `\n\n` boundary; multiple `data:` lines join with `\n`; comment lines
(starting with `:`) are skipped per the SSE spec; a trailing event without a
closing blank line is still returned before `io.EOF` so a server that closes
the connection cleanly doesn't lose the last frame
([sse.go:5558](../../../proxy/internal/llm/sse.go)). CRLF is
normalised in `trimEOL` so fixtures captured from live servers replay
unchanged.
## Per-provider
### OpenAI
[openai.go:5467](../../../proxy/internal/llm/openai.go) defines
`openAIRequest` with three prompt fields: `messages` (Chat Completions),
`prompt` (legacy), `input` (Responses API). The decoder uses
`json.RawMessage` so each shape is parsed lazily.
`ParseResponse`
([openai.go:117146](../../../proxy/internal/llm/openai.go))
accepts both naming conventions: Chat Completions returns
`prompt_tokens`/`completion_tokens`, Responses API returns
`input_tokens`/`output_tokens`. `pickInt64` prefers Responses-API names and
falls back — same parser handles both endpoints without per-route config.
`openAICachedTokens` mirrors the fallback for
`input_tokens_details.cached_tokens` vs `prompt_tokens_details.cached_tokens`.
**Key invariant:** `CachedInputTokens` for OpenAI is a SUBSET of
`InputTokens`. The cost meter clamps to guard against malformed upstream
responses where `cached > total`.
### Anthropic
[anthropic.go:3749](../../../proxy/internal/llm/anthropic.go)
defines `anthropicRequest` covering Messages API (`system` + `messages[]`)
and legacy `/v1/complete` (`prompt` string). `ExtractPrompt` emits
`system: <text>` first when present, then per-message `role: content`.
`ParseResponse`
([anthropic.go:82104](../../../proxy/internal/llm/anthropic.go))
fills three independent token buckets: `InputTokens`, `CacheReadInputTokens`,
`CacheCreationInputTokens`. Latter two are **additive** (not subset).
`TotalTokens` sums all four so downstream dashboards render one "tokens"
number without double-counting.
`ExtractCompletion` walks `content[]` `{type, text}` parts and concatenates
non-empty text with newlines, falling back to legacy `completion`.
### Bedrock
[bedrock.go](../../../proxy/internal/llm/bedrock.go) implements the
`Parser` interface for the AWS Bedrock runtime. Bedrock is **path-routed**: the
model lives in the URL (`/model/{id}/{action}`), so the request middleware
extracts it (see [50-path-routed-providers.md](./50-path-routed-providers.md))
and `ParseRequest` is a deliberate no-op. The parser's real work is on the
response leg, covering both Bedrock body shapes:
- **InvokeModel** — vendor-native. Anthropic-on-Bedrock returns snake_case usage
(`input_tokens`, `output_tokens`, `cache_read_input_tokens`,
`cache_creation_input_tokens`) with the same additive cache buckets as
first-party Anthropic.
- **Converse** — unified camelCase (`inputTokens`, `outputTokens`,
`totalTokens`). `firstNonZero` folds the two naming conventions into one
`Usage`; when Converse omits `totalTokens` the parser sums the buckets.
`ProviderName()` returns `"bedrock"` — its own `defaults_pricing.yaml` block,
keyed by the **normalised** model id (region prefix + version suffix stripped by
the request parser). `ParseResponse` returns `ErrStreamingUnsupported` for an
AWS binary event-stream content-type (`application/vnd.amazon.eventstream`,
`isAWSEventStream`) so the caller routes to the streaming accumulator instead.
### SSE framing
`Scanner` is `bufio`-backed, 64 KiB read buffer, 1 MiB max line so a
malicious upstream can't blow process memory
([sse.go:3338, 97100](../../../proxy/internal/llm/sse.go)).
`splitField` strips one space after the `:` per the SSE spec. Documented
`not safe for concurrent use`; every consumer creates a fresh scanner per
response body. Streaming accumulators live in the middleware package
([llm_response_parser/streaming.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go))
but use `llm.NewScanner` so the framing contract stays here.
### Pricing catalog
`Table.Cost`
([pricing.go:129174](../../../proxy/internal/llm/pricing/pricing.go))
is the cost formula — most security-relevant math in this module:
| Provider | Formula |
|---|---|
| `openai` | `(inTokens clamped) × InputPer1K + clamped × CachedInputPer1K + outTokens × OutputPer1K` where `clamped = min(cachedInput, inTokens)` |
| `anthropic`, `bedrock` | `inTokens × InputPer1K + cachedInput × CacheReadPer1K + cacheCreation × CacheCreationPer1K + outTokens × OutputPer1K` |
| default | `inTokens × InputPer1K + outTokens × OutputPer1K` |
`bedrock` shares the Anthropic additive-cache formula
([pricing.go:172-174](../../../proxy/internal/llm/pricing/pricing.go)):
Anthropic-on-Bedrock reports the same additive cache buckets, while non-Anthropic
Bedrock models (Nova, Llama) simply report zero in those buckets so cost reduces
to `input + output`.
Each per-bucket rate falls back to `InputPer1K` when zero — operators opt in
to discounts by setting the field.
`Loader`
([pricing.go:212268](../../../proxy/internal/llm/pricing/pricing.go))
overlays an optional `pricing.yaml` from data-dir on top of the go:embed
defaults. Atomic pointer swap means readers never observe a partial update.
The mtime-poll reloader (30s default cadence) keeps the previous table on
parse failure so cost annotation never goes blank during a botched edit.
`defaults_pricing.yaml` is the source of truth for built-in pricing.
Operator overrides only carry the entries they want to change.
## Public contracts
**`Parser` interface**
([parser.go:5066](../../../proxy/internal/llm/parser.go)):
```go
type Parser interface {
Provider() Provider
ProviderName() string
DetectFromURL(path string) bool
ParseRequest(body []byte) (RequestFacts, error)
ParseResponse(status int, contentType string, body []byte) (Usage, error)
ExtractPrompt(body []byte) string
ExtractCompletion(status int, contentType string, body []byte) string
}
```
Adding a provider means implementing this interface and appending to the
slice returned by `Parsers()` ([parser.go:7884](../../../proxy/internal/llm/parser.go)).
Order matters: `DetectFromURL` ties resolve by registration order.
`Parsers()` today returns `{OpenAIParser, AnthropicParser, BedrockParser}`.
**`Provider` enum**
([parser.go:818](../../../proxy/internal/llm/parser.go)):
`ProviderUnknown = 0`, `ProviderOpenAI = 1`, `ProviderAnthropic = 2`,
`ProviderBedrock = 3`. Numeric values are persisted in nothing today but treat
them as wire-stable — new providers must take fresh numbers.
**`Pricing` lookup**
([pricing.go:129](../../../proxy/internal/llm/pricing/pricing.go)):
```go
func (t *Table) Cost(provider, model string, inTokens, outTokens, cachedInput, cacheCreation int64) (float64, bool)
```
Nil-safe: `t.Cost` on a nil receiver returns `(0, false)`
([pricing.go:130132](../../../proxy/internal/llm/pricing/pricing.go)).
`ok=false` means provider or model is absent from the loaded table; the caller
emits `cost.skipped=unknown_model`.
## Invariants
1. **Cross-platform pricing build.** `pricing_unix.go` carries the only
functional `loadPricing` (uses `syscall.O_NOFOLLOW` and `f.Stat()` on an
open descriptor — both Unix-only). `pricing_other.go` is a build-tag
fallback that returns `"not supported on this platform"`
([pricing_other.go:1416](../../../proxy/internal/llm/pricing/pricing_other.go)).
The proxy is Linux-only in production today; a Windows port needs an
equivalent path-as-handle implementation. Reviewers building on Windows
should expect this surface to return an error at startup if an override
file is configured.
2. **SSE scanner handles partial chunks.** A buffered prefix that doesn't end
in `\n\n` still yields its accumulated event before `io.EOF`
([sse.go:5558](../../../proxy/internal/llm/sse.go)). Tests:
`TestSSEScanner_OpenAIFixture`, `TestSSEScanner_AnthropicFixture`,
`TestSSEScanner_MultilineData`, `TestSSEScanner_CRLF`. The streaming
accumulators ride on this: `accumulateAnthropicStream` and
`accumulateOpenAIStream` `break` on any scanner error to return partial
usage rather than aborting
([streaming.go:6873, 144150](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go)).
3. **`defaults_pricing.yaml` is the source of truth.** Compiled into the
binary via `//go:embed`
([pricing.go:2930](../../../proxy/internal/llm/pricing/pricing.go)).
`DefaultTable()` parses once and panics on parse failure
([pricing.go:4249](../../../proxy/internal/llm/pricing/pricing.go))
— by design: a broken embedded YAML must not ship to production.
4. **Loader path validation.** `resolveMiddlewareDataPath`
([pricing.go:370394](../../../proxy/internal/llm/pricing/pricing.go))
rejects absolute paths, traversal segments, and basenames that fail
`basenameRegex = ^[a-zA-Z0-9._-]+$`. The resolved path must remain
inside `baseDir` even after `filepath.Clean`. Tests:
`TestNewLoader_PathValidation`, `TestNewLoader_PathValidation_Extended`,
`TestNewLoader_SymlinkOutsideBaseDirRejected`, `TestNewLoader_SymlinkRejected`.
5. **Unix loader symlink safety.** `O_NOFOLLOW` on open, `f.Stat()` on the
open descriptor (never re-stat by path), `info.Mode().IsRegular()` check,
`io.LimitReader(f, maxPricingBytes+1)` with a final size assertion
([pricing_unix.go:2557](../../../proxy/internal/llm/pricing/pricing_unix.go)).
A mid-read symlink swap is detected because the fstat is on the original
fd. Test: `TestNewLoader_RejectsOversizedFile_FixesM4`.
6. **`yaml.NewDecoder(...).KnownFields(true)`**
([pricing.go:397398](../../../proxy/internal/llm/pricing/pricing.go))
rejects YAML files that carry fields not in the schema. A typo in an
operator override file fails loud instead of silently zeroing rates.
## Things to scrutinise
**Correctness.** Verify OpenAI cached-prompt clamp at
[pricing.go:147149](../../../proxy/internal/llm/pricing/pricing.go)
short-circuits before subtraction. `Anthropic.TotalTokens` sums all four
buckets (in + out + cache_read + cache_creation) — downstream dashboards
need to know this differs from `input + output`.
`OpenAIParser.ExtractPrompt` falls through `messages → input → prompt`; a
request sending all three reports only `messages` (uncommon but worth
noting).
**Security.** `Scanner.maxLine = 1 MiB`; a 2 MiB single-line `data:` event
errors from `Scanner.Next` and both accumulators stop with partial usage.
Pricing file 1 MiB cap is orders of magnitude larger than realistic. Confirm
new schema additions are mirrored in both `pricingFile` and `Entry`;
`KnownFields(true)` will reject silently-typo'd operator overrides
otherwise.
**Concurrency.** `Loader.table` is `atomic.Pointer[Table]`; readers never
block or see a torn table. `Loader.Reload` is one goroutine, cancelled via
context (`TestLoader_ReloadBackgroundLoopCancellation`). `DefaultTable()`
uses `sync.Once`. Per-call `Scanner` instances mean no shared state across
concurrent response-parser calls.
**Perf.** `Table.Cost` is two map lookups + multiplications, O(1).
`Scanner.Next` is one `ReadString('\n')` per line. Pricing reload poll 30s.
**Observability.** Reload failures count via `metric.Int64Counter` keyed
`plugin`; warning log rate-limited at 5 min so a broken file doesn't flood.
Parser errors return sentinels — middleware uses `errors.Is` to map to the
right `cost.skipped` reason.
## Test coverage
| File | Tests | Coverage highlights |
|---|---:|---|
| `parser_test.go` | 3 | `Parsers()` shape lock, `DetectParser` URL matrix, provider enum stability |
| `openai_test.go` | 11 | Chat Completions + Responses API + legacy `prompt`; cached-tokens subset for both naming conventions; fixture replays |
| `anthropic_test.go` | 7 | Messages + legacy `/v1/complete`; streaming REJECTED on `ParseResponse` (must use scanner); fixture replays |
| `sse_test.go` | 12 | Fixture replay both providers; multiline `data:`; CRLF; comment skip; trailing-event-without-blank-line; oversize rejection |
| `pricing/pricing_test.go` | 21 | Provider-shape switch; cached-rate fallback; cached-clamp; symlink rejection (target outside basedir + symlink to file); path validation matrix; oversize rejection; reload-keeps-previous-on-parse-error; mtime change detection; goroutine cancellation |
**Fixtures** ([proxy/internal/llm/fixtures/](../../../proxy/internal/llm/fixtures/)):
`openai_chat_completion.json` (chat.completions with usage),
`openai_responses.json` (Responses API shape),
`openai_stream.txt` (3 deltas + usage + `[DONE]`),
`anthropic_messages.json` (Messages API non-streaming),
`anthropic_stream.txt` (full 7-event sequence: message_start →
content_block_{start,delta×2,stop} → message_delta (usage) → message_stop),
`pricing.yaml` (realistic-pricing starter for operator overrides).
## Cross-references
- Sibling: [31-proxy-middleware-builtin.md](./31-proxy-middleware-builtin.md)
— the chain that calls `llm.Parsers()`, `llm.ParserByName`,
`llm.NewScanner`, `pricing.NewLoader`.
- Path-routed providers (Vertex AI + Bedrock), credential syntax, and the
Bedrock AWS event-stream accumulator:
[50-path-routed-providers.md](./50-path-routed-providers.md).
- Direct callers: `llm_request_parser/middleware.go:8294`,
`llm_response_parser/middleware.go:113123`,
`llm_response_parser/streaming.go:65, 142`, `cost_meter/factory.go:4957`.
- Related elsewhere: the agent-network synthesiser stamping `provider_id`
is covered in the management-side module guide; proxy server boot +
`FactoryContext` construction is covered in the proxy-framework guide.

View File

@@ -0,0 +1,194 @@
# proxy/runtime — translate + serve + log
> **Risk level:** High — every config push from management is translated here, and the chain runs on every HTTP request to a synth target.
> **Backward-compat impact:** Additive at the wire (`PathTargetOptions.middlewares`, `agent_network`, `disable_access_log`, capture caps) and on the proxy `Server` struct (`MiddlewareDataDir`, `MiddlewareCaptureBudgetBytes`). Non-agent-network targets stay on the no-middleware fast path.
## Module boundary
Turns the synth-service wire format from `ProxyService.SyncMappings`/`GetMappingUpdate` into in-process middleware chains and runs them on top of the existing `httputil.ReverseProxy`. Four concerns: (a) **translate**`proto.MiddlewareConfig` → validated `middleware.Spec` (proxy/middleware_translate.go) + self-register the eight built-ins (proxy/middleware_register.go); (b) **boot + rebuild** — construct the `middleware.Manager`, share the OTel meter, install the live-service check, rebuild per-path chains on every `addMapping`/`modifyMapping` (proxy/server.go); (c) **serve** — resolve chain at request time, capture bodies under a global budget, invoke `RunRequest`/`RunResponse`/`RunTerminal`, render deny responses, apply `UpstreamRewrite` (proxy/internal/proxy/reverseproxy.go); (d) **log + tag** — emit access-log entries with the new `agent_network` flag, gate emission on `EnableLogCollection` via `DisableAccessLog` (proxy/internal/accesslog).
**Inert for non-agent-network targets**: nil or empty chain → existing fast path (reverseproxy.go:127-139); `SuppressAccessLog` defaults false so the access-log middleware emits unchanged.
## Files
| Path | Role |
| ---- | ---- |
| proxy/middleware_translate.go | proto→Spec translation; slot/failmode/timeout mapping; caps |
| proxy/middleware_translate_test.go | translator unit tests |
| proxy/middleware_register.go | blank-imports the eight builtins for `init()` registration |
| proxy/server.go | `initMiddlewareManager`, `rebuildMiddlewareChains`, `isLiveService`, `buildMiddlewareBindings`, new Server fields, `protoToMapping` stamps AgentNetwork/DisableAccessLog/CaptureConfig/Middlewares |
| proxy/internal/proxy/reverseproxy.go | `WithMiddlewareManager`, chain dispatch, body capture, `applyUpstreamRewrite`/`Headers`, `buildRequestInput`, response-leg respInput identity fields |
| proxy/internal/proxy/reverseproxy_test.go | `TestBuildRequestInput_PropagatesIdentityAndGroups` |
| proxy/internal/proxy/context.go | `agentNetwork`, `suppressAccessLog`, `userGroupNames` on `CapturedData` |
| proxy/internal/proxy/servicemapping.go | new `PathTarget` fields |
| proxy/internal/proxy/agent_network_chain_realstack_test.go | end-to-end self-contained chain test |
| proxy/internal/accesslog/logger.go | `logEntry.AgentNetwork``proto.AccessLog` |
| proxy/internal/accesslog/middleware.go | reads `GetAgentNetwork()`; gates `l.log` on `!GetSuppressAccessLog()` |
| proxy/internal/accesslog/middleware_test.go | suppress/default/preserves-usage assertions |
| proxy/internal/auth/middleware_test.go | tunnel-peer group propagation contract |
| proxy/internal/metrics/metrics.go | `Meter()` getter for the middleware manager |
## Architecture & flow
### Synth-service ingestion → translate → register → serve
```mermaid
flowchart TD
A[Management SyncMappings/GetMappingUpdate] --> B["processMappings\nserver.go:1492"]
B --> C{Mapping type}
C -->|CREATED| D["addMapping → setupHTTPMapping → updateMapping"]
C -->|MODIFIED| E["modifyMapping → cleanupMappingRoutes → setupHTTPMapping → updateMapping"]
C -->|REMOVED| F["removeMapping → cleanupMappingRoutes → invalidateMiddlewareChains"]
D --> G["protoToMapping\nserver.go:2181"]
E --> G
G --> H["translateMiddlewareConfigs\nmiddleware_translate.go:55"]
G --> I["translateMiddlewareCaptureConfig\nmiddleware_translate.go:18"]
H --> J["[]middleware.Spec on PathTarget"]
I --> K["*bodytap.Config on PathTarget"]
J --> L["proxy.AddMapping\nservicemapping.go:118"]
K --> L
L --> M["rebuildMiddlewareChains\nserver.go:2017 → Manager.Rebuild"]
F --> N["Manager.Invalidate(serviceID)"]
```
### Per-request lifecycle through the chain + accesslog
```mermaid
sequenceDiagram
autonumber
participant C as Client
participant M as accesslog.Middleware
participant A as auth.Middleware (Protect)
participant RP as ReverseProxy.ServeHTTP
participant CH as middleware.Chain
participant U as Upstream
C->>M: HTTP request
M->>M: NewCapturedData(requestID), WithCapturedData(ctx)
M->>A: next.ServeHTTP
A->>A: Private → ValidateTunnelPeer → stamp UserID/Email/Groups/GroupNames/AuthMethod
A->>RP: next.ServeHTTP
RP->>RP: findTargetForRequest → targetResult
RP->>RP: stamp ServiceID/AccountID/AgentNetwork/SuppressAccessLog on CapturedData
RP->>RP: resolveChain via Manager.ChainFor
alt chain == nil or Empty
RP->>U: httputil.ReverseProxy.ServeHTTP (fast path)
else chain non-empty
RP->>RP: bodytap.CaptureRequest (global budget)
RP->>CH: RunRequest
CH-->>RP: denyOutput? requestMeta + upstreamRewrite
alt deny
RP->>C: RenderDenyResponse
else allow
RP->>RP: capturingWriter + applyUpstreamRewrite/Headers
RP->>U: httputil.ReverseProxy.ServeHTTP(respWriter)
U-->>RP: response
RP->>CH: RunResponse (respInput carries UserGroups)
RP->>CH: RunTerminal (merged request+response metadata)
end
end
RP-->>M: handler returns
M->>M: build logEntry incl. AgentNetwork
alt SuppressAccessLog == true
M->>M: skip l.log; still trackUsage
else default
M->>M: l.log → goroutine SendAccessLog
end
```
### EnableLogCollection suppression path
```mermaid
flowchart LR
S["agentnetwork.Settings.EnableLogCollection"] --> B["synthesizer: target.DisableAccessLog = !EnableLogCollection"]
B --> P["proto PathTargetOptions.disable_access_log (field 13)"]
P --> T["protoToMapping reads GetDisableAccessLog()\nserver.go:2211"]
T --> M["PathTarget.DisableAccessLog\nservicemapping.go:47"]
M --> R["ServeHTTP: cd.SetSuppressAccessLog\nreverseproxy.go:106"]
R --> G["accesslog middleware: if !GetSuppressAccessLog l.log\nmiddleware.go:95"]
R --> U["trackUsage unconditional — bandwidth telemetry preserved"]
```
**Ingestion** lands as a `ProxyMapping` batch on `handleSyncMappingsStream`/`handleMappingStream`. `processMappings` dispatches to `addMapping`/`modifyMapping`/`removeMapping`; HTTP goes `setupHTTPMapping → updateMapping → protoToMapping`. `protoToMapping` (server.go:2181) is the single translation surface that materialises `[]middleware.Spec`, `*bodytap.Config`, `AgentNetwork`, `DisableAccessLog` onto each `PathTarget`; `updateMapping` finishes with `s.proxy.AddMapping(m)` (atomic swap under `mappingsMux`) and `s.rebuildMiddlewareChains(svcID, m)`.
At **request time** the access-log middleware stamps `CapturedData`; the auth chain runs (Private services lift `peer_group_ids` from `ValidateTunnelPeer` — auth/middleware_test.go:322). `ReverseProxy.ServeHTTP` resolves the chain; nil or empty → original `httputil.ReverseProxy`, no body capture. When a chain matches, body is captured under the global budget, `RunRequest` produces an `UpstreamRewrite` (`llm_router` selects a provider, rewrites scheme/host/path, injects `Authorization`), and `RunResponse`+`RunTerminal` run after the upstream returns. The terminal slot sees the merged metadata bag — that's how `llm_limit_record` ships the consumption sample. The **access-log** addition: `logEntry.AgentNetwork` from `GetAgentNetwork()` onto `proto.AccessLog.AgentNetwork`; the gate at middleware.go:95 honors `EnableLogCollection`, skipping `l.log` but keeping `trackUsage` so bandwidth telemetry survives.
## Public contracts touched
- `proxy.Server.MiddlewareDataDir` (string) — base dir for file-backed middleware config (server.go:238-241).
- `proxy.Server.MiddlewareCaptureBudgetBytes` (int64) — process-wide capture cap; defaults to 256 MiB (server.go:248-250).
- `proxy/internal/proxy.WithMiddlewareManager(*middleware.Manager) Option` — new option on `NewReverseProxy`; nil keeps the fast path (reverseproxy.go:48-56).
- `proxy/internal/proxy.PathTarget` adds `Middlewares`, `CaptureConfig`, `AgentNetwork`, `DisableAccessLog` (servicemapping.go:27-51), all zero-default.
- `proxy/internal/proxy.CapturedData` adds `agentNetwork`, `suppressAccessLog`, `userGroupNames` behind `sync.RWMutex`; slices deep-copied (context.go:47-66, 183-258).
- `accesslog.logEntry.AgentNetwork` + `proto.AccessLog.AgentNetwork` (logger.go:131, 268).
- `metrics.Metrics.Meter()` exposes the OTel meter for the middleware manager (metrics.go:53-58).
## Invariants
- **Synth-service updates are live (no proxy restart).** Every `MODIFIED` flows through `modifyMapping → cleanupMappingRoutes` (invalidates chains) `→ setupHTTPMapping → updateMapping → rebuildMiddlewareChains`. **ProxyMapping.Private preservation:** the relevant logic lives in `management/internals/shared/grpc/proxy.go:shallowCloneMapping`, not this module, but it surfaces here — if a `MODIFIED` synth service arrives `private=false`, auth skips `ValidateTunnelPeer`, `CapturedData.UserGroups` stays empty, and `llm_router` denies with `llm_policy.no_authorised_provider` until a management restart re-pushes the snapshot. This module assumes `mapping.GetPrivate()` is correct on every batch.
- **`EnableLogCollection=false` suppresses access-log writes but middleware still runs.** Gate is one `if !cd.GetSuppressAccessLog()` immediately around `l.log(entry)` (middleware.go:95); `trackUsage` runs below the gate. Locked by `TestMiddleware_SuppressAccessLog_PreservesUsageTracking` (middleware_test.go:139).
- **`agent_network` flag on access-log entries is set when the chain processed the request.** Source `target.AgentNetwork`, stamped at reverseproxy.go:105, read at accesslog/middleware.go:86.
- **auth → builtin group propagation.** `Protect` writes `UserGroups`/`UserGroupNames`; `buildRequestInput` (reverseproxy.go:333) copies them into `middleware.Input`. The response-leg `respInput` (reverseproxy.go:196-223) also carries `UserEmail`/`UserGroups`/`UserGroupNames``llm_limit_record` needs `UserGroups` to ship `group_ids` so management's group-targeted budget rules match (comment at reverseproxy.go:211-215).
- **Empty chains stay on the fast path.** `ServeHTTP` skips body capture and the run sequence when `chain == nil || chain.Empty()` (reverseproxy.go:127).
- **Self-registration is the only way a builtin reaches the registry.** `middleware_register.go` blank-imports each builtin; `init()` adds the factory to `mwbuiltin.DefaultRegistry()`. Missing it → translator drops the entry with a warn (translate.go:97).
## Things to scrutinize
### Correctness
- **Translate edge cases** — drops on nil cfg, empty ID, unknown ID, UNSPECIFIED slot; each logs one warn; volume bounded by `MaxMiddlewaresPerChain`.
- **Re-translate without dropping in-flight requests** — `Manager.Rebuild` is the only call from `rebuildMiddlewareChains`. Reverse proxy reads `ChainFor` once per request (reverseproxy.go:327) and runs the captured `*Chain` for the whole request. Verify in module 30 that `Rebuild` swaps atomically.
- **ProxyMapping.Private preservation** — enforced management-side in `shallowCloneMapping`. Proxy-side regression catches: `TestProtect_PrivateService_TunnelPeerGroupsPropagate` + the integration test.
- **Body-capture cleanup** — `defer releaseBudget()` (reverseproxy.go:145) and `defer capturingWriter.Release()` (reverseproxy.go:180) must run on every return; confirm no future `return` lands between acquisition and defer.
- **`applyUpstreamRewrite` clones the URL** — `cloned := *orig` value-copies `*url.URL`; safe because overwritten fields are strings, not slices/maps (reverseproxy.go:285-292).
### Security
- **Translate validates every config** — registry membership rejects unknown IDs; UNSPECIFIED slot drops; ID-less drops; raw config copied (not aliased) at translate.go:109.
- **`AuthHeader`/`StripHeaders` only reachable via `UpstreamRewrite`** — regular mutation surface goes through the framework denylist (`Authorization`/`Cookie` blocked); only the router middleware can replace `Authorization` (reverseproxy.go:296-304). Confirm in module 30 nothing outside the proxy-trusted path populates `UpstreamRewrite.AuthHeader`.
- **`stampNetBirdIdentity` strips client-sent values first** (reverseproxy.go:742-743) — anti-spoof for `X-NetBird-User`/`X-NetBird-Groups`; control chars filtered; comma-bearing labels dropped (reverseproxy_test.go:1217/:1243/:1193).
- **Auth → group propagation** — `auth/middleware_test.go:322` and `:366` cover the contract. If auth ever stops calling `ValidateTunnelPeer` for Private services, every agent-network request silently denies.
### Concurrency
- **Chain replacement under in-flight requests** — `findTargetForRequest` takes `mappingsMux.RLock`; `AddMapping` writes. `resolveChain` calls `ChainFor` once; even if `Rebuild` swaps mid-request, in-flight requests keep running on the captured pointer.
- **`CapturedData` mutation across slots** — accessors take `sync.RWMutex`; slices deep-copied on both Set and Get. Verify no caller mutates the returned slice expecting it to land back.
- **`Manager.Invalidate` race** — `removeMapping` invalidates after `cleanupMappingRoutes`; mapping read happens before chain resolution, so requests before invalidate run captured chains; later ones fail `findTargetForRequest`.
- **`Logger.log` goroutine** — `logSem` caps at `maxLogWorkers = 4096`; overflow → `dropped.Add(1)` + debug log. Middleware test uses a buffered channel and 150ms negative-assertion window — review whether 150ms holds on slow CI.
### Backward compatibility
- **Non-agent-network services unaffected** — `protoToMapping` reads new fields only when `opts != nil`; defaults leave `Middlewares`/`CaptureConfig` nil → chain resolves nil → fast path. Existing `reverseproxy_test.go` (non-chain) still passes.
- **`disable_access_log` is proto field 13, default false** — every existing target unset; gate is no-op. Locked by `TestMiddleware_SuppressAccessLog_DefaultEmitsLog` (middleware_test.go:104).
- **`Server` additions optional** — 256 MiB default when `MiddlewareCaptureBudgetBytes ≤ 0` (server.go:1997-2000).
### Performance
- **Translate cost per push** — O(n) with per-entry registry lookup and `config_json` copy; negligible vs. the upstream gRPC unmarshal.
- **Empty-chain hot path** — one `ChainFor` map lookup + one `chain.Empty()` check; no allocation delta vs. pre-PR.
- **Body capture buffer churn** — `bodytap.CaptureRequest` allocates `MaxRequestBytes` per chain-hitting request; `releaseBudget` ties allocation to the 256 MiB proxy-wide budget. Confirm in module 30 the budget is a hard cap.
### Observability
- **Metrics** — `Metrics.Meter()` shared with `middleware.NewMetrics` (server.go:1990-1993) so middleware instruments land in the same prometheus exporter. No new metrics defined here.
- **Access-log accuracy** — every entry carries `AgentNetwork`; terminal-slot metadata merged into `CapturedData.Metadata` (reverseproxy.go:238-241).
- **Deny logs at `Infof`** (reverseproxy.go:170) — review whether `Info` is too noisy at high deny rates; consider Debug or rate-limit.
## Test coverage
| Test file | Locks down |
| --------- | ---------- |
| proxy/middleware_translate_test.go | Empty/nil → nil; field preservation; unknown ID skip; nil registry permissive; timeout clamping; fail-mode + slot incl. UNSPECIFIED-drop; empty-ID drop; truncation above + at `MaxMiddlewaresPerChain` |
| proxy/internal/proxy/reverseproxy_test.go | Rewrite host/headers/cookies/query; trusted proxy; path forwarding; classifyProxyError; X-NetBird-User/Groups anti-spoof + CSV-join + control-char/comma rejection + fallback-to-ID; `TestBuildRequestInput_PropagatesIdentityAndGroups` (UserGroups/Email/GroupNames/AgentNetwork reach `middleware.Input`) |
| proxy/internal/proxy/agent_network_chain_realstack_test.go | **The end-to-end integration test.** Drives a real agent-network request through `ReverseProxy.ServeHTTP` with the chain the synthesizer produces, against an in-process management gRPC (bufconn) backed by a real sqlite store + real `agentnetwork.Manager`, plus an `httptest` upstream — no external infrastructure or real LLM. Guarantees: (1) response-leg `respInput` carries `UserGroups` so `llm_limit_record` ships non-empty `group_ids` and the admin-group consumption row increments; (2) `RedactPii=true` redacts both prompt and completion on captured metadata; (3) the full chain runs against a real management stack. **Line 189-211 inlines the proto→Spec mapping** instead of calling the proxy's private `translateMiddlewareConfig` — keep that inline mirror in sync with `proxy/middleware_translate.go` or the test silently diverges from production. |
| proxy/internal/accesslog/middleware_test.go | `SuppressAccessLog=true` skips `SendAccessLog` (150ms negative wait); default emits one send (2s positive); usage tracking runs under suppression |
| proxy/internal/auth/middleware_test.go | `TestProtect_PrivateService_TunnelPeerGroupsPropagate` proves `peer_group_ids` reach `CapturedData.UserGroups`; `TestProtect_PrivateService_TunnelPeerDenied` proves rejected peers 403 without reaching the handler |
The integration test runs in a few seconds with no external infrastructure — exercising the real synthesizer, `Manager.Rebuild`, `ServeHTTP` dispatch, and `llm_limit_record` writing a real consumption row through the real `agentnetwork.Manager` over real gRPC.
## Known limitations / explicit non-goals
- **Translator does not validate `RawConfig` JSON** — factory's job at `New([]byte)`. Confirm in module 30 that a per-binding factory failure doesn't poison the rest of the chain.
- **No throttle on management push rate** — every `MODIFIED` triggers `Manager.Rebuild`. Mitigation upstream.
- **Streaming responses (SSE)** — body capture is streaming-aware, but response-leg middleware runs only after the response completes; long SSE streams delay `llm_limit_record` until close.
- **OIDC-only path doesn't carry tunnel-peer groups** — agent-network synth services rely on the Private tunnel-peer path; JWT groups claim is the only carrier for non-Private OIDC.
- **`agent_network` flag on L4 entries** not added; HTTP-only.
- **`mw.capture.bypass_reason` metadata key** documented at reverseproxy.go:151,184; namespace this in module 30/31 to avoid collisions.
## Cross-references
- Upstream: [shared/api](10-shared-api.md), [proxy/middleware-framework](30-proxy-middleware-framework.md), [proxy/middleware-builtin](31-proxy-middleware-builtin.md), [proxy/llm-parsers](32-proxy-llm-parsers.md)
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
- Top-level: [../00-overview.md](../00-overview.md)

View File

@@ -0,0 +1,228 @@
# dashboard — UI for agent-networks
This module documents code that lives in the **dashboard repo** (under
`src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`), not
in this repo. It is co-located here so backend readers see the full picture.
> **Risk level:** Medium. The new surface is isolated under `src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`, but it also reshapes the sidebar, splits `/peers`, renames `reverse-proxy/clusters` → `self-hosted-proxies`, and overlays the Control Center graph. Regressions here would be cross-cutting.
> **Backward-compat impact:** Additive on the API side. Breaking on URL/navigation: `/peers` redirects to `/peers/devices` (src/app/(dashboard)/peers/page.tsx:7-15), `/reverse-proxy/clusters` was renamed to `/reverse-proxy/self-hosted-proxies`, the sidebar lost Access Control / Networks / Reverse Proxy / DNS / standalone Guardrails / Consumption / Activity (Navigation.tsx:165-171 — routes still resolve via URL), and the standalone `/agent-network/{access-log,consumption,global-controls}` routes are gone in favor of `/agent-network/observability`.
## Module boundary
The dashboard is the only place an operator interacts with agent-networks: provider catalog, configured providers, policies, guardrails, account-level budget rules, account settings (collection / redaction toggles), per-request access log, and consumption rollups all render, paginate, and edit here. Data flows in via SWR (`useFetchApi`) keyed by REST URL. One big context provider (`src/modules/agent-network/AIProvidersProvider.tsx`) aggregates five resources (providers, policies, guardrails, budget rules, settings) plus the proxy access-log stream filtered to `agent_network=true`, and exposes `add* / update* / toggle* / delete*` mutators that call through `useApiCall` and re-`mutate()` SWR. Pages mount the provider once at the top and compose presentational tables and modals beneath. The control-center page additionally fetches `/agent-network/{providers,policies}` directly (control-center/page.tsx:123-130) to overlay graph nodes.
## What the UI delivers
- **AI Observability** page with four tabs: Access Logs, Budget Dashboard,
Budget Settings, Log Settings (replaces the standalone access-log,
consumption, and global-controls routes).
- **Providers** page: provider catalog + connect/edit wizard with per-vendor
copy (LiteLLM, Portkey, Bifrost, Cloudflare, Vercel, OpenRouter, custom).
- **Policies** page: group → provider authorization with per-policy Limits
(minute-granular windows) + guardrail attach.
- **Guardrails** page: reusable model-allowlist + prompt-capture sets.
- **Account controls**: Log Collection / Prompt Collection / Redact PII toggles.
- **Budget rules**: account-level rules reusing the policy Limits UI.
- **Control Center overlay**: provider + agent-policy nodes on the graph.
- **Navigation + peers reshaping**: peers split into Devices / Agents,
`reverse-proxy/clusters` renamed to `self-hosted-proxies`, sidebar
repackaged for agent-network focus.
## Surface added
### New pages
| Route | Purpose | Backing module(s) |
| ----- | ------- | ----------------- |
| `/agent-network` | Redirect to `/agent-network/providers` | page.tsx:7-15 |
| `/agent-network/providers` | List + connect providers; header surfaces per-account base URL | providers/page.tsx + AgentProvidersTable + AIProviderModal |
| `/agent-network/policies` | Group → Provider authorization with per-policy Limits + Guardrail attach | policies/page.tsx + AgentPoliciesTable + AgentPolicyModal |
| `/agent-network/guardrails` | Reusable guardrail sets (model allowlist + prompt capture) | guardrails/page.tsx + AgentGuardrailsTable + AgentGuardrailModal |
| `/agent-network/observability` | Tabs: Access Logs / Budget Dashboard / Budget Settings / Log Settings | observability/page.tsx |
| `/peers/devices`, `/peers/agents` | Split of `/peers`, shared via `PeersListView` keyed by `kind` | peers/{devices,agents}/page.tsx |
| `/reverse-proxy/self-hosted-proxies` | Renamed from `clusters` | self-hosted-proxies/page.tsx |
Removed in favor of `/agent-network/observability`: `/agent-network/access-log`, `/agent-network/consumption`, `/agent-network/global-controls`.
### New modules under src/modules/agent-network
| File | Role |
| ---- | ---- |
| AIProvidersProvider.tsx (~1158 LOC) | Aggregates every agent-network resource via SWR; normalises snake↔camel; exposes mutators; holds wizard-open state |
| AIProviderModal.tsx (~1268 LOC) | Connect / edit provider wizard with per-vendor copy (Bifrost, Portkey, LiteLLM, Cloudflare, Vercel, OpenRouter, custom) |
| AIProviderLogo + useProviderCatalog | Catalog-driven brand swatch + SWR hook over `/agent-network/catalog/providers` |
| AgentPoliciesTable + AgentPolicyModal + AgentPolicyGuardrailsTab + AgentPolicyLimitsTab | Policies; modal has 3 tabs (Rule, Limits, Guardrails) |
| AgentGuardrailsTable + AgentGuardrailModal + AgentGuardrailBrowseModal + AgentGuardrailChecksCell | Guardrails CRUD + attach-from-policy |
| AgentBudgetRulesTable + AgentBudgetRuleModal | Account-level budget rules; modal reuses AgentPolicyLimitsTab verbatim |
| AgentAccountControlsCard | Three account-wide toggles (Log Collection / Prompt Collection / Redact PII) |
| AgentAccessLogTable + AgentAccessLogExpandedRow | Access log on `/events/proxy?agent_network=true` |
| AgentConsumptionPanel + AgentConsumptionTable | Token + cost panel: charts + counter table |
| table/AgentProvidersTable + AgentProviderActionCell | Providers table + per-row actions |
| data/mockData.ts | Domain types and a few residual `MOCK_*` constants (see scrutinize) |
### Touched non-agent-network areas
- **control-center**: agent-network overlay (provider + agent-policy nodes); removed the All Networks dropdown; hid the Networks tab in FlowSelector (FlowSelector.tsx:9-14 — enum value kept so `?tab=networks` still type-checks); wrapped `ControlCenterView` in `AIProvidersProvider` (page.tsx:73-83); `agentPolicyNode` clicks routed to a separate state slot (page.tsx:1871-1874). New node renderers: nodes/ProviderNode.tsx, nodes/AgentPolicyNode.tsx (registered at utils/nodes.ts:21-22).
- **peers**: Split into Devices and Agents sub-routes; shared via `PeersListView` keyed by `kind` (PeersListView.tsx:24-95). New compact-toolbar `UserFilterSelector` (users/UserFilterSelector.tsx).
- **reverse-proxy**: Folder rename `clusters/``self-hosted-proxies/`; deleted `ClustersFeaturesCell.tsx`, `ClusterTypeIndicator.tsx`; new ReverseProxyClusterTargetSelector for cluster target type; Private toggle on target modal; body-capture knobs removed; new ReverseProxyEventExpandedRow.
- **events**: `ReverseProxyEventsUserCell` rewritten with user + peer fallback (ReverseProxyEventsUserCell.tsx:14-21), shared with the access-log table.
- **navigation**: Full repackaging in Navigation.tsx — Agent Network items flattened (no collapsible parent), distinct icons per item; Access Control, Networks, Reverse Proxy, DNS, standalone Guardrails, Consumption, Activity removed (still URL-reachable, per lines 165-171).
## Architecture & flow
### Page → Provider → Table/Modal hierarchy
```mermaid
graph TD
Nav[Navigation.tsx]
Nav --> ProvidersPage[/agent-network/providers/]
Nav --> PoliciesPage[/agent-network/policies/]
Nav --> GuardrailsPage[/agent-network/guardrails/]
Nav --> ObsPage[/agent-network/observability/]
ProvidersPage --> AIPP1[AIProvidersProvider]
PoliciesPage --> AIPP2[AIProvidersProvider]
GuardrailsPage --> AIPP3[AIProvidersProvider]
ObsPage --> AIPP4[AIProvidersProvider]
ObsPage -.wraps.-> GroupsProvider
ObsPage -.wraps.-> PeersProvider
AIPP1 --> ProvTable[AgentProvidersTable]
ProvTable --> ProvModal[AIProviderModal]
AIPP2 --> PolTable[AgentPoliciesTable]
PolTable --> PolModal[AgentPolicyModal]
PolModal --> PolGuardTab[AgentPolicyGuardrailsTab]
PolModal --> PolLimitsTab[AgentPolicyLimitsTab]
PolGuardTab --> GuardBrowse[AgentGuardrailBrowseModal]
PolGuardTab --> GuardModal[AgentGuardrailModal]
AIPP3 --> GuardTable[AgentGuardrailsTable]
GuardTable --> GuardModal
AIPP4 --> Tabs[Tabs]
Tabs --> AccessLog[AgentAccessLogTable]
Tabs --> Consumption[AgentConsumptionPanel]
Tabs --> BudgetRules[AgentBudgetRulesTable]
Tabs --> AccountCtl[AgentAccountControlsCard]
BudgetRules --> BudgetModal[AgentBudgetRuleModal]
BudgetModal -.reuses.-> PolLimitsTab
```
### AI Observability tab page
```mermaid
graph LR
Page[AIObservabilityPage] --> RA[RestrictedAccess<br/>permission.services.read]
RA --> GP[GroupsProvider]
GP --> PP[PeersProvider]
PP --> AIP[AIProvidersProvider]
AIP --> Tabs[Tabs / TabsList]
Tabs --> T1[Access Logs<br/>AgentAccessLogTable]
Tabs --> T2[Budget Dashboard<br/>AgentConsumptionPanel]
Tabs --> T3[Budget Settings<br/>AgentBudgetRulesTable]
Tabs --> T4[Log Settings<br/>AgentAccountControlsCard]
T1 -.GET.-> EP[/events/proxy?agent_network=true/]
T2 -.GET poll 5s.-> CONS[/agent-network/consumption/]
T3 -.GET/PUT.-> BR[/agent-network/budget-rules/]
T4 -.GET/PUT.-> ST[/agent-network/settings/]
```
### Data fetch path
```mermaid
graph TD
Page[Page component] --> Prov[AIProvidersProvider]
Prov -->|useFetchApi| SWR[(SWR cache<br/>key = URL)]
SWR -.GET.-> P[/agent-network/providers/]
SWR -.GET.-> POL[/agent-network/policies/]
SWR -.GET.-> G[/agent-network/guardrails/]
SWR -.GET.-> BR[/agent-network/budget-rules/]
SWR -.GET ignoreError.-> ST[/agent-network/settings/]
SWR -.GET.-> CAT[/agent-network/catalog/providers/]
SWR -.GET pageSize=100.-> EVT[/events/proxy agent_network=true/]
Prov --> Mut[useApiCall.post/put/del]
Mut -.on success.-> MutateSWR[SWR mutate keys]
Prov --> Children[Tables / Modals via useAIProviders]
```
Every list view reaches management through SWR over `/api/agent-network/*`. The provider context maps snake-case payloads to camelCase domain types (`fromAPI`, `policyFromAPI`, `guardrailFromAPI`, `budgetRuleFromAPI`, `settingsFromAPI`, `accessLogFromAPI` — AIProvidersProvider.tsx:138-562) and back via matching `*ToRequest` adaptors. The access log piggy-backs on `/events/proxy` with `agent_network=true&page_size=100` (line 707-709) and decodes LLM-specific fields from per-event `metadata`. Group IDs on events are resolved to current names through the surrounding GroupsProvider catalog (lines 515-521, 717-731) — no extra round trip. Mutators run `*ToRequest`, await `useApiCall.post/put/del`, call SWR `mutate()`, then `notify`. Errors caught and surfaced via `notify` — no exceptions escape into render. The Connect Provider modal's open state lives in the provider itself (`isWizardOpen` at lines 732-735) so the providers-page empty-state CTA and the table's + button share one modal. Control-center re-fetches `/agent-network/{providers,policies}` directly on top of `AIProvidersProvider` — SWR de-dupes but the code path is harder to reason about.
## Public contracts consumed
- `GET/POST /api/agent-network/providers`, `PUT/DELETE /:id`
- `GET/POST /api/agent-network/policies`, `PUT/DELETE /:id`
- `GET/POST /api/agent-network/guardrails`, `PUT/DELETE /:id`
- `GET/POST /api/agent-network/budget-rules`, `PUT/DELETE /:id`
- `GET/PUT /api/agent-network/settings` (ignoreError-tolerant; 404 = not yet bootstrapped — auto-bootstrap on first provider create via `bootstrap_cluster` field — AIProvidersProvider.tsx:737-760)
- `GET /api/agent-network/catalog/providers` (read-only declarative; backend owns vendor list, IDs, brand colors, models, extra_headers, identity_injection — useProviderCatalog.ts:6-95)
- `GET /api/agent-network/consumption` (polled every 5s on Budget Dashboard — ConsumptionPanel.tsx:53,65-71)
- `GET /api/events/proxy?agent_network=true&page_size=100` (shared with Proxy Events)
- `permission?.services?.read` gates every agent-network route via RestrictedAccess.
`AIProviderId` is a closed union in dashboard types (data/mockData.ts:8-21) but the converter tolerates anything the backend ships — unknown ids fall through to `"custom"` (AIProvidersProvider.tsx:497-506). Catalog values are pure read-through: anything declared in `extra_headers` renders in the modal automatically, copy keyed by header name (`EXTRA_HEADER_UI` in AIProviderModal.tsx:61-89), labeled-fallback for unknown ones.
## Invariants
- Provider context wrap order on user-attribution pages: `GroupsProvider > PeersProvider > AIProvidersProvider` (observability/page.tsx:87-89). Reverse it and access-log group resolution silently drops names.
- Every agent-network route checks `permission?.services?.read` via `RestrictedAccess` (observability/page.tsx:85, providers/page.tsx:184, policies/page.tsx:53, guardrails/page.tsx:55).
- Modal `key={open ? 1 : 0}` pattern is used to force unmount/remount on close so internal `useState` resets between edits (AgentBudgetRuleModal.tsx:60, AgentPolicyModal.tsx:66). Removing this would leak prior-row state into a new-row session.
- `mockData.ts` is the canonical home for ALL agent-network domain types; `MOCK_*` constants must never reach a production code path. One leak remains (below).
## Things to scrutinize
### Correctness
- **Tab-state URL hand-off is one-way.** observability/page.tsx:53-58 reads `?tab=` on mount (despite the file comment at line 28 saying URL hand-off is future) but `setTab` does NOT push back, so reload preserves the chosen tab only if it came in via the link. Inconsistent with control-center (page.tsx:1817-1831).
- **Provider overlay runs only in `applySingleGroupView` / `applyPeerView`** (control-center/page.tsx:557, 1159-1166). User view does NOT show providers — if agent-network is a primary lens, that's a gap.
- **Two useEffects race to invalidate the control-center layout.** page.tsx:1655-1657 drops `layoutInitialized` when `agentPolicies` / `agentProviders` arrive; the main effect (1786-1799) also lists them as deps. Functional but fragile — watch for flash-of-empty-graph.
- **`updateProvider` / `updatePolicy` / `updateBudgetRule` use `??` on `enabled`** (AIProvidersProvider.tsx:784, 859, 1018). Toggle paths are safe; any caller sending `enabled: false` thinking "leave it off" gets `existing.enabled` instead. Audit modal callers.
- **Form validation in modals is minimal.** Window-seconds picker — mockData.ts:209-215 documents "minimum 60 — one minute" but there is no matching UI guard in PolicyLimitsTab; the backend validator is the enforcement point.
### Security
- **No client-side enforcement claims** — every cap, allowlist, and toggle is display + edit; proxy is the source of truth for deny decisions (AccessLogTable.tsx:177-191 renders backend-emitted `denyReason` as-is).
- **Prompt display is gated by what the backend stamps.** When `enable_prompt_collection` is OFF the proxy must not put prompt/completion into event metadata; the dashboard renders whatever it gets verbatim (AccessLogTable lines 532-534, AccessLogExpandedRow.tsx:42-57). No UI filter on top of backend collection switches.
- Account Controls disables `Redact PII` when `Prompt Collection` is off (AgentAccountControlsCard.tsx:122) and clears it on off-transition (line 100), but relies on backend to enforce the same gate at write — confirm PUT handler rejects `redact_pii=true && enable_prompt_collection=false`.
- **Bifrost identity-header overrides**: empty-string vs nil semantics documented in AIProvidersProvider.tsx:772-781 ("omitted = preserve, empty = explicit clear"). Mishandling could leak group attribution to a header the operator thought disabled. Focused read of Bifrost code path in AIProviderModal.tsx recommended.
### Accessibility
- Observability TabsList (observability/page.tsx:96-113) uses the shared Tabs component — should inherit Radix roving-tabindex. All four TabsTriggers carry only icon + text, no `aria-label`; fine because text is visible.
- Modal focus traps are inherited from the shared Modal; agent-network modals don't override them. Quick keyboard pass recommended.
- `EndpointBadge` Copy button (providers/page.tsx:66-76) has an `aria-label`, good.
### Performance
- `AgentConsumptionPanel` polls `/agent-network/consumption` every 5s (ConsumptionPanel.tsx:53,70). Tab switches unmount the panel, so the poll stops — verify in network panel.
- `AgentAccessLogTable` is hard-capped at 100 rows via `page_size=100` (AIProvidersProvider.tsx:707-709). Server-side pagination is future work; high-traffic tenants miss everything past row 100 — known limitation.
- Observability page mounts providers ONCE at page level (observability/page.tsx:87-89); tab switches keep SWR cache hot. Moving the provider mount inside `TabsContent` would re-fetch the access log on every switch.
### Visual consistency
- The observability tab style mirrors peers/page.tsx. Outer Tabs `pt-4 pb-0 mb-0`, TabsList `px-8` (observability/page.tsx:94-96) — confirm chrome height matches so the page doesn't visually jump.
- Sidebar: `Boxes` for Providers, `AccessControlIcon` for Policies, `TelescopeIcon` for AI Observability (Navigation.tsx:113,120,133). Reusing `AccessControlIcon` makes Policies look identical to the (now hidden) Access Control item — if Access Control ever comes back, they collide.
- `AgentNetworkIcon` is used in breadcrumbs on every agent-network page but NOT in the sidebar (per-page icons instead). Deliberate departure — record so it doesn't get reverted.
## Test coverage
- **Cypress**: One file (`cypress/e2e/test.cy.ts`) covering only the install-page copy-to-clipboard flow. NOTHING covers agent-network UI.
- **Component / unit tests**: `src/utils/version.test.ts` is the only `.test.*` file in the repo. The agent-network modules ship without component tests.
- Data-cy hooks exist on key controls: `save-account-controls` (AgentAccountControlsCard.tsx:71), `enable-log-collection`, `enable-prompt-collection`, `redact-pii`, plus existing `data-cy={policy.name}` / `data-cy={provider.name}` on ActiveInactiveRow. Sufficient hooks for Cypress flows; none written yet.
- **Tooling gap (pre-existing):** `npm run lint` (`next lint`) is broken in Next 16 — the `lint` subcommand was removed from the Next CLI in 16.x, so the dashboard effectively has no working lint gate. The fix is to add either a flat-config `eslint .` script or wire ESLint via an explicit `eslint-config-next` invocation.
## Known limitations / explicit non-goals
- **`data/mockData.ts` still contains `MOCK_GROUPS`, `MOCK_PROVIDERS`, `MOCK_PEERS`.** Only `MOCK_GROUPS` is referenced from production — AgentPoliciesTable.tsx:45,76 uses it as a name-lookup fallback when a policy references a group ID the real GroupsProvider doesn't know about. `MOCK_PROVIDERS` / `MOCK_PEERS` are unreferenced; safe to delete. The file is `/* eslint-disable */` so dead-code warnings don't flag them.
- **Tab-state URL hand-off on observability page is one-way** (read-only).
- **Access log hard-capped at 100 rows**; no server-side pagination.
- **No optimistic updates.** All mutations are round-trip; failures rollback via SWR revalidation.
- **`FlowView.NETWORKS` retained but hidden** from FlowSelector (FlowSelector.tsx:9-14). Old `?tab=networks` links still route to the hidden view because `applyNetworksView` still runs.
- **Redirects are not query-preserving** — `router.replace("/peers/devices")` (peers/page.tsx:13) strips any incoming filter params.
- **Control-center cross-fetches** `/agent-network/{providers,policies}` directly on top of `AIProvidersProvider`. Could be collapsed.
- **Sidebar permanently hides Access Control, Networks, Reverse Proxy, standalone Guardrails, DNS, Activity, Consumption.** Routes still resolve via URL (Navigation.tsx:165-171); intentional.
## Cross-references
- Upstream API contracts: [shared/api](10-shared-api.md)
- Backend persistence: [management/store](20-management-store.md)
- Backend handler wiring: [management/handlers + wiring](22-management-handlers-wiring.md)
- End-to-end flow narrative: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
- Top-level overview: [../00-overview.md](../00-overview.md)

View File

@@ -0,0 +1,251 @@
# path-routed providers — Vertex AI + Bedrock
This guide pulls the **path-routed** provider story together in one place
because it crosses the catalog, the synthesiser, the request parser, and the
router. The relevant building blocks are the `llm_router` /
`llm_request_parser` middlewares
([31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md)), the
per-provider parser surface ([32-proxy-llm-parsers.md](32-proxy-llm-parsers.md)),
and the synthesiser's catalog → `ProviderRoute` mapping
([21-management-agentnetwork.md](21-management-agentnetwork.md)).
Sibling modules: [31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md)
(router + request parser) and [32-proxy-llm-parsers.md](32-proxy-llm-parsers.md)
(Bedrock parser + pricing).
---
## What "path-routed" means
Most catalog providers carry the model in the request **body** (`{"model": …}`),
so `llm_router` selects an upstream by matching the model name against each
provider's `Models` claim. Two providers instead carry the model in the **URL
path**, so they are routed by path before the model/vendor table is consulted:
| Catalog id | Style flag | Request path shape |
|---|---|---|
| `vertex_ai_api` | `IsVertexPathStyle``ProviderRoute.Vertex` | `/v1/projects/{project}/locations/{region}/publishers/{publisher}/models/{model}:{action}` |
| `bedrock_api` | `IsBedrockPathStyle``ProviderRoute.Bedrock` | `/model/{modelId}/{action}` (optionally behind `/bedrock`) |
The catalog declares the style with
[`catalog.IsVertexPathStyle` / `catalog.IsBedrockPathStyle`](../../../management/server/agentnetwork/catalog/catalog.go)
and the synthesiser copies the result onto the router route as the `Vertex` /
`Bedrock` booleans
([synthesizer.go:450-451](../../../management/server/agentnetwork/synthesizer.go)).
On the request leg `llm_router.Invoke` dispatches `isVertexPath` / `isBedrockPath`
**before** the model lookup
([llm_router/middleware.go:138-216](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
so a model the parser extracted from the path can't be claimed by a same-vendor
*body-routed* provider (e.g. `claude-*` on `api.anthropic.com`).
## Google Vertex AI (`vertex_ai_api`)
### Catalog entry
`KindProvider`, parser surface left unset on the catalog entry — the request
parser picks the parser from the URL **publisher** segment, not from
`ParserID`. Upstream host is `<region>-aiplatform.googleapis.com`
(`https://aiplatform.googleapis.com` for the `global` location). The catalog
lists the Claude-on-Vertex lineup (`claude-opus-4-*`, `claude-sonnet-4-*`,
`claude-haiku-4-5`, `claude-fable-5`) at the same per-token rates as the
first-party Anthropic entry
([catalog.go:333-363](../../../management/server/agentnetwork/catalog/catalog.go)).
### Credential — service-account OAuth (`keyfile::`)
Vertex does **not** accept a static API key. The operator sets the provider
`api_key` to:
```
keyfile::<base64 of the GCP service-account JSON key>
```
The synthesiser recognises the `keyfile::` prefix in `providerAuthHeader`
([synthesizer.go:897-903](../../../management/server/agentnetwork/synthesizer.go)),
emits **no** static auth value, and carries the base64 key material on the
route as `GCPServiceAccountKeyB64`
([factory.go:56-61](../../../proxy/internal/middleware/builtin/llm_router/factory.go)).
At request time the router mints a short-lived OAuth2 access token from the key
(cloud-platform scope) and injects `Authorization: Bearer <access-token>`
never the key itself
([llm_router/middleware.go:621-692](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)):
- One auto-refreshing `oauth2.TokenSource` is cached per key (keyed by a
SHA-256 of the base64 material), so token minting happens once and refreshes
amortise across requests.
- Mint / refresh is bounded by a 10s timeout HTTP client (`gcpTokenTimeout`) so
a slow Google token endpoint can't hang the request.
- A malformed key or an unreachable token endpoint fails the request with
`llm_policy.upstream_auth_failed` at HTTP **502** (an upstream problem, not a
policy denial) — see `denyUpstreamAuth`.
### Metering — Anthropic-on-Vertex only
The request parser extracts `{publisher, model, action}` from the path
(`parseVertexPath`, [llm_request_parser/middleware.go:237-263](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)),
strips the `@version` suffix from the model, and maps the publisher to a parser
surface via `vertexPublisherVendor`:
- `anthropic``llm.provider="anthropic"` → metered through the Anthropic
parser, priced under the **`anthropic`** block in `defaults_pricing.yaml`
(the parser emits the standard Anthropic provider label, so Vertex Claude
reuses first-party Anthropic prices).
- `openai``llm.provider="openai"` (reserved; not in the catalog lineup
today).
- anything else (notably `google` / Gemini) → empty vendor → **no parser**.
**Gemini is intentionally denied as unmeterable.** When the parser emits no
`llm.provider` for a Vertex publisher, `llm_router` returns
`llm_policy.unmeterable_publisher` (403) rather than forwarding the request
uncounted — serving it would bypass token / budget metering
([llm_router/middleware.go:144-162, 712-728](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)).
A Gemini parser would lift this restriction; until then the `google` publisher
is omitted from the catalog.
> Caveat: cross-region inference profiles in `eu` / `apac` carry a ~10% price
> premium that the base per-token rates do **not** model — cost annotations for
> those regions read low. Operators who need exact regional billing override
> the affected entries in `pricing.yaml`.
## AWS Bedrock (`bedrock_api`)
### Catalog entry
`KindProvider`, upstream host `bedrock-runtime.<region>.amazonaws.com`. Metered
models are the Anthropic-on-Bedrock lineup (`anthropic.claude-*`) plus Amazon
Nova and Llama 3.3 entries
([catalog.go:300-332](../../../management/server/agentnetwork/catalog/catalog.go)).
Anthropic-on-Bedrock reuses the first-party Claude prices (with additive cache
buckets); Nova / Llama report no cache, so cost is `input + output`.
### Credential — static bearer token
Bedrock uses the **AWS Bedrock API key** as a static bearer. The operator sets
the provider `api_key` directly (no `keyfile::` prefix); the catalog template
is `Authorization: Bearer ${API_KEY}`
([catalog.go:306-307](../../../management/server/agentnetwork/catalog/catalog.go)).
No token minting — the synthesiser substitutes the key into the template and
the router injects the resulting `Authorization` header after stripping inbound
vendor auth (including client-supplied AWS SigV4 material: `X-Amz-Date`,
`X-Amz-Security-Token`, `X-Amz-Content-Sha256`, see `strippedAuthHeaders`).
### Model id form — cross-region inference profiles
Bedrock model ids in the request path must be the cross-region
**inference-profile** form, e.g.
`eu.anthropic.claude-sonnet-4-5-20250929-v1:0`. The bare
`anthropic.claude-…` id is rejected by AWS. `normalizeBedrockModel`
([llm_request_parser/middleware.go:398-414](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go))
strips the region prefix (`us.` / `eu.` / `apac.` / `global.`), an optional ARN
wrapper, and the `-YYYYMMDD-vN[:N]` version/throughput suffix so the normalised
id (`anthropic.claude-sonnet-4-5`) matches the catalog/pricing key.
### Supported endpoints + actions
`/model/{modelId}/{action}` where action ∈ `invoke`,
`invoke-with-response-stream`, `converse`, `converse-stream`
([llm_request_parser/middleware.go:363-390](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)).
`invoke` / `converse` are non-streaming; the `-stream` actions set the streaming
flag.
- **InvokeModel** body uses the vendor-native shape — for Anthropic that means
`"anthropic_version":"bedrock-2023-05-31"` and snake_case usage with additive
cache buckets.
- **Converse** uses the unified camelCase shape with a precomputed `totalTokens`.
- The `BedrockParser` reads both shapes on the response leg
([bedrock.go](../../../proxy/internal/llm/bedrock.go)); the request parser
doesn't need to distinguish them (`ParseRequest` is a no-op — model + stream
come from the path).
### Streaming — AWS binary event-stream
The `-stream` actions return `application/vnd.amazon.eventstream` (the AWS
binary event-stream framing), and streaming **is metered**.
`accumulateBedrockStream`
([llm_response_parser/streaming_bedrock.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock.go))
decodes the frames with `aws-sdk-go-v2/aws/protocol/eventstream`:
- InvokeModel `chunk` frames wrap a base64 `{"bytes":…}` payload carrying a
vendor-native (Anthropic) stream event — folded through the shared Anthropic
stream accumulator.
- Converse `contentBlockDelta` frames carry text; the trailing `metadata` frame
carries the final usage block.
- A truncated stream (cut at the body-tap capture cap) decodes best-effort:
frames up to the cut are applied and partial usage is returned.
### Optional `/bedrock` gateway-namespace prefix
Clients may place an optional `/bedrock` prefix before the native path
(`/bedrock/model/{modelId}/{action}`) to disambiguate Bedrock from other
providers that also use `/model/...`. Both the request parser
(`trimBedrockNamespace`) and the router (`splitBedrockNamespace`) accept it.
When the prefix is present, the router sets
`RewriteUpstream.StripPathPrefix = "/bedrock"` so the **native** path
(`/model/...`) is what reaches `bedrock-runtime.<region>.amazonaws.com`
([llm_router/middleware.go:168-184, 320-348](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)).
## Model allowlist on path-routed providers
Because the model lives in the URL rather than the body, a path-routed provider
credential could otherwise be used for any model the upstream supports. The
router still enforces the route's `Models` allowlist via `matchPathRoute`
([llm_router/middleware.go:370-416](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)):
1. Filter to routes of the matching style (`Vertex` / `Bedrock`).
2. Filter to routes whose `AllowedGroupIDs` authorise the caller's groups
(else `no_authorised_provider`).
3. Filter to routes that **claim the requested model**. As with body-routed
providers, an **empty `Models` list = catch-all** (serve any model);
a non-empty list serves only the listed models (else `model_not_routable`).
4. Multiple survivors disambiguate by longest `UpstreamPath` prefix match.
So an operator who lists explicit models on a Vertex/Bedrock provider gets a
hard allowlist; an operator who leaves `Models` empty accepts every model the
upstream serves (still subject to the unmeterable-publisher gate on Vertex).
Model-less OpenAI endpoints (`GET /v1/models`) are **never** routed to a
Vertex/Bedrock provider — `matchModelless` skips path-routed routes
([llm_router/middleware.go:427-462](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
so a model-listing call can't be rewritten onto an upstream that would 404 it.
## Catalog ↔ pricing cross-check
Catalog prices and context windows are cross-checked against LiteLLM's
`model_prices_and_context_window.json`. The proxy's embedded
`defaults_pricing.yaml` covers **every metered first-party model** the catalog
enumerates — guarded by
`TestDefaultTable_FirstPartyModelCoverage`
([pricing/defaults_coverage_test.go](../../../proxy/internal/llm/pricing/defaults_coverage_test.go)),
which fails if a catalog model has no embedded price. Bedrock entries are keyed
by the **normalised** id the request parser emits (region prefix + version
suffix stripped). Vertex Claude carries no Bedrock-style prefix, so it prices
straight off the `anthropic` block.
## Things to scrutinise
**Security.** The Vertex service-account key is never forwarded — only a minted
short-lived bearer. Confirm the key material stays out of access logs (it lives
on `ProviderRoute.GCPServiceAccountKeyB64`, not in any emitted metadata key).
The unmeterable-publisher deny is the only thing standing between an
operator-misconfigured Vertex provider and unmetered Gemini traffic; verify
`vertexPublisherVendor` stays conservative (deny by default for unknown
publishers).
**Correctness.** `normalizeBedrockModel` is the join between the wire id and the
pricing key — a model that normalises to something not in `defaults_pricing.yaml`
meters at `cost.skipped=unknown_model` rather than failing the request. The
`/bedrock` prefix strip must run on both the parser side (so the model is
extracted) and the router side (so the upstream path is native); a regression in
either silently breaks the other.
**Metering caveats.** eu/apac cross-region Bedrock + Vertex profiles carry a
~10% premium not modelled by base pricing — flagged in both the catalog comment
and `defaults_pricing.yaml`. Operators needing exact regional billing override
the relevant entries.
## Cross-references
- Router + request-parser detail: [31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md)
- Bedrock parser + pricing + SSE / event-stream: [32-proxy-llm-parsers.md](32-proxy-llm-parsers.md)
- Catalog → route synthesis + `keyfile::` handling: [21-management-agentnetwork.md](21-management-agentnetwork.md)
- Overview: [../00-overview.md](../00-overview.md)

View File

@@ -1,78 +0,0 @@
# Privileged tests
Some tests in this repo need `root` or mutate host network state: they create
TUN/WireGuard interfaces, open netlink/raw sockets, run eBPF programs, or shell
out to `ip`/`iptables`/`nft`/`ifconfig`/`route`. Running them on a developer
machine would require `sudo` and could leave stray interfaces or routes behind.
These tests are gated behind the **`privileged` build tag** so the default test
run is host-safe.
## Running tests
```bash
# Host-safe: excludes privileged tests. Runs as a normal user, no sudo.
make test-unit
# equivalently:
go test -tags devcert ./...
# Privileged suite: runs the privileged-tagged tests inside a
# --privileged --cap-add=NET_ADMIN container (requires Docker).
make test-privileged
# Narrow the container run to a single test / package:
PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
```
`PRIV_RUN` adds a `-run` test-name filter and `PRIV_PKGS` overrides the package
list; both are optional and default to the full privileged suite.
`make test-privileged` invokes the `ory/dockertest` harness in
`client/testutil/privileged/`. The harness:
1. Skips immediately when it detects it is already inside the container
(`DOCKER_CI=true`), so the privileged tests run in place instead of recursing.
2. Otherwise spins up a `golang:1.25-alpine` container (matching CI),
bind-mounts the repo and the host Go build/module caches, installs the
required packages, and runs `go test -tags 'devcert privileged'` over the
client packages.
3. Streams the container's output to the test log and fails if the suite fails.
## Adding a privileged test
A test is privileged if it does any of:
- creates a real interface via `iface.NewWGIFace(...).Create()`,
- opens a netlink or raw socket that hard-fails without `CAP_NET_ADMIN`,
- runs an eBPF program (`ebpf.*.Listen()`),
- shells out to `ip`, `iptables`, `nft`, `ifconfig`, or `route` to change state.
Add the tag to the **top** of the file, combined with any existing platform
constraint:
```go
//go:build privileged && linux
package foo
```
If a file mixes privileged and pure-logic tests, **split it**: keep the pure
tests (and any shared data — type/var declarations, table-driven `testCases`,
helper interfaces) in an untagged file, and move the privileged tests into a
`*_privileged_test.go` file with the tag. Shared declarations must stay untagged,
otherwise the unprivileged files in the package will not compile.
Always verify both build modes compile on every target platform:
```bash
go vet -tags devcert ./...
go vet -tags 'devcert privileged' ./...
```
## CI
- The `Client / Unit` job runs `go test -tags devcert` with **no** `sudo` — only
host-safe tests.
- The `Client (Docker) / Unit` job runs `go test -tags 'devcert privileged'`
inside a `--privileged --cap-add=NET_ADMIN` container, which is where the
privileged tests actually execute.

11
go.mod
View File

@@ -35,6 +35,7 @@ require (
github.com/DeRuina/timberjack v1.4.2
github.com/awnumar/memguard v0.23.0
github.com/aws/aws-sdk-go-v2 v1.38.3
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1
github.com/aws/aws-sdk-go-v2/config v1.31.6
github.com/aws/aws-sdk-go-v2/credentials v1.18.10
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
@@ -78,12 +79,10 @@ require (
github.com/mdp/qrterminal/v3 v3.2.1
github.com/miekg/dns v1.1.72
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/moby/moby/api v1.54.1
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/ory/dockertest/v4 v4.0.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203
@@ -147,7 +146,7 @@ require (
dario.cat/mergo v1.0.1 // indirect
filippo.io/edwards25519 v1.1.1 // indirect
github.com/AppsFlyer/go-sundheit v0.6.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/Azure/go-ntlmssp v0.1.0 // indirect
github.com/BurntSushi/toml v1.5.0 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
@@ -158,7 +157,6 @@ require (
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect
github.com/awnumar/memcall v0.4.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6 // indirect
@@ -179,8 +177,6 @@ require (
github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
@@ -275,12 +271,11 @@ require (
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/moby/client v0.4.0 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.5.0 // indirect
github.com/moby/sys/user v0.3.0 // indirect
github.com/moby/sys/userns v0.1.0 // indirect
github.com/moby/term v0.5.2 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect

24
go.sum
View File

@@ -23,8 +23,8 @@ github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=
github.com/AppsFlyer/go-sundheit v0.6.0/go.mod h1:LDdBHD6tQBtmHsdW+i1GwdTt6Wqc0qazf5ZEJVTbTME=
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A=
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
@@ -117,10 +117,6 @@ github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
@@ -484,10 +480,6 @@ github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zx
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/moby/api v1.54.1 h1:TqVzuJkOLsgLDDwNLmYqACUuTehOHRGKiPhvH8V3Nn4=
github.com/moby/moby/api v1.54.1/go.mod h1:+RQ6wluLwtYaTd1WnPLykIDPekkuyD/ROWQClE83pzs=
github.com/moby/moby/client v0.4.0 h1:S+2XegzHQrrvTCvF6s5HFzcrywWQmuVnhOXe2kiWjIw=
github.com/moby/moby/client v0.4.0/go.mod h1:QWPbvWchQbxBNdaLSpoKpCdf5E+WxFAgNHogCWDoa7g=
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
github.com/moby/sys/sequential v0.5.0 h1:OPvI35Lzn9K04PBbCLW0g4LcFAJgHsvXsRyewg5lXtc=
@@ -496,8 +488,8 @@ github.com/moby/sys/user v0.3.0 h1:9ni5DlcW5an3SvRSx4MouotOygvzaXbaSrc/wGDFWPo=
github.com/moby/sys/user v0.3.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
@@ -550,8 +542,6 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/ory/dockertest/v4 v4.0.0 h1:i19aFsO/VXE0VrMk4ifnKW4G/KIJ93PCjLOslxXoPME=
github.com/ory/dockertest/v4 v4.0.0/go.mod h1:b5Ofu8VIxWNhXFvQcLu17pRNQdoUBKtXBW74G4Ygzx8=
github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs=
github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
@@ -983,13 +973,11 @@ gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDa
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM=
howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk=
pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=

View File

@@ -398,7 +398,42 @@ configure_domain() {
return 0
}
apply_agent_network_preset() {
# Agent-network turnkey install: built-in Traefik + NetBird Proxy with
# NB_PROXY_PRIVATE=true, dashboard locked to agent-network-only mode.
# Bypasses every reverse-proxy / proxy / CrowdSec prompt. The only
# inputs we still need from the operator are the domain (handled by
# configure_domain via NETBIRD_DOMAIN env var or interactive prompt)
# and the ACME email — both honor env vars first and fall back to a
# prompt only when unset. CrowdSec is intentionally off.
REVERSE_PROXY_TYPE="0"
ENABLE_PROXY="true"
ENABLE_CROWDSEC="false"
if [[ -n "${NETBIRD_LETSENCRYPT_EMAIL}" ]]; then
TRAEFIK_ACME_EMAIL="${NETBIRD_LETSENCRYPT_EMAIL}"
else
TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email)
fi
echo "" > /dev/stderr
echo "Agent-network preset enabled (NETBIRD_AGENT_NETWORK=true):" > /dev/stderr
echo " - reverse proxy: built-in Traefik" > /dev/stderr
echo " - NetBird Proxy: enabled with NB_PROXY_PRIVATE=true" > /dev/stderr
echo " - dashboard: NETBIRD_AGENT_NETWORK_ONLY=true" > /dev/stderr
echo " - CrowdSec: disabled" > /dev/stderr
echo " - Let's Encrypt email: ${TRAEFIK_ACME_EMAIL}" > /dev/stderr
echo "" > /dev/stderr
}
configure_reverse_proxy() {
# Short-circuit: agent-network preset locks every reverse-proxy /
# proxy / CrowdSec choice and bypasses the interactive prompts.
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
apply_agent_network_preset
return 0
fi
# Prompt for reverse proxy type
REVERSE_PROXY_TYPE=$(read_reverse_proxy_type)
@@ -910,6 +945,15 @@ NGINX_SSL_PORT=443
# Letsencrypt
LETSENCRYPT_DOMAIN=none
EOF
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
cat <<EOF
# Agent-network preset: dashboard hides the standard NetBird surfaces
# and exposes only the AI Observability + agent-network configuration
# pages. Paired with NB_PROXY_PRIVATE=true on the proxy side.
NETBIRD_AGENT_NETWORK_ONLY=true
EOF
fi
return 0
}
@@ -946,6 +990,17 @@ NB_PROXY_PROXY_PROTOCOL=true
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
EOF
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
cat <<EOF
# Agent-network preset: turn the proxy into the private reverse-proxy
# ingress for agent-network synth services. Disables the public-facing
# surface so the proxy serves only synth-generated routes (the
# llm_router-driven LLM endpoints) and the per-account inbound
# listeners on the embedded netstack.
NB_PROXY_PRIVATE=true
EOF
fi
if [[ "$ENABLE_CROWDSEC" == "true" && -n "$CROWDSEC_BOUNCER_KEY" ]]; then
cat <<EOF
NB_PROXY_CROWDSEC_API_URL=http://crowdsec:8080

View File

@@ -116,6 +116,24 @@ func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, p
c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer)
}
// injectAllProxyPolicies prepares an account for the per-peer network-map
// computation. It prepends the in-memory agent-network services synthesised
// from the account's current provider/policy state to account.Services so
// the existing InjectProxyPolicies + injectPrivateServicePolicies walks pick
// them up alongside persisted reverse-proxy services. Synthesised services
// are never persisted; the account is loaded fresh per cycle so re-prepending
// is safe and idempotent. Accounts without agent-network providers get an
// empty synth slice — no behaviour change.
func (c *Controller) injectAllProxyPolicies(ctx context.Context, account *types.Account) {
synth, err := c.repo.SynthesizeAgentNetworkServices(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Warnf("synthesise agent-network services for account %s: %v", account.Id, err)
} else if len(synth) > 0 {
account.Services = append(synth, account.Services...)
}
account.InjectProxyPolicies(ctx)
}
func (c *Controller) CountStreams() int {
return c.peersUpdateManager.CountStreams()
}
@@ -150,7 +168,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
account.InjectProxyPolicies(ctx)
c.injectAllProxyPolicies(ctx, account)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
@@ -281,7 +299,15 @@ func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID s
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
account.InjectProxyPolicies(ctx)
// The affected-peer path MUST mirror sendUpdateAccountPeers (line 171)
// here: injectAllProxyPolicies prepends the synthesised agent-network
// services BEFORE InjectProxyPolicies + private-service policies run.
// Previously this path called only account.InjectProxyPolicies, which
// skipped the synth-services prepend — so peer-level changes
// (proxy restart, embedded peer connect/disconnect) propagated a
// network map that omitted the synth DNS zone, and the agent kept
// resolving against the stale or absent record.
c.injectAllProxyPolicies(ctx, account)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
@@ -399,7 +425,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return fmt.Errorf("failed to get validated peers: %v", err)
}
account.InjectProxyPolicies(ctx)
c.injectAllProxyPolicies(ctx, account)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
@@ -497,7 +523,7 @@ func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID st
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s with reason %s/%s", len(peerIDs), accountID, util.GetCallerName(), reason.Operation, reason.Resource)
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
peerIDs: make(map[string]struct{}),
@@ -603,19 +629,17 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return nil, nil, 0, err
}
account.InjectProxyPolicies(ctx)
c.injectAllProxyPolicies(ctx, account)
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, 0, err
}
startPosture := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, peerID)
if err != nil {
return nil, nil, 0, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
@@ -876,7 +900,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
return nil, err
}
account.InjectProxyPolicies(ctx)
c.injectAllProxyPolicies(ctx, account)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()

View File

@@ -3,7 +3,9 @@ package controller
import (
"context"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -16,6 +18,10 @@ type Repository interface {
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error)
// SynthesizeAgentNetworkServices returns the in-memory reverse-proxy
// services synthesised from the account's agent-network provider/policy
// state. Empty for accounts without agent-network providers.
SynthesizeAgentNetworkServices(ctx context.Context, accountID string) ([]*service.Service, error)
}
type repository struct {
@@ -50,6 +56,10 @@ func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID s
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}
func (r *repository) SynthesizeAgentNetworkServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return agentnetwork.SynthesizeServices(ctx, r.store, accountID)
}
func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) {
return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
}

View File

@@ -0,0 +1,15 @@
package agentnetwork
import "github.com/netbirdio/netbird/management/server/affectedpeers"
// init registers the agent-network service synthesiser with the affectedpeers
// resolver. Agent-network reverse-proxy services are synthesised on demand and
// never persisted, so the resolver can't load them from the store; without them
// it can't fold the embedded proxy peer into the affected set on a client
// group/peer change, and the proxy never learns a newly authorised client until
// it reconnects. Registered here (rather than via a direct
// affectedpeers→agentnetwork import) to avoid an import cycle
// (agentnetwork → account → affectedpeers).
func init() {
affectedpeers.SetAgentNetworkSynthesizer(SynthesizeServices)
}

View File

@@ -0,0 +1,749 @@
// Package catalog defines the static set of Agent Network providers
// recognized by the management server. The catalog is consulted both to
// validate provider_id on create/update and to surface the available
// providers (and their models) to the dashboard.
package catalog
import "github.com/netbirdio/netbird/shared/management/http/api"
// Model is the in-memory representation of a catalog model.
type Model struct {
ID string
Label string
InputPer1k float64
OutputPer1k float64
ContextWindow int
}
// ProviderKind groups catalog entries for UI presentation. The split
// is semantic, not technical:
// - KindProvider: the upstream is a vendor's first-party API (OpenAI,
// Anthropic, Mistral, Bedrock, etc.) — NetBird talks straight to
// the model provider.
// - KindGateway: the upstream is itself a routing / aggregation layer
// in front of multiple providers (LiteLLM, Portkey, Helicone, …).
// These typically need NetBird identity stamped onto upstream
// requests so the gateway's analytics and budgets attribute to the
// real caller; that's what IdentityInjection is for.
// - KindCustom: the catch-all "OpenAI-compatible self-hosted endpoint"
// entry (vLLM, Ollama, custom inference servers).
//
// Frontend uses Kind to group the provider Select in the modal so an
// operator can spot at a glance which catalog entries proxy other
// providers vs. talk straight to one. Backend doesn't dispatch on Kind
// today; it's purely a presentation hint.
type ProviderKind string
const (
KindProvider ProviderKind = "provider"
KindGateway ProviderKind = "gateway"
KindCustom ProviderKind = "custom"
)
// Provider is the in-memory representation of a catalog provider.
type Provider struct {
ID string
Name string
Description string
DefaultHost string
// Kind groups this entry for UI presentation; see ProviderKind.
Kind ProviderKind
// AuthHeaderName is the HTTP header the provider's API expects
// the credential under (e.g. "Authorization" for OpenAI,
// "x-api-key" for Anthropic). Combined with AuthHeaderTemplate
// at synthesis time to inject the auth header on every upstream
// request.
AuthHeaderName string
AuthHeaderTemplate string
DefaultContentType string
BrandColor string
// ParserID names the proxy LLM parser surface this provider
// speaks (matches llm.Parser.ProviderName: "openai",
// "anthropic"). Multiple catalog ids may share a parser surface
// (e.g. azure_openai_api and mistral_api both speak the OpenAI
// shape). Empty when no parser is yet implemented for the
// surface — the proxy middleware then falls back to URL sniffing
// or skips request-side enrichment.
ParserID string
// IdentityInjection, when non-nil, instructs the proxy to stamp
// the caller's NetBird identity onto upstream requests under the
// configured header names. Used for gateways like LiteLLM that
// key budgets and attribution off request headers (the gateway
// otherwise has no way to learn which user / group made the call).
// The proxy strips the same header names from the inbound request
// before stamping ours, so an app can't spoof identity by setting
// these headers itself.
IdentityInjection *IdentityInjection
// ExtraHeaders is a catalog-declared list of additional per-
// provider routing/config headers the proxy stamps on every
// upstream request. Distinct from AuthHeaderName/Template (which
// always carries the API_KEY) and from IdentityInjection (caller
// identity). Each entry surfaces an optional input on the
// dashboard's provider modal whose value lives on the provider
// record's ExtraValues map (keyed by ExtraHeader.Name). Empty
// list = no extra inputs rendered. Used today by Portkey for
// "x-portkey-config: pc-..." (a saved-config id that resolves
// upstream provider + credentials on Portkey's hosted side).
ExtraHeaders []ExtraHeader
Models []Model
}
// ExtraHeader names a single optional per-provider routing/config
// header. Catalog declares N of these per provider type; the operator
// fills any subset on the provider record (see Provider.ExtraValues).
// At synth time, only entries with a non-empty operator value are
// stamped; the proxy's identity-inject middleware applies anti-spoof
// (Remove + Add) so a client can't supply these headers themselves.
//
// UI copy (label / help text / tooltip) for each known Name lives on
// the dashboard, not here — the backend's job is just to declare
// which wire headers are accepted. New provider needs an extra
// header? Add the Name here AND the matching UI copy on the dashboard.
type ExtraHeader struct {
// Name is the wire header name, e.g. "x-portkey-config".
Name string
}
// IdentityInjection describes how the proxy stamps NetBird identity onto
// upstream gateway requests. Exactly one shape must be set — they're
// mutually exclusive and dispatched by the inject middleware.
//
// Shape choice tracks the wire convention the upstream gateway uses,
// not the vendor name. New gateways with a known shape become a catalog
// entry, not a new code path.
type IdentityInjection struct {
// HeaderPair emits separate headers per identity dimension
// (end-user id, tags as CSV). LiteLLM and OpenAI-compatible
// self-hosted gateways that read identity from dedicated headers.
HeaderPair *HeaderPairInjection
// JSONMetadata emits a single header carrying a JSON object with
// reserved keys for user / groups / etc. Portkey, Helicone-style
// metadata headers, anything that wants a structured envelope.
JSONMetadata *JSONMetadataInjection
}
// HeaderPairInjection is the LiteLLM-style wire convention.
type HeaderPairInjection struct {
// Customizable, when true, marks the wire header names as
// operator-overridable: the dashboard surfaces EndUserIDHeader
// and TagsHeader as editable inputs (defaults shown as
// placeholders) and the synthesizer pulls the actual values from
// the provider record's IdentityHeader* fields rather than from
// these defaults. An empty operator value disables stamping for
// that dimension. Used today for Bifrost, whose log-metadata /
// telemetry header prefix (x-bf-lh-* vs x-bf-dim-*) is a
// per-operator choice; LiteLLM and similar gateways with a fixed
// wire protocol leave this false so the catalog defaults are
// authoritative.
Customizable bool
// EndUserIDHeader receives the caller's display identity (user
// email when the peer is attached to a user, else peer.Name),
// e.g. "x-litellm-end-user-id".
EndUserIDHeader string
// TagsHeader receives the caller's NetBird group display names
// as a CSV, e.g. "x-litellm-tags".
TagsHeader string
// TagsInBody, when true, additionally writes the tag list into
// the request body's metadata.tags array (a JSON path the
// gateway parses for budget enforcement). LiteLLM only honours
// metadata.tags for tag-budget gating — its x-litellm-tags
// header path feeds spend tracking but bypasses
// _tag_max_budget_check entirely. Body inject is skipped when
// the request body is empty, truncated, non-JSON, or when an
// existing metadata field is a non-object value (defensive: we
// never clobber a client-supplied non-object). The header path
// remains a robust fallback for spend tracking in those cases.
TagsInBody bool
// EndUserIDInBody, when true, additionally writes the display
// identity into the request body's top-level "user" field (the
// OpenAI-standard end-user identifier). LiteLLM resolves the end
// user id from headers first then body, so for LiteLLM this is
// belt-and-suspenders. It matters when an OpenAI-compatible
// gateway downstream of LiteLLM (or OpenAI direct, bypassing
// LiteLLM) only reads the body, and as anti-spoof: client-
// supplied "user" values are overwritten with our trusted
// identity. Same skip rules as TagsInBody.
EndUserIDInBody bool
}
// JSONMetadataInjection is the Portkey-style wire convention: a single
// header carrying a JSON object. NetBird identity fields land under the
// configured reserved keys; missing keys (empty string) are skipped at
// emit time.
type JSONMetadataInjection struct {
// Customizable, when true, marks the JSON keys as operator-
// overridable. The dashboard surfaces UserKey and GroupsKey as
// editable inputs (the catalog values shown as placeholders) and
// the synthesizer pulls the actual JSON-key names from the
// provider record's IdentityHeader* fields. Same field reuse as
// HeaderPair's customizable path — the dimensions (user identity,
// groups) are the same, only the wire encoding differs (JSON key
// vs HTTP header name). An empty operator value disables emission
// for that dimension. Used today for Cloudflare AI Gateway, whose
// cf-aig-metadata header accepts arbitrary JSON keys; Portkey
// leaves this false because its keys are reserved by the Portkey
// schema.
Customizable bool
// Header is the wire header name carrying the JSON payload, e.g.
// "x-portkey-metadata".
Header string
// UserKey is the JSON key for the caller's display identity.
// Portkey reserves "_user" for this dimension.
UserKey string
// GroupsKey is the JSON key for the caller's NetBird groups,
// emitted as a CSV string value (Portkey requires string values).
GroupsKey string
// MaxValueLength caps each emitted JSON value, in bytes. Portkey
// enforces a 128-char limit per value; oversized values are
// truncated rather than failing the request. 0 disables the cap.
MaxValueLength int
}
// providers is the canonical list of supported Agent Network providers.
// Update this list together with the dashboard's PROVIDER_CATALOG.
var providers = []Provider{
{
ID: "openai_api",
Kind: KindProvider,
Name: "OpenAI API",
Description: "GPT, Responses API, and Embeddings",
DefaultHost: "api.openai.com",
AuthHeaderName: "Authorization",
AuthHeaderTemplate: "Bearer ${API_KEY}",
DefaultContentType: "application/json",
BrandColor: "#10A37F",
ParserID: "openai",
// Pricing + context windows cross-checked against LiteLLM's
// model_prices_and_context_window.json. Notable corrections from
// earlier values: o4-mini repriced from $4/$16 to $1.10/$4.40
// per MTok, gpt-4o from $5/$15 to $2.50/$10, and the GPT-5
// family context windows split between 1.05M for full-size
// models and 272K for mini/nano/codex variants.
Models: []Model{
{ID: "gpt-5.5", Label: "GPT-5.5", InputPer1k: 0.005, OutputPer1k: 0.030, ContextWindow: 1050000},
{ID: "gpt-5.5-pro", Label: "GPT-5.5 Pro", InputPer1k: 0.030, OutputPer1k: 0.180, ContextWindow: 1050000},
{ID: "gpt-5.4", Label: "GPT-5.4", InputPer1k: 0.0025, OutputPer1k: 0.015, ContextWindow: 1050000},
{ID: "gpt-5.4-pro", Label: "GPT-5.4 Pro", InputPer1k: 0.030, OutputPer1k: 0.180, ContextWindow: 1050000},
{ID: "gpt-5.4-mini", Label: "GPT-5.4 Mini", InputPer1k: 0.00075, OutputPer1k: 0.0045, ContextWindow: 272000},
{ID: "gpt-5.4-nano", Label: "GPT-5.4 Nano", InputPer1k: 0.0002, OutputPer1k: 0.00125, ContextWindow: 272000},
{ID: "gpt-5.3-codex", Label: "GPT-5.3 Codex", InputPer1k: 0.00175, OutputPer1k: 0.014, ContextWindow: 272000},
{ID: "gpt-5.3-chat-latest", Label: "GPT-5.3 Chat", InputPer1k: 0.00175, OutputPer1k: 0.014, ContextWindow: 128000},
{ID: "o4-mini", Label: "o4-mini", InputPer1k: 0.0011, OutputPer1k: 0.0044, ContextWindow: 200000},
{ID: "gpt-4.1", Label: "GPT-4.1", InputPer1k: 0.002, OutputPer1k: 0.008, ContextWindow: 1047576},
{ID: "gpt-4.1-mini", Label: "GPT-4.1 mini", InputPer1k: 0.0004, OutputPer1k: 0.0016, ContextWindow: 1047576},
{ID: "gpt-4.1-nano", Label: "GPT-4.1 nano", InputPer1k: 0.0001, OutputPer1k: 0.0004, ContextWindow: 1047576},
{ID: "gpt-4o", Label: "GPT-4o", InputPer1k: 0.0025, OutputPer1k: 0.010, ContextWindow: 128000},
{ID: "gpt-4o-mini", Label: "GPT-4o mini", InputPer1k: 0.00015, OutputPer1k: 0.0006, ContextWindow: 128000},
{ID: "gpt-4-turbo", Label: "GPT-4 Turbo", InputPer1k: 0.01, OutputPer1k: 0.03, ContextWindow: 128000},
{ID: "gpt-3.5-turbo", Label: "GPT-3.5 Turbo", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 16385},
{ID: "text-embedding-3-large", Label: "text-embedding-3-large", InputPer1k: 0.00013, OutputPer1k: 0, ContextWindow: 8191},
{ID: "text-embedding-3-small", Label: "text-embedding-3-small", InputPer1k: 0.00002, OutputPer1k: 0, ContextWindow: 8191},
},
},
{
ID: "anthropic_api",
Kind: KindProvider,
Name: "Anthropic API",
Description: "Claude Messages API",
DefaultHost: "api.anthropic.com",
AuthHeaderName: "x-api-key",
AuthHeaderTemplate: "${API_KEY}",
DefaultContentType: "application/json",
BrandColor: "#D97757",
ParserID: "anthropic",
// Per Anthropic's current model lineup. Pricing in USD per 1k
// tokens. Context windows: 4.6+ family is 1M; Haiku 4.5 stays at
// 200K. claude-3-7-sonnet and claude-3-5-haiku retired
// 2026-02-19 — dropped from the catalog. claude-opus-4-1
// deprecated, retires 2026-08-05 — kept until the cutover.
// claude-mythos-5 omitted: Project Glasswing access only, not a
// general-availability target. claude-fable-5 requires the
// account to be on >= 30-day data retention or all requests
// 400.
Models: []Model{
{ID: "claude-fable-5", Label: "Claude Fable 5", InputPer1k: 0.010, OutputPer1k: 0.050, ContextWindow: 1000000},
{ID: "claude-opus-4-8", Label: "Claude Opus 4.8", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
{ID: "claude-opus-4-7", Label: "Claude Opus 4.7", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
{ID: "claude-opus-4-6", Label: "Claude Opus 4.6", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
{ID: "claude-opus-4-1", Label: "Claude Opus 4.1 (deprecated, retires 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000},
{ID: "claude-sonnet-4-6", Label: "Claude Sonnet 4.6", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000},
{ID: "claude-sonnet-4-5", Label: "Claude Sonnet 4.5", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000},
{ID: "claude-haiku-4-5", Label: "Claude Haiku 4.5", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000},
},
},
{
ID: "azure_openai_api",
Kind: KindProvider,
Name: "Azure OpenAI API",
Description: "Azure-hosted OpenAI deployments",
DefaultHost: "<resource>.openai.azure.com",
AuthHeaderName: "api-key",
AuthHeaderTemplate: "${API_KEY}",
DefaultContentType: "application/json",
BrandColor: "#0078D4",
ParserID: "openai",
// Mirrors openai_api pricing — Azure resells OpenAI models at the
// same per-token rates, just under different deployment names.
Models: []Model{
{ID: "gpt-5.5", Label: "GPT-5.5 (Azure)", InputPer1k: 0.005, OutputPer1k: 0.030, ContextWindow: 1050000},
{ID: "gpt-5.4", Label: "GPT-5.4 (Azure)", InputPer1k: 0.0025, OutputPer1k: 0.015, ContextWindow: 1050000},
{ID: "gpt-5.4-mini", Label: "GPT-5.4 Mini (Azure)", InputPer1k: 0.00075, OutputPer1k: 0.0045, ContextWindow: 272000},
{ID: "gpt-5.4-nano", Label: "GPT-5.4 Nano (Azure)", InputPer1k: 0.0002, OutputPer1k: 0.00125, ContextWindow: 272000},
{ID: "o4-mini", Label: "o4-mini (Azure)", InputPer1k: 0.0011, OutputPer1k: 0.0044, ContextWindow: 200000},
{ID: "gpt-4.1", Label: "GPT-4.1 (Azure)", InputPer1k: 0.002, OutputPer1k: 0.008, ContextWindow: 1047576},
{ID: "gpt-4.1-mini", Label: "GPT-4.1 mini (Azure)", InputPer1k: 0.0004, OutputPer1k: 0.0016, ContextWindow: 1047576},
{ID: "gpt-4o", Label: "GPT-4o (Azure)", InputPer1k: 0.0025, OutputPer1k: 0.010, ContextWindow: 128000},
{ID: "gpt-4o-mini", Label: "GPT-4o mini (Azure)", InputPer1k: 0.00015, OutputPer1k: 0.0006, ContextWindow: 128000},
{ID: "gpt-35-turbo", Label: "GPT-3.5 Turbo (Azure)", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 16385},
},
},
{
ID: "bedrock_api",
Kind: KindProvider,
Name: "AWS Bedrock API",
Description: "Anthropic, Meta, Cohere via Bedrock",
DefaultHost: "bedrock-runtime.<region>.amazonaws.com",
AuthHeaderName: "Authorization",
AuthHeaderTemplate: "Bearer ${API_KEY}",
DefaultContentType: "application/json",
BrandColor: "#FF9900",
// Anthropic models on Bedrock take the anthropic.* prefix and
// follow the same lineup / pricing as the first-party Anthropic
// catalog entry above. claude-3-7-sonnet and claude-3-5-haiku
// were retired upstream on 2026-02-19 — dropped from the
// Bedrock list too. Amazon Nova entries cross-checked against
// LiteLLM (added Nova Micro + the new Nova 2 Lite preview).
// Llama 3.3 70B entry kept unchanged — LiteLLM tracks only
// per-region Llama 3 entries; standalone 3.3 not yet listed.
Models: []Model{
{ID: "anthropic.claude-opus-4-8", Label: "Claude Opus 4.8 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
{ID: "anthropic.claude-opus-4-7", Label: "Claude Opus 4.7 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
{ID: "anthropic.claude-opus-4-6", Label: "Claude Opus 4.6 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
{ID: "anthropic.claude-opus-4-1", Label: "Claude Opus 4.1 (Bedrock, deprecated 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000},
{ID: "anthropic.claude-sonnet-4-6", Label: "Claude Sonnet 4.6 (Bedrock)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000},
{ID: "anthropic.claude-sonnet-4-5", Label: "Claude Sonnet 4.5 (Bedrock)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000},
{ID: "anthropic.claude-haiku-4-5", Label: "Claude Haiku 4.5 (Bedrock)", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000},
{ID: "meta.llama3-3-70b-instruct", Label: "Llama 3.3 70B (Bedrock)", InputPer1k: 0.00072, OutputPer1k: 0.00072, ContextWindow: 128000},
{ID: "amazon.nova-2-lite", Label: "Amazon Nova 2 Lite (Bedrock, preview)", InputPer1k: 0.0003, OutputPer1k: 0.0025, ContextWindow: 1000000},
{ID: "amazon.nova-pro", Label: "Amazon Nova Pro (Bedrock)", InputPer1k: 0.0008, OutputPer1k: 0.0032, ContextWindow: 300000},
{ID: "amazon.nova-lite", Label: "Amazon Nova Lite (Bedrock)", InputPer1k: 0.00006, OutputPer1k: 0.00024, ContextWindow: 300000},
{ID: "amazon.nova-micro", Label: "Amazon Nova Micro (Bedrock)", InputPer1k: 0.000035, OutputPer1k: 0.00014, ContextWindow: 128000},
},
},
{
ID: "vertex_ai_api",
Kind: KindProvider,
Name: "Google Vertex AI API",
Description: "Anthropic Claude models hosted on Vertex AI",
DefaultHost: "<region>-aiplatform.googleapis.com",
AuthHeaderName: "Authorization",
AuthHeaderTemplate: "Bearer ${API_KEY}",
DefaultContentType: "application/json",
BrandColor: "#4285F4",
// Vertex carries the model in the URL path and authenticates with a
// service-account-minted OAuth token (api_key = "keyfile::<base64 SA>").
// Only Anthropic-on-Vertex is metered today: the request parser maps the
// anthropic publisher to the Anthropic parser, so the lineup + prices
// mirror the first-party Anthropic catalog (LiteLLM vertex_ai/claude-*
// confirms the same per-token rates; cross-region profiles in eu/apac
// carry a ~10% premium that base pricing does not model). Gemini (the
// google publisher) is intentionally omitted until a Gemini parser
// exists — the router denies unmeterable publishers rather than forward
// them uncounted.
Models: []Model{
{ID: "claude-fable-5", Label: "Claude Fable 5 (Vertex)", InputPer1k: 0.010, OutputPer1k: 0.050, ContextWindow: 1000000},
{ID: "claude-opus-4-8", Label: "Claude Opus 4.8 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
{ID: "claude-opus-4-7", Label: "Claude Opus 4.7 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
{ID: "claude-opus-4-6", Label: "Claude Opus 4.6 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
{ID: "claude-opus-4-1", Label: "Claude Opus 4.1 (Vertex, deprecated 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000},
{ID: "claude-sonnet-4-6", Label: "Claude Sonnet 4.6 (Vertex)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000},
{ID: "claude-sonnet-4-5", Label: "Claude Sonnet 4.5 (Vertex)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000},
{ID: "claude-haiku-4-5", Label: "Claude Haiku 4.5 (Vertex)", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000},
},
},
{
ID: "mistral_api",
Kind: KindProvider,
Name: "Mistral API",
Description: "Mistral cloud API",
DefaultHost: "api.mistral.ai",
AuthHeaderName: "Authorization",
AuthHeaderTemplate: "Bearer ${API_KEY}",
DefaultContentType: "application/json",
BrandColor: "#FF7000",
ParserID: "openai",
// Pricing + context windows cross-checked against LiteLLM. Key
// gotchas the marketing page hides:
// - `mistral-medium-latest` aliases to Medium 3.1 ($0.40/$2),
// NOT Medium 3.5 ($1.50/$7.50). Catalog exposes both.
// - `mistral-large-latest` aliases to Large 3 — 262K context,
// cheaper than Medium 3.5.
// - Magistral models are tuned for reasoning but cap context
// at only 40K (vs 128K-262K elsewhere).
// - `codestral-latest` still routes to the old 2405 build
// ($1/$3) per LiteLLM; the newer codestral-2508 is both
// cheaper and longer-context. Both exposed.
// - Pixtral was folded into the main Large/Medium series; no
// standalone vision entry.
Models: []Model{
{ID: "mistral-large-latest", Label: "Mistral Large 3", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 262144},
{ID: "mistral-medium-latest", Label: "Mistral Medium 3.1", InputPer1k: 0.0004, OutputPer1k: 0.002, ContextWindow: 131072},
{ID: "mistral-medium-3-5", Label: "Mistral Medium 3.5", InputPer1k: 0.0015, OutputPer1k: 0.0075, ContextWindow: 262144},
{ID: "mistral-small-latest", Label: "Mistral Small 3.2", InputPer1k: 0.00006, OutputPer1k: 0.00018, ContextWindow: 131072},
{ID: "magistral-medium-latest", Label: "Magistral Medium (reasoning)", InputPer1k: 0.002, OutputPer1k: 0.005, ContextWindow: 40000},
{ID: "magistral-small-latest", Label: "Magistral Small (reasoning)", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 40000},
{ID: "devstral-medium-latest", Label: "Devstral Medium 2 (coding)", InputPer1k: 0.0004, OutputPer1k: 0.002, ContextWindow: 256000},
{ID: "devstral-small-latest", Label: "Devstral Small 2 (coding)", InputPer1k: 0.0001, OutputPer1k: 0.0003, ContextWindow: 256000},
{ID: "codestral-2508", Label: "Codestral 2508", InputPer1k: 0.0003, OutputPer1k: 0.0009, ContextWindow: 256000},
{ID: "codestral-latest", Label: "Codestral (legacy 2405)", InputPer1k: 0.001, OutputPer1k: 0.003, ContextWindow: 32000},
{ID: "ministral-3-14b-2512", Label: "Ministral 3 14B", InputPer1k: 0.0002, OutputPer1k: 0.0002, ContextWindow: 262144},
{ID: "ministral-8b-latest", Label: "Ministral 8B", InputPer1k: 0.00015, OutputPer1k: 0.00015, ContextWindow: 262144},
{ID: "ministral-3-3b-2512", Label: "Ministral 3 3B", InputPer1k: 0.0001, OutputPer1k: 0.0001, ContextWindow: 131072},
{ID: "mistral-embed", Label: "Mistral Embed", InputPer1k: 0.0001, OutputPer1k: 0, ContextWindow: 8192},
},
},
{
ID: "litellm_proxy",
Kind: KindGateway,
Name: "LiteLLM Proxy",
Description: "Bring your own LiteLLM proxy with NetBird identity stamped on every request",
DefaultHost: "",
AuthHeaderName: "Authorization",
AuthHeaderTemplate: "Bearer ${API_KEY}",
DefaultContentType: "application/json",
BrandColor: "#0EA5E9",
ParserID: "openai",
// IdentityInjection requires a LiteLLM virtual key minted with
// metadata.allow_client_tags=true; the master key silently drops
// caller tags. Tags go out via both the x-litellm-tags header and
// body metadata.tags: LiteLLM enforces budgets from the body only,
// so the header is the spend-tracking fallback when body injection
// can't run. See the Agent Network provider docs for key setup.
IdentityInjection: &IdentityInjection{
HeaderPair: &HeaderPairInjection{
EndUserIDHeader: "x-litellm-end-user-id",
TagsHeader: "x-litellm-tags",
TagsInBody: true,
EndUserIDInBody: true,
},
},
Models: []Model{},
},
{
ID: "portkey",
Kind: KindGateway,
Name: "Portkey AI Gateway",
Description: "Portkey AI Gateway with NetBird identity stamped via x-portkey-metadata",
DefaultHost: "api.portkey.ai",
// Portkey hosted requires x-portkey-api-key (account key)
// plus a routing decision per request. The simplest routing
// path is a saved Portkey config id stamped via
// x-portkey-config — operators paste the pc-... id once and
// Portkey resolves the upstream provider + virtual key from
// it. ExtraHeaders below surfaces the input. Alternative:
// callers author "@org/model" in the body; both flows
// coexist (per-request authoring still works without a
// configured value).
AuthHeaderName: "x-portkey-api-key",
AuthHeaderTemplate: "${API_KEY}",
DefaultContentType: "application/json",
BrandColor: "#FF5C00",
ParserID: "openai",
IdentityInjection: &IdentityInjection{
JSONMetadata: &JSONMetadataInjection{
Header: "x-portkey-metadata",
UserKey: "_user",
GroupsKey: "groups",
MaxValueLength: 128,
},
},
ExtraHeaders: []ExtraHeader{
{Name: "x-portkey-config"},
},
Models: []Model{},
},
{
ID: "bifrost",
Kind: KindGateway,
Name: "Bifrost",
Description: "Maxim AI's Bifrost gateway. Point upstream URL at /openai/v1 or /anthropic/v1 on your Bifrost host depending on which body shape your apps use.",
DefaultHost: "",
AuthHeaderName: "Authorization",
AuthHeaderTemplate: "Bearer ${API_KEY}",
DefaultContentType: "application/json",
BrandColor: "#7C3AED",
// ParserID empty: the proxy's request parser sniffs the URL
// path. Bifrost's /openai/v1/... contains "/v1/chat/completions"
// (matches OpenAIParser.DetectFromURL); /anthropic/v1/messages
// contains "/v1/messages" (matches AnthropicParser). Operators
// who paste a different prefix get no usage parsing and the
// cost meter skips with skipMissingProvider — degraded but
// non-fatal.
ParserID: "",
// Identity-injection headers are operator-customisable. The
// HeaderPair values below are PLACEHOLDERS surfaced by the
// dashboard; the actual values stamped on the wire come from
// the provider record's IdentityHeaderUserID /
// IdentityHeaderGroups fields. An empty operator value
// disables stamping for that dimension (the inject middleware
// already no-ops on empty header names). Defaulting to the
// x-bf-dim- family so the values land in Bifrost's
// Prometheus/OTEL pipelines when the operator declares the
// label names in their client.prometheus_labels config — see
// docs.getbifrost.ai/features/telemetry. Operators who use
// the always-on x-bf-lh- log-metadata family (no Bifrost-side
// declaration required) just edit the inputs.
//
// Bifrost virtual keys (sk-bf-*) ride Authorization: Bearer.
// Operators provision the VK on their Bifrost (UI /
// config.json / POST /api/governance/virtual-keys) and paste
// the returned sk-bf-... as ${API_KEY}. Pin v1.4+ to avoid
// the v1.3.0 x-bf-vk regression (maximhq/bifrost#632).
IdentityInjection: &IdentityInjection{
HeaderPair: &HeaderPairInjection{
EndUserIDHeader: "x-bf-dim-netbird_user_id",
TagsHeader: "x-bf-dim-netbird_groups",
Customizable: true,
},
},
Models: []Model{},
},
{
ID: "cloudflare_ai_gateway",
Kind: KindGateway,
Name: "Cloudflare AI Gateway",
Description: "Cloudflare AI Gateway. Operator pastes the gateway URL (with the upstream provider slug like /openai or /anthropic so the URL sniffer dispatches to the right parser) and a per-gateway authentication token. Recommended setup is BYOK / Stored Keys: Cloudflare manages the upstream provider credential and the gateway token is the only secret NetBird needs.",
DefaultHost: "",
AuthHeaderName: "cf-aig-authorization",
AuthHeaderTemplate: "Bearer ${API_KEY}",
DefaultContentType: "application/json",
BrandColor: "#F38020",
// ParserID empty: like Bifrost, the proxy's parser-detect
// sniffs the URL path. /openai/... contains the OpenAI hint
// substrings; /anthropic/v1/messages contains /v1/messages
// (matches AnthropicParser). The /compat universal endpoint
// also speaks OpenAI shape so OpenAIParser handles it.
// Operators who paste a different prefix degrade to no-cost
// (skipMissingProvider) but the request still flows.
ParserID: "",
// cf-aig-metadata is a single header carrying a JSON object;
// up to five string/number/boolean values per request. NetBird
// occupies two slots (user id + groups CSV) and leaves three
// for operator-added context. JSON keys are operator-
// customisable so Cloudflare-side log filters can use the
// operator's existing label conventions instead of NetBird's
// defaults — hence Customizable=true. The dashboard surfaces
// the catalog values as placeholders; only the values stored
// on the provider record's IdentityHeader* fields land on the
// wire (empty operator value = key is omitted from the JSON,
// since applyJSONMetadata already skips empty keys).
IdentityInjection: &IdentityInjection{
JSONMetadata: &JSONMetadataInjection{
Header: "cf-aig-metadata",
UserKey: "netbird_user_id",
GroupsKey: "netbird_groups",
Customizable: true,
// Cloudflare's docs don't specify a per-value cap;
// leaving 0 disables the truncate path. Header-level
// constraint is "5 entries max" rather than length.
MaxValueLength: 0,
},
},
Models: []Model{},
},
{
ID: "vercel_ai_gateway",
Kind: KindGateway,
Name: "Vercel AI Gateway",
Description: "Vercel's unified API for hundreds of models. Single endpoint, OpenAI-compatible body, model dispatch via prefix (openai/..., anthropic/..., google/..., xai/...). Per-user / per-tag attribution lands in Vercel's Custom Reporting API and observability dashboard.",
DefaultHost: "",
AuthHeaderName: "Authorization",
AuthHeaderTemplate: "Bearer ${API_KEY}",
DefaultContentType: "application/json",
BrandColor: "#000000",
// Vercel always speaks OpenAI shape on /v1/chat/completions —
// the model prefix in the body picks the upstream provider.
// No URL sniffing needed; pin the parser directly.
ParserID: "openai",
// HeaderPair shape with fixed wire names dictated by Vercel's
// Custom Reporting API contract. Customizable=false because
// renaming the headers makes Vercel silently stop attributing
// — the gateway's reporting endpoint only matches its own
// header names. Same fixed-protocol position as LiteLLM.
//
// Caveats operators should know:
// - up to 10 tags total per request (deduped); 11+ → HTTP 400
// - each tag must be 1-64 chars
// - user up to 256 chars (NetBird user emails fit)
// - $0.075 per 1k unique user/tag values written
// We don't enforce the caps in the inject middleware today;
// operators in groups beyond the 10-tag limit will see Vercel
// 400s and need to re-scope their group memberships.
IdentityInjection: &IdentityInjection{
HeaderPair: &HeaderPairInjection{
EndUserIDHeader: "ai-reporting-user",
TagsHeader: "ai-reporting-tags",
},
},
Models: []Model{},
},
{
ID: "openrouter",
Kind: KindGateway,
Name: "OpenRouter",
Description: "OpenRouter's unified API for hundreds of models. Single endpoint at openrouter.ai/api/v1, OpenAI-compatible body, model dispatch via prefix (anthropic/claude-..., openai/gpt-..., google/gemini-..., etc.). Per-user attribution lands in OpenRouter's analytics via the OpenAI-standard `user` body field; OpenRouter has no groups / tags dimension at request time.",
DefaultHost: "openrouter.ai/api/v1",
AuthHeaderName: "Authorization",
AuthHeaderTemplate: "Bearer ${API_KEY}",
DefaultContentType: "application/json",
BrandColor: "#6F4FF2",
// OpenRouter is single-endpoint OpenAI-shape on /api/v1/chat/completions —
// model prefix in the body picks the upstream provider.
// Pinning the parser saves URL sniffing.
ParserID: "openai",
// HeaderPair shape with EndUserIDInBody as the only active
// dimension. OpenRouter's per-user attribution is the
// OpenAI-standard `user` body field, not a header — and
// OpenRouter offers no per-request groups / tags dimension at
// all. Customizable=false because the field name is locked by
// OpenAI's spec; renaming would just defeat the inject.
IdentityInjection: &IdentityInjection{
HeaderPair: &HeaderPairInjection{
EndUserIDInBody: true,
},
},
// HTTP-Referer + X-OpenRouter-Title surface in OpenRouter's
// app rankings and per-app analytics. Operators paste their
// own app URL + display name on the provider record so their
// requests show under their brand instead of "no app". Both
// are static per-deployment, not per-request, hence the
// ExtraHeaders mechanism (operator-typed value, stamped on
// every request to this provider). Skip X-OpenRouter-Categories
// for now — the marketplace-categories dimension is
// niche-enough that we'd add it on demand.
ExtraHeaders: []ExtraHeader{
{Name: "HTTP-Referer"},
{Name: "X-OpenRouter-Title"},
},
Models: []Model{},
},
{
ID: "custom",
Kind: KindCustom,
Name: "Custom / Self-hosted",
Description: "OpenAI-compatible endpoint (vLLM, Ollama, …)",
DefaultHost: "",
AuthHeaderName: "Authorization",
AuthHeaderTemplate: "Bearer ${API_KEY}",
DefaultContentType: "application/json",
BrandColor: "#9CA3AF",
Models: []Model{},
},
}
// All returns a copy of the full catalog.
func All() []Provider {
out := make([]Provider, len(providers))
copy(out, providers)
return out
}
// Lookup returns the catalog entry with the given id, if any.
func Lookup(id string) (Provider, bool) {
for _, p := range providers {
if p.ID == id {
return p, true
}
}
return Provider{}, false
}
// IsKnown reports whether the given id refers to a catalog entry.
func IsKnown(id string) bool {
_, ok := Lookup(id)
return ok
}
// IsVertexPathStyle reports whether a provider uses the Google Vertex AI
// request shape — the model is carried in the URL path
// (/v1/projects/{p}/locations/{r}/publishers/{pub}/models/{model}:{action})
// rather than the body, so the proxy routes it by path instead of by model.
func IsVertexPathStyle(providerID string) bool {
return providerID == "vertex_ai_api"
}
// IsBedrockPathStyle reports whether a provider uses the AWS Bedrock request
// shape — the model is carried in the URL path (/model/{modelId}/{action},
// action being invoke, invoke-with-response-stream, converse, or
// converse-stream) rather than the body, so the proxy routes it by path.
func IsBedrockPathStyle(providerID string) bool {
return providerID == "bedrock_api"
}
// ToAPIResponse renders a catalog provider as the API representation.
func (p Provider) ToAPIResponse() api.AgentNetworkCatalogProvider {
models := make([]api.AgentNetworkCatalogModel, 0, len(p.Models))
for _, m := range p.Models {
models = append(models, api.AgentNetworkCatalogModel{
Id: m.ID,
Label: m.Label,
InputPer1k: m.InputPer1k,
OutputPer1k: m.OutputPer1k,
ContextWindow: m.ContextWindow,
})
}
kind := api.AgentNetworkCatalogProviderKindProvider
switch p.Kind {
case KindGateway:
kind = api.AgentNetworkCatalogProviderKindGateway
case KindCustom:
kind = api.AgentNetworkCatalogProviderKindCustom
}
resp := api.AgentNetworkCatalogProvider{
Id: p.ID,
Name: p.Name,
Description: p.Description,
DefaultHost: p.DefaultHost,
Kind: kind,
AuthHeaderTemplate: p.AuthHeaderTemplate,
DefaultContentType: p.DefaultContentType,
BrandColor: p.BrandColor,
Models: models,
}
if len(p.ExtraHeaders) > 0 {
extras := make([]api.AgentNetworkCatalogExtraHeader, 0, len(p.ExtraHeaders))
for _, h := range p.ExtraHeaders {
extras = append(extras, api.AgentNetworkCatalogExtraHeader{
Name: h.Name,
})
}
resp.ExtraHeaders = &extras
}
// Surface IdentityInjection so the dashboard can decide whether
// to render editable inputs vs. a read-only mappings strip per
// shape's customizable flag. HeaderPair (Bifrost) and
// JSONMetadata (Cloudflare, Portkey) are mutually exclusive on a
// given catalog entry; emit whichever shape is set.
if p.IdentityInjection != nil {
injection := &api.AgentNetworkCatalogIdentityInjection{}
if hp := p.IdentityInjection.HeaderPair; hp != nil {
injection.HeaderPair = &api.AgentNetworkCatalogHeaderPairInjection{
Customizable: hp.Customizable,
EndUserIdHeader: hp.EndUserIDHeader,
TagsHeader: hp.TagsHeader,
}
}
if jm := p.IdentityInjection.JSONMetadata; jm != nil {
injection.JsonMetadata = &api.AgentNetworkCatalogJSONMetadataInjection{
Customizable: jm.Customizable,
Header: jm.Header,
UserKey: jm.UserKey,
GroupsKey: jm.GroupsKey,
}
}
if injection.HeaderPair != nil || injection.JsonMetadata != nil {
resp.IdentityInjection = injection
}
}
return resp
}

View File

@@ -0,0 +1,91 @@
package handlers
import (
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
// addAccessLogEndpoints registers the read-only, server-side-filtered
// agent-network access-log listing and the aggregated usage overview.
func (h *handler) addAccessLogEndpoints(router *mux.Router) {
router.HandleFunc("/agent-network/access-logs", h.listAccessLogs).Methods("GET", "OPTIONS")
router.HandleFunc("/agent-network/usage/overview", h.getUsageOverview).Methods("GET", "OPTIONS")
}
func (h *handler) getUsageOverview(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
// Reuse the access-log filter for the shared date/user/group/provider/model
// params; pagination/sort/search are irrelevant for an aggregate.
var filter types.AgentNetworkAccessLogFilter
if err := filter.ParseFromRequest(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
// Bound the aggregation window so an unbounded or over-wide query can't load
// an account's entire usage history into memory.
filter.ApplyUsageOverviewBounds(time.Now())
granularity := types.ParseUsageGranularity(r.URL.Query().Get("granularity"))
buckets, err := h.manager.GetUsageOverview(r.Context(), userAuth.AccountId, userAuth.UserId, filter, granularity)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
out := make([]api.AgentNetworkUsageBucket, 0, len(buckets))
for _, b := range buckets {
out = append(out, b.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, out)
}
func (h *handler) listAccessLogs(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var filter types.AgentNetworkAccessLogFilter
if err := filter.ParseFromRequest(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
rows, total, err := h.manager.ListAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId, filter)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
data := make([]api.AgentNetworkAccessLog, 0, len(rows))
for _, row := range rows {
data = append(data, row.ToAPIResponse())
}
pageSize := filter.GetLimit()
totalPages := 0
if pageSize > 0 {
totalPages = int((total + int64(pageSize) - 1) / int64(pageSize))
}
util.WriteJSONObject(r.Context(), w, api.AgentNetworkAccessLogsResponse{
Data: data,
Page: filter.Page,
PageSize: pageSize,
TotalRecords: int(total),
TotalPages: totalPages,
})
}

View File

@@ -0,0 +1,172 @@
package handlers
import (
"encoding/json"
"net/http"
"strings"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// addBudgetRuleEndpoints registers the account-level budget rule routes.
func (h *handler) addBudgetRuleEndpoints(router *mux.Router) {
router.HandleFunc("/agent-network/budget-rules", h.getAllBudgetRules).Methods("GET", "OPTIONS")
router.HandleFunc("/agent-network/budget-rules", h.createBudgetRule).Methods("POST", "OPTIONS")
router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.getBudgetRule).Methods("GET", "OPTIONS")
router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.updateBudgetRule).Methods("PUT", "OPTIONS")
router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.deleteBudgetRule).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllBudgetRules(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
rules, err := h.manager.GetAllBudgetRules(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
out := make([]*api.AgentNetworkBudgetRule, 0, len(rules))
for _, rule := range rules {
out = append(out, rule.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, out)
}
func (h *handler) getBudgetRule(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ruleID := mux.Vars(r)["ruleId"]
if ruleID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w)
return
}
rule, err := h.manager.GetBudgetRule(r.Context(), userAuth.AccountId, userAuth.UserId, ruleID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, rule.ToAPIResponse())
}
func (h *handler) createBudgetRule(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.AgentNetworkBudgetRuleRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if err := validateBudgetRule(&req); err != nil {
util.WriteError(r.Context(), err, w)
return
}
rule := types.NewAccountBudgetRule(userAuth.AccountId)
rule.FromAPIRequest(&req)
created, err := h.manager.CreateBudgetRule(r.Context(), userAuth.UserId, rule)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
}
func (h *handler) updateBudgetRule(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ruleID := mux.Vars(r)["ruleId"]
if ruleID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w)
return
}
var req api.AgentNetworkBudgetRuleRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if err := validateBudgetRule(&req); err != nil {
util.WriteError(r.Context(), err, w)
return
}
rule := &types.AccountBudgetRule{ID: ruleID, AccountID: userAuth.AccountId}
rule.FromAPIRequest(&req)
updated, err := h.manager.UpdateBudgetRule(r.Context(), userAuth.UserId, rule)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
}
func (h *handler) deleteBudgetRule(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ruleID := mux.Vars(r)["ruleId"]
if ruleID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w)
return
}
if err := h.manager.DeleteBudgetRule(r.Context(), userAuth.AccountId, userAuth.UserId, ruleID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
// validateBudgetRule rejects malformed budget rules. It reuses the policy limit
// validation since the cap shape is identical, and rejects empty target entries.
func validateBudgetRule(req *api.AgentNetworkBudgetRuleRequest) error {
if strings.TrimSpace(req.Name) == "" {
return status.Errorf(status.InvalidArgument, "name is required")
}
if req.TargetGroups != nil {
for _, id := range *req.TargetGroups {
if strings.TrimSpace(id) == "" {
return status.Errorf(status.InvalidArgument, "target_groups must not contain empty entries")
}
}
}
if req.TargetUsers != nil {
for _, id := range *req.TargetUsers {
if strings.TrimSpace(id) == "" {
return status.Errorf(status.InvalidArgument, "target_users must not contain empty entries")
}
}
}
return validatePolicyLimits(req.Limits)
}

View File

@@ -0,0 +1,131 @@
package handlers
import (
"context"
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// TestBudgetRuleHandler_RoundTrip seeds a budget rule via the store and asserts
// the GET wire shape carries targets and the reused PolicyLimits cap shape. The
// create/update/delete success paths go through accountManager.StoreEvent which
// this fixture doesn't wire — they are covered by the manager-level no-mock
// test (TestAgentNetwork_BudgetRuleCRUD_RealManager).
func TestBudgetRuleHandler_RoundTrip(t *testing.T) {
f := newAgentNetworkHandlerFixture(t)
rule := &agentNetworkTypes.AccountBudgetRule{
ID: "ainbud_test",
AccountID: testAccountID,
Name: "org-monthly",
Enabled: true,
TargetGroups: []string{"grp-eng"},
TargetUsers: []string{"user-alice"},
Limits: agentNetworkTypes.PolicyLimits{
TokenLimit: agentNetworkTypes.PolicyTokenLimit{Enabled: true, GroupCap: 100000, UserCap: 10000, WindowSeconds: 2_592_000},
BudgetLimit: agentNetworkTypes.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 500, WindowSeconds: 2_592_000},
},
}
require.NoError(t, f.store.SaveAgentNetworkBudgetRule(context.Background(), rule))
rec := f.do(t, http.MethodGet, "/agent-network/budget-rules/"+rule.ID, "")
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
var got api.AgentNetworkBudgetRule
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
assert.Equal(t, "org-monthly", got.Name, "name must round-trip")
assert.Equal(t, []string{"grp-eng"}, got.TargetGroups, "target groups must round-trip")
assert.Equal(t, []string{"user-alice"}, got.TargetUsers, "target users must round-trip")
assert.Equal(t, int64(100000), got.Limits.TokenLimit.GroupCap, "token group cap must round-trip")
assert.Equal(t, int64(2_592_000), got.Limits.BudgetLimit.WindowSeconds, "budget window must round-trip")
}
// TestBudgetRuleHandler_ListReturnsArray asserts the list endpoint returns a
// JSON array (never null) for an account with no rules.
func TestBudgetRuleHandler_ListReturnsArray(t *testing.T) {
f := newAgentNetworkHandlerFixture(t)
rec := f.do(t, http.MethodGet, "/agent-network/budget-rules", "")
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
assert.Equal(t, "[]", trimSpace(rec.Body.String()), "empty account must return an empty array, not null")
}
// TestBudgetRuleHandler_RejectsMissingName covers the validation path (which
// runs before the manager call, so it works without a wired accountManager).
func TestBudgetRuleHandler_RejectsMissingName(t *testing.T) {
f := newAgentNetworkHandlerFixture(t)
body := `{
"name": "",
"limits": {
"token_limit": {"enabled": false, "group_cap": 0, "user_cap": 0, "window_seconds": 0},
"budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0}
}
}`
rec := f.do(t, http.MethodPost, "/agent-network/budget-rules", body)
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code,
"missing name must be rejected as a validation error (not a route/auth 4xx): got %d body=%s", rec.Code, rec.Body.String())
assert.Contains(t, rec.Body.String(), "name",
"rejection body must name the offending field, proving the validation path: %s", rec.Body.String())
}
// TestBudgetRuleHandler_RejectsSubMinuteWindow proves budget rules reuse the
// policy-limit validation (enabled limit needs window >= 60s).
func TestBudgetRuleHandler_RejectsSubMinuteWindow(t *testing.T) {
f := newAgentNetworkHandlerFixture(t)
body := `{
"name": "bad-window",
"limits": {
"token_limit": {"enabled": true, "group_cap": 1000, "user_cap": 0, "window_seconds": 30},
"budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0}
}
}`
rec := f.do(t, http.MethodPost, "/agent-network/budget-rules", body)
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code,
"sub-minute window must be rejected as a validation error (not a route/auth 4xx): got %d body=%s", rec.Code, rec.Body.String())
assert.Contains(t, rec.Body.String(), "window_seconds",
"rejection body must name the offending window_seconds field, proving the validation path: %s", rec.Body.String())
}
// TestSettingsHandler_GetExposesCollectionToggles asserts the GET settings wire
// shape carries the account-level collection toggles after a store seed.
func TestSettingsHandler_GetExposesCollectionToggles(t *testing.T) {
f := newAgentNetworkHandlerFixture(t)
require.NoError(t, f.store.SaveAgentNetworkSettings(context.Background(), &agentNetworkTypes.Settings{
AccountID: testAccountID,
Cluster: "eu.proxy.netbird.io",
Subdomain: "violet",
EnableLogCollection: true,
EnablePromptCollection: true,
RedactPii: false,
}))
rec := f.do(t, http.MethodGet, "/agent-network/settings", "")
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
var got api.AgentNetworkSettings
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
assert.True(t, got.EnableLogCollection, "log collection toggle must surface on the wire")
assert.True(t, got.EnablePromptCollection, "prompt collection toggle must surface on the wire")
assert.False(t, got.RedactPii, "redact toggle must surface its false value")
assert.Equal(t, "violet.eu.proxy.netbird.io", got.Endpoint, "endpoint stays computed from immutable cluster+subdomain")
}
func trimSpace(s string) string {
for len(s) > 0 && (s[len(s)-1] == '\n' || s[len(s)-1] == ' ' || s[len(s)-1] == '\t' || s[len(s)-1] == '\r') {
s = s[:len(s)-1]
}
for len(s) > 0 && (s[0] == '\n' || s[0] == ' ' || s[0] == '\t' || s[0] == '\r') {
s = s[1:]
}
return s
}

View File

@@ -0,0 +1,53 @@
package handlers
import (
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
// addConsumptionEndpoints registers the read-only Agent Network
// consumption listing — backs the dashboard's basic counter view.
func (h *handler) addConsumptionEndpoints(router *mux.Router) {
router.HandleFunc("/agent-network/consumption", h.listConsumption).Methods("GET", "OPTIONS")
}
func (h *handler) listConsumption(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
rows, err := h.manager.ListConsumption(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
out := make([]api.AgentNetworkConsumption, 0, len(rows))
for _, row := range rows {
out = append(out, consumptionToAPI(row))
}
util.WriteJSONObject(r.Context(), w, out)
}
func consumptionToAPI(c *types.Consumption) api.AgentNetworkConsumption {
windowStart := c.WindowStartUTC
updatedAt := c.UpdatedAt
return api.AgentNetworkConsumption{
DimensionKind: api.AgentNetworkConsumptionDimensionKind(c.DimensionKind),
DimensionId: c.DimensionID,
WindowSeconds: c.WindowSeconds,
WindowStartUtc: windowStart,
TokensInput: c.TokensInput,
TokensOutput: c.TokensOutput,
CostUsd: c.CostUSD,
UpdatedAt: &updatedAt,
}
}

View File

@@ -0,0 +1,171 @@
package handlers
import (
"encoding/json"
"net/http"
"strings"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// addGuardrailEndpoints registers all Agent Network guardrail routes.
func (h *handler) addGuardrailEndpoints(router *mux.Router) {
router.HandleFunc("/agent-network/guardrails", h.getAllGuardrails).Methods("GET", "OPTIONS")
router.HandleFunc("/agent-network/guardrails", h.createGuardrail).Methods("POST", "OPTIONS")
router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.getGuardrail).Methods("GET", "OPTIONS")
router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.updateGuardrail).Methods("PUT", "OPTIONS")
router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.deleteGuardrail).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllGuardrails(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
guardrails, err := h.manager.GetAllGuardrails(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
out := make([]*api.AgentNetworkGuardrail, 0, len(guardrails))
for _, g := range guardrails {
out = append(out, g.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, out)
}
func (h *handler) getGuardrail(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
guardrailID := mux.Vars(r)["guardrailId"]
if guardrailID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w)
return
}
guardrail, err := h.manager.GetGuardrail(r.Context(), userAuth.AccountId, userAuth.UserId, guardrailID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, guardrail.ToAPIResponse())
}
func (h *handler) createGuardrail(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.AgentNetworkGuardrailRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if err := validateGuardrail(&req); err != nil {
util.WriteError(r.Context(), err, w)
return
}
guardrail := types.NewGuardrail(userAuth.AccountId)
guardrail.FromAPIRequest(&req)
created, err := h.manager.CreateGuardrail(r.Context(), userAuth.UserId, guardrail)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
}
func (h *handler) updateGuardrail(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
guardrailID := mux.Vars(r)["guardrailId"]
if guardrailID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w)
return
}
var req api.AgentNetworkGuardrailRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if err := validateGuardrail(&req); err != nil {
util.WriteError(r.Context(), err, w)
return
}
guardrail := &types.Guardrail{
ID: guardrailID,
AccountID: userAuth.AccountId,
}
guardrail.FromAPIRequest(&req)
updated, err := h.manager.UpdateGuardrail(r.Context(), userAuth.UserId, guardrail)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
}
func (h *handler) deleteGuardrail(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
guardrailID := mux.Vars(r)["guardrailId"]
if guardrailID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w)
return
}
if err := h.manager.DeleteGuardrail(r.Context(), userAuth.AccountId, userAuth.UserId, guardrailID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func validateGuardrail(req *api.AgentNetworkGuardrailRequest) error {
if strings.TrimSpace(req.Name) == "" {
return status.Errorf(status.InvalidArgument, "name is required")
}
c := req.Checks
if c.ModelAllowlist.Enabled {
for _, id := range c.ModelAllowlist.Models {
if strings.TrimSpace(id) == "" {
return status.Errorf(status.InvalidArgument, "model_allowlist.models must not contain empty entries")
}
}
}
return nil
}

View File

@@ -0,0 +1,256 @@
package handlers
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"runtime"
"strings"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/store"
nbtypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
)
const (
testAccountID = "acc-1"
testUserID = "user-bob"
)
// agentNetworkHandlerFixture builds a real agentnetwork.Manager with
// a sqlite store and an always-allow permissions mock, then exposes
// the HTTP handlers via a gorilla router. Tests issue requests
// through httptest and assert on the wire shape — the same path the
// dashboard exercises.
type agentNetworkHandlerFixture struct {
store store.Store
manager agentnetwork.Manager
router *mux.Router
}
func newAgentNetworkHandlerFixture(t *testing.T) *agentNetworkHandlerFixture {
t.Helper()
if runtime.GOOS == "windows" {
t.Skip("sqlite store not properly supported on Windows yet")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(nbtypes.SqliteStoreEngine))
st, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanUp)
ctrl := gomock.NewController(t)
perms := permissions.NewMockManager(ctrl)
// Always-allow: the handler tests are about wire shape, not
// authz. Authz is covered by the manager's own tests.
perms.EXPECT().
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(true, context.Background(), nil).
AnyTimes()
manager := agentnetwork.NewManager(st, perms, nil, nil)
h := &handler{manager: manager}
router := mux.NewRouter()
h.addPolicyEndpoints(router)
h.addConsumptionEndpoints(router)
h.addBudgetRuleEndpoints(router)
h.addSettingsEndpoints(router)
return &agentNetworkHandlerFixture{
store: st,
manager: manager,
router: router,
}
}
func (f *agentNetworkHandlerFixture) do(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
t.Helper()
var reader io.Reader
if body != "" {
reader = strings.NewReader(body)
}
req := httptest.NewRequest(method, path, reader)
if body != "" {
req.Header.Set("Content-Type", "application/json")
}
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
rec := httptest.NewRecorder()
f.router.ServeHTTP(rec, req)
return rec
}
// seedProvider persists a minimal provider record so policy create
// passes the manager's destination_provider_ids existence check.
func (f *agentNetworkHandlerFixture) seedProvider(t *testing.T, id string) {
t.Helper()
require.NoError(t, f.store.SaveAgentNetworkProvider(context.Background(), &agentNetworkTypes.Provider{
ID: id,
AccountID: testAccountID,
ProviderID: "openai_api",
Name: "test-" + id,
UpstreamURL: "https://api.openai.com",
APIKey: "sk-test",
Enabled: true,
SessionPrivateKey: "test-priv-key",
SessionPublicKey: "test-pub-key",
}))
}
// TestPolicyHandler_WindowSecondsRoundTrip ports bash 10 to Go:
// assert that a policy with window_seconds on both Token + Budget
// halves round-trips through GET unchanged AND that legacy
// window_hours / window_days are absent from the JSON response. We
// seed the policy directly via the store rather than POST-ing
// because the create path goes through the manager's
// accountManager.StoreEvent which we don't wire in this fixture; the
// on-wire shape is what matters here, and the POST validation path
// is covered separately by the RejectsSubMinuteWindow test.
func TestPolicyHandler_WindowSecondsRoundTrip(t *testing.T) {
f := newAgentNetworkHandlerFixture(t)
policy := &agentNetworkTypes.Policy{
ID: "ainpol_test",
AccountID: testAccountID,
Name: "round-trip",
Enabled: true,
SourceGroups: []string{"grp-engineers"},
DestinationProviderIDs: []string{"prov-1"},
Limits: agentNetworkTypes.PolicyLimits{
TokenLimit: agentNetworkTypes.PolicyTokenLimit{Enabled: true, GroupCap: 10000, UserCap: 5000, WindowSeconds: 86_400},
BudgetLimit: agentNetworkTypes.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 10.0, UserCapUsd: 2.5, WindowSeconds: 2_592_000},
},
}
require.NoError(t, f.store.SaveAgentNetworkPolicy(context.Background(), policy))
rec := f.do(t, http.MethodGet, "/agent-network/policies/"+policy.ID, "")
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
var got api.AgentNetworkPolicy
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
assert.Equal(t, int64(86_400), got.Limits.TokenLimit.WindowSeconds, "token_limit.window_seconds must round-trip")
assert.Equal(t, int64(2_592_000), got.Limits.BudgetLimit.WindowSeconds, "budget_limit.window_seconds must round-trip")
// Legacy field names must NOT appear in the response — would
// signal that the management server is still emitting the old
// shape and would fool a v1 dashboard into rendering days/hours.
assert.NotContains(t, rec.Body.String(), "window_hours",
"legacy window_hours field must be absent from the on-wire response")
assert.NotContains(t, rec.Body.String(), "window_days",
"legacy window_days field must be absent from the on-wire response")
}
// TestPolicyHandler_RejectsSubMinuteWindow ports bash 20 to Go: an
// enabled limit with window_seconds < 60 must surface as a 4xx
// because anything finer than per-minute produces an untenable
// volume of consumption rows for a feature whose value comes from
// per-window cap enforcement.
func TestPolicyHandler_RejectsSubMinuteWindow(t *testing.T) {
f := newAgentNetworkHandlerFixture(t)
f.seedProvider(t, "prov-1")
body := `{
"name": "sub-minute-window",
"enabled": true,
"source_groups": ["grp-engineers"],
"destination_provider_ids": ["prov-1"],
"guardrail_ids": [],
"limits": {
"token_limit": {"enabled": true, "group_cap": 10000, "user_cap": 5000, "window_seconds": 30},
"budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0}
}
}`
rec := f.do(t, http.MethodPost, "/agent-network/policies", body)
// 422 specifically (InvalidArgument) proves the window-validation path —
// a route miss would be 404 and an auth failure 403, so a generic 4xx
// would let those false-pass.
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code,
"enabled token_limit with window_seconds<60 must be rejected as a validation error: got %d body=%s", rec.Code, rec.Body.String())
assert.Contains(t, rec.Body.String(), "window_seconds",
"rejection body must name the offending window_seconds field, proving it's the validation path: %s", rec.Body.String())
}
// TestConsumptionHandler_EmptyAccountReturnsArray ports bash 30 to
// Go: GET /agent-network/consumption on a clean account always
// returns a JSON array (possibly empty), never a 404 / 500. The
// dashboard depends on this shape to render its empty state.
func TestConsumptionHandler_EmptyAccountReturnsArray(t *testing.T) {
f := newAgentNetworkHandlerFixture(t)
rec := f.do(t, http.MethodGet, "/agent-network/consumption", "")
require.Equal(t, http.StatusOK, rec.Code)
var rows []api.AgentNetworkConsumption
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rows),
"response must always be a JSON array — even when empty: %s", rec.Body.String())
assert.Empty(t, rows)
}
// TestConsumptionHandler_PopulatedAccountListsRows mirrors the
// /consumption read after a few RecordConsumption calls. Validates
// the wire shape carries every field the dashboard reads (dim_kind,
// dim_id, window_seconds, window_start_utc, tokens, cost_usd) and
// rows are ordered window-newest-first.
func TestConsumptionHandler_PopulatedAccountListsRows(t *testing.T) {
f := newAgentNetworkHandlerFixture(t)
require.NoError(t, f.manager.RecordConsumption(
context.Background(), testAccountID,
agentNetworkTypes.DimensionGroup, "grp-engineers",
86_400, 100, 50, 0.0125,
))
require.NoError(t, f.manager.RecordConsumption(
context.Background(), testAccountID,
agentNetworkTypes.DimensionUser, testUserID,
86_400, 100, 50, 0.0125,
))
rec := f.do(t, http.MethodGet, "/agent-network/consumption", "")
require.Equal(t, http.StatusOK, rec.Code)
var rows []api.AgentNetworkConsumption
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rows))
require.Len(t, rows, 2, "two RecordConsumption calls must yield two rows")
// Index by dim_kind so we can assert the full wire shape of each row,
// including the dimension id and the aligned window start the dashboard
// keys on. Both rows share totals and window.
byKind := make(map[string]api.AgentNetworkConsumption, len(rows))
for _, row := range rows {
assert.Equal(t, int64(100), row.TokensInput)
assert.Equal(t, int64(50), row.TokensOutput)
assert.InDelta(t, 0.0125, row.CostUsd, 1e-9)
assert.Equal(t, int64(86_400), row.WindowSeconds)
assert.False(t, row.WindowStartUtc.IsZero(), "window_start_utc must be set on every row")
byKind[string(row.DimensionKind)] = row
}
groupRow, ok := byKind["group"]
require.True(t, ok, "group dimension must surface")
assert.Equal(t, "grp-engineers", groupRow.DimensionId, "group row must carry the source group id as dimension_id")
userRow, ok := byKind["user"]
require.True(t, ok, "user dimension must surface")
assert.Equal(t, testUserID, userRow.DimensionId, "user row must carry the user id as dimension_id")
// Both rows fall in the same aligned window (same length, recorded
// together), so window_start_utc must match across them.
assert.Equal(t, groupRow.WindowStartUtc, userRow.WindowStartUtc,
"rows recorded in the same window must share the aligned window_start_utc")
}

View File

@@ -0,0 +1,228 @@
package handlers
import (
"encoding/json"
"net/http"
"strings"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// minWindowSeconds is the floor enforced on enabled token / budget
// limit windows. One minute is short enough for fine-grained burst
// control without producing untenable consumption-row volume at scale.
const minWindowSeconds int64 = 60
// addPolicyEndpoints registers all Agent Network policy routes on the
// shared handler.
func (h *handler) addPolicyEndpoints(router *mux.Router) {
router.HandleFunc("/agent-network/policies", h.getAllPolicies).Methods("GET", "OPTIONS")
router.HandleFunc("/agent-network/policies", h.createPolicy).Methods("POST", "OPTIONS")
router.HandleFunc("/agent-network/policies/{policyId}", h.getPolicy).Methods("GET", "OPTIONS")
router.HandleFunc("/agent-network/policies/{policyId}", h.updatePolicy).Methods("PUT", "OPTIONS")
router.HandleFunc("/agent-network/policies/{policyId}", h.deletePolicy).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policies, err := h.manager.GetAllPolicies(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
out := make([]*api.AgentNetworkPolicy, 0, len(policies))
for _, p := range policies {
out = append(out, p.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, out)
}
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policyID := mux.Vars(r)["policyId"]
if policyID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w)
return
}
policy, err := h.manager.GetPolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policyID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, policy.ToAPIResponse())
}
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.AgentNetworkPolicyRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if err := validatePolicy(&req); err != nil {
util.WriteError(r.Context(), err, w)
return
}
policy := types.NewPolicy(userAuth.AccountId)
policy.FromAPIRequest(&req)
created, err := h.manager.CreatePolicy(r.Context(), userAuth.UserId, policy)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
}
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policyID := mux.Vars(r)["policyId"]
if policyID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w)
return
}
var req api.AgentNetworkPolicyRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if err := validatePolicy(&req); err != nil {
util.WriteError(r.Context(), err, w)
return
}
policy := &types.Policy{
ID: policyID,
AccountID: userAuth.AccountId,
}
policy.FromAPIRequest(&req)
updated, err := h.manager.UpdatePolicy(r.Context(), userAuth.UserId, policy)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
}
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policyID := mux.Vars(r)["policyId"]
if policyID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w)
return
}
if err := h.manager.DeletePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policyID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func validatePolicy(req *api.AgentNetworkPolicyRequest) error {
if strings.TrimSpace(req.Name) == "" {
return status.Errorf(status.InvalidArgument, "name is required")
}
if len(req.SourceGroups) == 0 {
return status.Errorf(status.InvalidArgument, "source_groups must contain at least one group id")
}
for _, id := range req.SourceGroups {
if strings.TrimSpace(id) == "" {
return status.Errorf(status.InvalidArgument, "source_groups must not contain empty entries")
}
}
if len(req.DestinationProviderIds) == 0 {
return status.Errorf(status.InvalidArgument, "destination_provider_ids must contain at least one provider id")
}
for _, id := range req.DestinationProviderIds {
if strings.TrimSpace(id) == "" {
return status.Errorf(status.InvalidArgument, "destination_provider_ids must not contain empty entries")
}
}
if req.GuardrailIds != nil {
for _, id := range *req.GuardrailIds {
if strings.TrimSpace(id) == "" {
return status.Errorf(status.InvalidArgument, "guardrail_ids must not contain empty entries")
}
}
}
if req.Limits != nil {
if err := validatePolicyLimits(*req.Limits); err != nil {
return err
}
}
return nil
}
func validatePolicyLimits(l api.AgentNetworkPolicyLimits) error {
if l.TokenLimit.Enabled {
if l.TokenLimit.WindowSeconds < minWindowSeconds {
return status.Errorf(status.InvalidArgument, "limits.token_limit.window_seconds must be at least %d (one minute) when enabled", minWindowSeconds)
}
if l.TokenLimit.GroupCap < 0 {
return status.Errorf(status.InvalidArgument, "limits.token_limit.group_cap must not be negative")
}
if l.TokenLimit.UserCap < 0 {
return status.Errorf(status.InvalidArgument, "limits.token_limit.user_cap must not be negative")
}
if l.TokenLimit.GroupCap == 0 && l.TokenLimit.UserCap == 0 {
return status.Errorf(status.InvalidArgument, "limits.token_limit requires group_cap or user_cap to be greater than zero when enabled")
}
}
if l.BudgetLimit.Enabled {
if l.BudgetLimit.WindowSeconds < minWindowSeconds {
return status.Errorf(status.InvalidArgument, "limits.budget_limit.window_seconds must be at least %d (one minute) when enabled", minWindowSeconds)
}
if l.BudgetLimit.GroupCapUsd < 0 {
return status.Errorf(status.InvalidArgument, "limits.budget_limit.group_cap_usd must not be negative")
}
if l.BudgetLimit.UserCapUsd < 0 {
return status.Errorf(status.InvalidArgument, "limits.budget_limit.user_cap_usd must not be negative")
}
if l.BudgetLimit.GroupCapUsd == 0 && l.BudgetLimit.UserCapUsd == 0 {
return status.Errorf(status.InvalidArgument, "limits.budget_limit requires group_cap_usd or user_cap_usd to be greater than zero when enabled")
}
}
return nil
}

View File

@@ -0,0 +1,217 @@
// Package handlers serves the Agent Network HTTP API.
//
// All persistence is delegated to agentnetwork.Manager so this layer only
// translates between the wire format (api.AgentNetworkProvider*) and the
// domain types.
package handlers
import (
"encoding/json"
"net/http"
"net/url"
"strings"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/catalog"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager agentnetwork.Manager
}
// RegisterEndpoints registers all Agent Network routes.
func RegisterEndpoints(manager agentnetwork.Manager, router *mux.Router) {
h := &handler{manager: manager}
router.HandleFunc("/agent-network/catalog/providers", h.getCatalogProviders).Methods("GET", "OPTIONS")
router.HandleFunc("/agent-network/providers", h.getAllProviders).Methods("GET", "OPTIONS")
router.HandleFunc("/agent-network/providers", h.createProvider).Methods("POST", "OPTIONS")
router.HandleFunc("/agent-network/providers/{providerId}", h.getProvider).Methods("GET", "OPTIONS")
router.HandleFunc("/agent-network/providers/{providerId}", h.updateProvider).Methods("PUT", "OPTIONS")
router.HandleFunc("/agent-network/providers/{providerId}", h.deleteProvider).Methods("DELETE", "OPTIONS")
h.addPolicyEndpoints(router)
h.addGuardrailEndpoints(router)
h.addSettingsEndpoints(router)
h.addConsumptionEndpoints(router)
h.addAccessLogEndpoints(router)
h.addBudgetRuleEndpoints(router)
}
func (h *handler) getCatalogProviders(w http.ResponseWriter, r *http.Request) {
if _, err := nbcontext.GetUserAuthFromContext(r.Context()); err != nil {
util.WriteError(r.Context(), err, w)
return
}
entries := catalog.All()
out := make([]api.AgentNetworkCatalogProvider, 0, len(entries))
for _, e := range entries {
out = append(out, e.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, out)
}
func (h *handler) getAllProviders(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
providers, err := h.manager.GetAllProviders(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
out := make([]*api.AgentNetworkProvider, 0, len(providers))
for _, p := range providers {
out = append(out, p.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, out)
}
func (h *handler) getProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
providerID := mux.Vars(r)["providerId"]
if providerID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w)
return
}
provider, err := h.manager.GetProvider(r.Context(), userAuth.AccountId, userAuth.UserId, providerID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, provider.ToAPIResponse())
}
func (h *handler) createProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.AgentNetworkProviderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if err := validate(&req, true); err != nil {
util.WriteError(r.Context(), err, w)
return
}
provider := types.NewProvider(userAuth.AccountId)
provider.FromAPIRequest(&req)
bootstrapCluster := ""
if req.BootstrapCluster != nil {
bootstrapCluster = *req.BootstrapCluster
}
created, err := h.manager.CreateProvider(r.Context(), userAuth.UserId, provider, bootstrapCluster)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
}
func (h *handler) updateProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
providerID := mux.Vars(r)["providerId"]
if providerID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w)
return
}
var req api.AgentNetworkProviderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if err := validate(&req, false); err != nil {
util.WriteError(r.Context(), err, w)
return
}
provider := &types.Provider{
ID: providerID,
AccountID: userAuth.AccountId,
}
provider.FromAPIRequest(&req)
updated, err := h.manager.UpdateProvider(r.Context(), userAuth.UserId, provider)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
}
func (h *handler) deleteProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
providerID := mux.Vars(r)["providerId"]
if providerID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w)
return
}
if err := h.manager.DeleteProvider(r.Context(), userAuth.AccountId, userAuth.UserId, providerID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func validate(req *api.AgentNetworkProviderRequest, requireAPIKey bool) error {
if strings.TrimSpace(req.ProviderId) == "" {
return status.Errorf(status.InvalidArgument, "provider_id is required")
}
if !catalog.IsKnown(req.ProviderId) {
return status.Errorf(status.InvalidArgument, "provider_id %q is not a known catalog provider", req.ProviderId)
}
if strings.TrimSpace(req.Name) == "" {
return status.Errorf(status.InvalidArgument, "name is required")
}
if strings.TrimSpace(req.UpstreamUrl) == "" {
return status.Errorf(status.InvalidArgument, "upstream_url is required")
}
u, err := url.Parse(strings.TrimSpace(req.UpstreamUrl))
if err != nil || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") {
return status.Errorf(status.InvalidArgument, "upstream_url must be a full http(s) URL")
}
if requireAPIKey && (req.ApiKey == nil || strings.TrimSpace(*req.ApiKey) == "") {
return status.Errorf(status.InvalidArgument, "api_key is required")
}
return nil
}

View File

@@ -0,0 +1,74 @@
package handlers
import (
"encoding/json"
"errors"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// addSettingsEndpoints registers the Agent Network settings routes. The
// settings row is bootstrapped server-side on first provider create; GET reads
// it and PUT updates the mutable collection toggles (cluster/subdomain stay
// immutable).
func (h *handler) addSettingsEndpoints(router *mux.Router) {
router.HandleFunc("/agent-network/settings", h.getSettings).Methods("GET", "OPTIONS")
router.HandleFunc("/agent-network/settings", h.updateSettings).Methods("PUT", "OPTIONS")
}
// updateSettings applies the collection toggles to the account's settings row.
func (h *handler) updateSettings(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.AgentNetworkSettingsRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
settings := &types.Settings{AccountID: userAuth.AccountId}
settings.FromAPIRequest(&req)
updated, err := h.manager.UpdateSettings(r.Context(), userAuth.UserId, settings)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
}
// getSettings returns the account's agent-network settings. The settings
// row is bootstrapped on first provider create, so freshly-onboarded
// accounts have nothing to read. Rather than 404-ing in that case (which
// the dashboard would have to special-case), return a JSON null with 200
// so consumers can branch on the body alone.
func (h *handler) getSettings(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
settings, err := h.manager.GetSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
var sErr *status.Error
if errors.As(err, &sErr) && sErr.Type() == status.NotFound {
util.WriteJSONObject(r.Context(), w, nil)
return
}
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, settings.ToAPIResponse())
}

View File

@@ -0,0 +1,66 @@
// Package labelgen produces DNS-safe Agent Network subdomain labels.
package labelgen
import (
"fmt"
"math/rand"
"sort"
"sync"
)
// pickAttempts caps the random retries before falling back to the
// suffixed form. Eight is a soft compromise: with a near-empty taken
// set the very first pick almost always succeeds; when the wordlist is
// densely populated the fallback eventually fires anyway.
const pickAttempts = 8
var (
dedupOnce sync.Once
uniqWords []string
)
// uniqueWords returns the wordlist deduplicated and sorted for
// deterministic exhaustion behaviour. Lazy-built once per process.
func uniqueWords() []string {
dedupOnce.Do(func() {
seen := make(map[string]struct{}, len(words))
uniqWords = make([]string, 0, len(words))
for _, w := range words {
if _, ok := seen[w]; ok {
continue
}
seen[w] = struct{}{}
uniqWords = append(uniqWords, w)
}
sort.Strings(uniqWords)
})
return uniqWords
}
// PickUnique selects a label not already in `taken`. It tries up to
// pickAttempts random picks; on exhaustion it scans the deduplicated
// wordlist for any remaining free entry, and if none is left appends
// `-<fallbackSuffix>` to a deterministic word and returns. The caller
// is responsible for seeding rng (math/rand).
func PickUnique(rng *rand.Rand, taken map[string]struct{}, fallbackSuffix string) string {
pool := uniqueWords()
if len(pool) == 0 {
return fallbackSuffix
}
for i := 0; i < pickAttempts; i++ {
w := pool[rng.Intn(len(pool))]
if _, ok := taken[w]; !ok {
return w
}
}
for _, w := range pool {
if _, ok := taken[w]; !ok {
return w
}
}
w := pool[rng.Intn(len(pool))]
return fmt.Sprintf("%s-%s", w, fallbackSuffix)
}

View File

@@ -0,0 +1,101 @@
package labelgen
import (
"math/rand"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestPickUnique_DeterministicWithSeededRng locks the property the
// caller relies on: same seed + same taken set → same pick. Without
// that, the bootstrap flow can't reproduce a label across retries.
func TestPickUnique_DeterministicWithSeededRng(t *testing.T) {
taken := map[string]struct{}{}
rngA := rand.New(rand.NewSource(42))
rngB := rand.New(rand.NewSource(42))
a := PickUnique(rngA, taken, "abcd")
b := PickUnique(rngB, taken, "abcd")
assert.Equal(t, a, b, "Same seed and taken set must produce identical pick")
}
// TestPickUnique_AvoidsTakenWordsWhenMostAreReserved seeds taken with
// every word in the pool except a handful and confirms PickUnique
// finds one of the remaining free entries instead of returning the
// fallback form.
func TestPickUnique_AvoidsTakenWordsWhenMostAreReserved(t *testing.T) {
pool := uniqueWords()
require.NotEmpty(t, pool, "wordlist must be populated for the test to mean anything")
free := map[string]struct{}{
pool[0]: {},
pool[len(pool)/2]: {},
pool[len(pool)-1]: {},
}
taken := make(map[string]struct{}, len(pool))
for _, w := range pool {
if _, ok := free[w]; ok {
continue
}
taken[w] = struct{}{}
}
rng := rand.New(rand.NewSource(7))
got := PickUnique(rng, taken, "abcd")
_, isFree := free[got]
assert.True(t, isFree, "PickUnique must return one of the free words; got %q", got)
assert.NotContains(t, got, "-", "Free pick must not be the suffix fallback form")
}
// TestPickUnique_FallsBackWhenAllReserved exhausts the pool and
// confirms PickUnique appends the supplied suffix instead of
// returning a duplicate.
func TestPickUnique_FallsBackWhenAllReserved(t *testing.T) {
pool := uniqueWords()
taken := make(map[string]struct{}, len(pool))
for _, w := range pool {
taken[w] = struct{}{}
}
rng := rand.New(rand.NewSource(99))
got := PickUnique(rng, taken, "abcd")
assert.True(t, strings.HasSuffix(got, "-abcd"), "Exhausted pool must produce <word>-<suffix>; got %q", got)
prefix := strings.TrimSuffix(got, "-abcd")
found := false
for _, w := range pool {
if w == prefix {
found = true
break
}
}
assert.True(t, found, "Fallback prefix must be drawn from the wordlist; got %q", prefix)
}
// TestUniqueWords_DropsDuplicates guards against authoring slips in
// words.go: every entry must be unique and DNS-safe.
func TestUniqueWords_DropsDuplicates(t *testing.T) {
pool := uniqueWords()
seen := make(map[string]struct{}, len(pool))
for _, w := range pool {
_, dup := seen[w]
assert.False(t, dup, "Duplicate entry %q in deduplicated pool", w)
seen[w] = struct{}{}
assert.GreaterOrEqual(t, len(w), 4, "Word %q is shorter than 4 chars", w)
assert.LessOrEqual(t, len(w), 12, "Word %q is longer than 12 chars", w)
for _, r := range w {
ok := r >= 'a' && r <= 'z'
assert.True(t, ok, "Word %q contains non-lowercase-ASCII rune %q", w, r)
}
}
assert.GreaterOrEqual(t, len(pool), 500, "Pool must contain at least 500 unique words")
}

View File

@@ -0,0 +1,136 @@
// Package labelgen produces DNS-safe Agent Network subdomain labels.
//
// The wordlist below is a curated subset drawn from public-domain
// nature / common-noun pools (e.g. EFF's diceware lists). Every entry
// is lowercase ASCII, 412 chars, no hyphens, no digits, and was
// hand-checked to avoid offensive, brand, or region-specific terms.
package labelgen
// words is the pool PickUnique selects from. The slice is intentionally
// not sorted — random picks distribute across the list naturally.
var words = []string{
"acorn", "adobe", "agate", "alder", "almond", "alpine", "amber", "amethyst",
"anchor", "antler", "apple", "apricot", "arcade", "arctic", "arrow", "ashen",
"aspen", "atlas", "atom", "aurora", "autumn", "azure",
"badger", "bamboo", "banana", "banjo", "barley", "barn", "basalt", "basil",
"basin", "bayou", "beach", "beacon", "beaver", "beech", "beetle", "berry",
"birch", "bison", "blossom", "blue", "bobcat", "bonsai", "boulder", "branch",
"brass", "breeze", "bridge", "bright", "brook", "broom", "brown", "buffalo",
"bumble", "burrow", "butter", "button",
"cabin", "cactus", "calm", "camel", "campfire", "canary", "candle", "canoe",
"canyon", "cardinal", "carrot", "cascade", "castle", "cedar", "celery", "cello",
"cement", "cherry", "chestnut", "chime", "cinnamon", "cinder", "citron", "clay",
"clear", "cliff", "clock", "cloud", "clover", "coast", "cobalt", "cobble",
"cocoa", "coffee", "comet", "compass", "copper", "coral", "corner", "cosmos",
"cotton", "cougar", "country", "coyote", "cove", "crane", "crater", "creek",
"crescent", "crimson", "crocus", "crystal", "cypress",
"daffodil", "dahlia", "daisy", "dawn", "deer", "delta", "denim", "desert",
"dewdrop", "diamond", "dolphin", "doodle", "dove", "dragon", "drift", "drop",
"dune", "dusk", "dusty",
"eagle", "earth", "echo", "elder", "elkhorn", "ember", "emerald", "emperor",
"evergreen", "evening",
"falcon", "fawn", "feather", "fern", "fiddle", "field", "fiesta", "finch",
"firepit", "firefly", "fjord", "flame", "flax", "fleece", "flint", "floral",
"flower", "flute", "foal", "foggy", "forest", "fountain", "foxglove", "fresh",
"frost", "fuchsia", "fudge",
"gable", "galaxy", "garden", "garnet", "gazelle", "geode", "geyser", "ginger",
"glacier", "glade", "glass", "glow", "gold", "goose", "gorge", "gourd",
"granite", "grape", "grass", "gravel", "grayling", "greenery", "grizzly", "grove",
"gull", "gumdrop", "gust",
"hammock", "harbor", "harvest", "hawk", "hazel", "heather", "hedge", "heron",
"hibiscus", "hickory", "hideaway", "highland", "hill", "hive", "hollow", "honey",
"hopper", "horizon", "hummingbird", "husky",
"iceberg", "indigo", "iris", "island", "ivory", "ivybush",
"jade", "jasmine", "jasper", "jaybird", "jelly", "jewel", "jonquil", "journey",
"juniper", "jupiter", "jute",
"kale", "kangaroo", "kayak", "kelp", "kestrel", "kettle", "khaki", "kindling",
"kingfisher", "kiwi", "knapweed", "koala",
"lagoon", "lake", "lantern", "larch", "lark", "laurel", "lava", "lavender",
"leaf", "lemon", "lichen", "light", "lilac", "lily", "lime", "limestone",
"linden", "linen", "lion", "lobster", "locust", "loon", "lotus", "lumber",
"lunar", "lupine", "lynx",
"madrone", "magenta", "magnolia", "mahogany", "mallow", "mango", "manor", "maple",
"marble", "marigold", "marina", "marlin", "marsh", "mauve", "meadow", "melody",
"melon", "merlin", "metal", "midnight", "milk", "millet", "mineral", "mint",
"mirror", "mist", "mitten", "molasses", "moon", "moose", "morning", "moss",
"mountain", "mulberry", "muscat", "mustard",
"narwhal", "navy", "nectar", "needle", "nest", "nettle", "newt", "nightfall",
"noon", "nook", "north", "nova", "nutmeg",
"oaken", "oasis", "oatmeal", "ocean", "ochre", "octagon", "olive", "onyx",
"opal", "orange", "orbit", "orchard", "orchid", "oregano", "orion", "osprey",
"otter", "outpost", "owlet", "oyster",
"painter", "palace", "palm", "pansy", "panther", "papaya", "paprika", "parsley",
"partridge", "passage", "pastel", "patio", "peach", "peacock", "pear", "pearl",
"pebble", "pecan", "pelican", "penguin", "peony", "pepper", "perch", "peridot",
"pewter", "phoenix", "pier", "pillar", "pine", "pineapple", "pinto", "piper",
"pistachio", "plain", "planet", "plateau", "platinum", "plum", "plume", "polar",
"pollen", "pond", "poplar", "poppy", "porcelain", "portal", "portrait", "potato",
"prairie", "primrose", "prism", "puffin", "pumpkin",
"quail", "quartz", "quaver", "quill", "quince", "quinoa",
"rabbit", "raccoon", "radish", "rain", "rainbow", "raindrop", "rapids", "raspberry",
"raven", "ravine", "redwood", "reed", "reef", "ridge", "river", "robin",
"rocket", "rubyred", "rose", "rosemary", "rosewood", "ruffle", "rugby", "russet",
"rustic", "ryefield",
"saffron", "sage", "salmon", "sand", "sandstone", "sapphire", "savanna", "scarlet",
"scout", "seal", "season", "seaweed", "sequoia", "shadow", "shamrock", "shell",
"sherbet", "shore", "silver", "siskin", "skybloom", "skyline", "sleet", "smoke",
"snail", "snapdragon", "snow", "snowflake", "snowy", "solar", "song", "sonic",
"sorrel", "south", "sparkle", "sparrow", "spice", "spider", "spinach", "spire",
"spring", "sprout", "spruce", "squirrel", "starfish", "starlight", "stoat", "stone",
"stork", "storm", "stream", "studio", "summer", "sunbeam", "sundew", "sunny",
"sunrise", "sunset", "swallow", "swan", "sweet", "sycamore",
"tangelo", "tangerine", "tansy", "taupe", "teak", "teal", "thicket", "thistle",
"thrush", "thunder", "tide", "tiger", "tinder", "topaz", "torch", "tortoise",
"tower", "trail", "tranquil", "tundra", "tulip", "turquoise", "turtle", "twig",
"twilight",
"umber", "uplands",
"valley", "vanilla", "velvet", "venus", "verdant", "verdigris", "vermilion", "violet",
"vista", "vivid", "volcano", "vortex",
"walnut", "warbler", "watercress", "waterfall", "wave", "waxwing", "weasel", "westwind",
"whale", "whisker", "whisper", "wicker", "wildwood", "willow", "winter", "wisp",
"wisteria", "wolf", "wombat", "woodland", "woolly", "wren", "wreath",
"yarrow", "yellow", "yewtree", "yodel",
"zebra", "zenith", "zephyr", "zinnia",
"alabaster", "alfalfa", "almanac", "anise", "antelope", "arbor", "arena", "armadillo",
"avocet", "azalea", "balsam", "bayou", "beacon", "blizzard", "bluebell", "bluebird",
"bluejay", "bobolink", "borage", "boreal", "buckeye", "buckthorn", "buttercup",
"cabana", "calico", "canopy", "caraway", "cardamom", "cattail", "celadon", "centaur",
"chambray", "chamois", "champlain", "chestnuts", "chickadee", "chinook", "chipmunk", "cinnabar",
"cirrus", "citrine", "clematis", "copperhead",
"crocodile", "currant", "cuttlebone", "daffy", "dapple", "delphinium", "dervish", "diamondback",
"dogwood", "dolphins", "dragonfly", "driftwood", "dusk", "dustpan", "ebony", "edelweiss",
"emperor", "endive", "estuary", "everglade", "fairway", "feldspar", "fennel", "fieldstone",
"firebrand", "firefly", "fireweed", "firework", "flagstone", "fossil", "frostbite", "galleon",
"gardener", "geranium", "gingko", "ginseng", "goldfish", "goldfinch", "goldenrod", "graphite",
"greenfinch", "guppy", "haiku", "halibut", "hammerhead", "harbinger", "harvest", "hatchling",
"havana", "hawthorn", "hazelnut", "heartwood", "henna", "heron", "highrise", "homestead",
"honeycomb", "honeydew", "horseshoe", "hyacinth", "iceland", "icicle", "indigobird", "ironwood",
"jacaranda", "jamboree", "javelina", "jellyfish", "junebug", "kaleido", "kayaker", "kerchief",
"keystone", "kingdom", "labrador", "lacewing", "ladybug", "lakeside", "lamplight", "leopard",
"lighthouse", "lilypad", "lullaby", "magnet", "mahonia", "mandolin", "manzanita", "maraschino",
"mariner", "marsupial", "mastodon", "matterhorn", "mayflower", "mayfly", "meadowlark", "merlot",
"meteor", "midshipman", "millpond", "mimosa", "minnow", "mockingbird", "molten", "monarch",
"monsoon", "moondust", "moonlight", "moorland", "morning", "mossland", "mountain", "mulch",
"narcissus", "nautilus", "nettlebush", "northstar", "nuthatch", "obsidian", "okra", "olivine",
"opalescent", "orchidea", "orchard", "ornament", "outrigger", "oxalis", "paddler", "paintbrush",
"papyrus", "paradise", "pasture", "patchwork", "pathway", "peridot", "periwinkle", "petalbloom",
"petrel", "petunia", "phlox", "pikeperch", "pinecone", "pioneer", "pipevine", "platypus",
"pomelo", "pondweed", "porpoise", "powder", "promise", "puddle", "pumice", "puzzle",
"quetzal", "quicksilver", "raccoon", "ragwort", "rainforest", "ramble", "rapid", "rascal",
"raspberry", "redbud", "redfern", "redpoll", "reedling", "ringtail", "riverbed", "riverbird",
"riverstone", "rockcress", "roebuck", "rosebay", "rosehip", "rosemary", "rowan", "rumble",
"runaway", "rustler", "sagebrush", "sailcloth", "salamander", "salsify", "samphire", "sandbar",
"sanddollar", "sandpiper", "santolina", "sapodilla", "sassafras", "scallion", "schooner", "seafoam",
"seafrost", "seagrass", "seahorse", "seaport", "seashell", "seaspray", "shamble", "shimmer",
"shoreline", "silkmoth", "silverfox", "skylark", "snapdragon", "snowberry", "snowdrop", "snowfall",
"snowmelt", "softwood", "songbird", "sorghum", "southwind", "speedwell", "spinnaker", "spruce",
"starlight", "starling", "stormcloud", "summit", "sundance", "sundew", "sundial", "sunflower",
"surface", "swallowtail", "sweetcorn", "sycamore", "tabletop", "tamarack", "tamarind", "tangerine",
"tarragon", "telescope", "thicket", "thrasher", "thunder", "thyme", "tideline", "timberland",
"tinderbox", "topiary", "torchwood", "totem", "tradewind", "treasure", "tremolo", "trinket",
"trumpetvine", "tugboat", "tundra", "turnstone", "underbrush", "vagabond", "valerian", "vanilla",
"velveteen", "vermilion", "vinca", "vineyard", "violet", "voyager", "wagonwheel", "walnutwood",
"watermark", "watershed", "waterway", "wavefront", "westerly", "whaleback", "whetstone", "wicker",
"wildbloom", "wildflower", "wilderness", "windsong", "windward", "winterberry", "woodbine", "woodfern",
"woodland", "woodthrush", "woolgrass", "yellowfin", "zenithal", "zucchini",
}

View File

@@ -0,0 +1,896 @@
package agentnetwork
import (
"context"
"errors"
"fmt"
"math/rand"
"slices"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/labelgen"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
// ensureSessionKeys mints an ed25519 session keypair on the provider
// when one is missing. Idempotent: skips when both fields are already
// populated (e.g. update or migrated rows). The keys are used by the
// synthesised reverse-proxy service to sign / verify session JWTs
// after a successful OIDC handshake.
func ensureSessionKeys(p *types.Provider) error {
if p.SessionPrivateKey != "" && p.SessionPublicKey != "" {
return nil
}
pair, err := sessionkey.GenerateKeyPair()
if err != nil {
return fmt.Errorf("generate provider session keys: %w", err)
}
p.SessionPrivateKey = pair.PrivateKey
p.SessionPublicKey = pair.PublicKey
return nil
}
// Manager governs the lifecycle of Agent Network providers and policies.
type Manager interface {
GetAllProviders(ctx context.Context, accountID, userID string) ([]*types.Provider, error)
GetProvider(ctx context.Context, accountID, userID, providerID string) (*types.Provider, error)
CreateProvider(ctx context.Context, userID string, provider *types.Provider, bootstrapCluster string) (*types.Provider, error)
UpdateProvider(ctx context.Context, userID string, provider *types.Provider) (*types.Provider, error)
DeleteProvider(ctx context.Context, accountID, userID, providerID string) error
GetAllPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
GetPolicy(ctx context.Context, accountID, userID, policyID string) (*types.Policy, error)
CreatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error)
UpdatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error)
DeletePolicy(ctx context.Context, accountID, userID, policyID string) error
GetAllGuardrails(ctx context.Context, accountID, userID string) ([]*types.Guardrail, error)
GetGuardrail(ctx context.Context, accountID, userID, guardrailID string) (*types.Guardrail, error)
CreateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error)
UpdateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error)
DeleteGuardrail(ctx context.Context, accountID, userID, guardrailID string) error
GetAllBudgetRules(ctx context.Context, accountID, userID string) ([]*types.AccountBudgetRule, error)
GetBudgetRule(ctx context.Context, accountID, userID, ruleID string) (*types.AccountBudgetRule, error)
CreateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error)
UpdateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error)
DeleteBudgetRule(ctx context.Context, accountID, userID, ruleID string) error
GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error)
UpdateSettings(ctx context.Context, userID string, settings *types.Settings) (*types.Settings, error)
ListConsumption(ctx context.Context, accountID, userID string) ([]*types.Consumption, error)
ListAccessLogs(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error)
GetUsageOverview(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter, granularity types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error)
StartAccessLogCleanup(ctx context.Context, cleanupIntervalHours int)
RecordConsumption(ctx context.Context, accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds, tokensIn, tokensOut int64, costUSD float64) error
RecordAccountBudgetUsage(ctx context.Context, accountID, userID string, groupIDs []string, tokensIn, tokensOut int64, costUSD float64) error
RecordUsage(ctx context.Context, in RecordUsageInput) error
SelectPolicyForRequest(ctx context.Context, in PolicySelectionInput) (*PolicySelectionResult, error)
}
// PolicySelectionInput is the per-request selection envelope. The
// proxy populates it from CapturedData (account, user, groups) plus
// the provider llm_router resolved.
type PolicySelectionInput struct {
AccountID string
UserID string
GroupIDs []string
ProviderID string
}
// PolicySelectionResult names the policy that "pays" for this request
// plus the deny envelope when every applicable policy has exhausted
// every cap. AttributionGroupID is the lowest group id (string sort)
// of caller_groups ∩ selected_policy.source_groups; empty when no
// group dimension applies. WindowSeconds is the chosen policy's
// effective window length in seconds (token_limit's wins when both
// halves are enabled with mismatched windows; budget_limit's
// otherwise; 0 when no caps are configured at all).
type PolicySelectionResult struct {
Allow bool
SelectedPolicyID string
AttributionGroupID string
WindowSeconds int64
DenyCode string
DenyReason string
}
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
proxyController proxy.Controller
// reconcileCache holds the last set of synthesised proxy mappings
// per account so reconcile can emit precise Create/Update/Delete
// updates instead of a full re-push on every mutation. Keyed by
// accountID, then by synthesised service ID.
reconcileMu sync.Mutex
reconcileCache map[string]map[string]*proto.ProxyMapping
// labelRngMu guards labelRng. PickUnique consumes math/rand.Source
// state; concurrent provider creates would otherwise race.
labelRngMu sync.Mutex
labelRng *rand.Rand
}
// NewManager constructs the persistent Agent Network manager. The
// manager persists provider/policy/guardrail configuration and, on
// every mutation, reconciles the in-memory synthesised reverse-proxy
// services with the proxy cluster via proxyController. Pass nil for
// proxyController to disable the reconcile push (useful in tests).
func NewManager(
store store.Store,
permissionsManager permissions.Manager,
accountManager account.Manager,
proxyController proxy.Controller,
) Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyController: proxyController,
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
labelRng: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
func (m *managerImpl) GetAllProviders(ctx context.Context, accountID, userID string) ([]*types.Provider, error) {
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
return nil, err
}
return m.store.GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, accountID)
}
func (m *managerImpl) GetProvider(ctx context.Context, accountID, userID, providerID string) (*types.Provider, error) {
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
return nil, err
}
return m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthNone, accountID, providerID)
}
// CreateProvider persists a new provider for the account. bootstrapCluster
// is used only when the per-account agent-network Settings row hasn't
// been created yet; otherwise it is ignored (the cluster is pinned on
// Settings and every provider in the account routes through it).
func (m *managerImpl) CreateProvider(ctx context.Context, userID string, provider *types.Provider, bootstrapCluster string) (*types.Provider, error) {
if err := m.requirePermission(ctx, provider.AccountID, userID, operations.Create); err != nil {
return nil, err
}
// An empty api_key would silently produce a synthesised service
// that 401s on every upstream request. Surface the misconfiguration
// at create time instead.
if strings.TrimSpace(provider.APIKey) == "" {
return nil, status.Errorf(status.InvalidArgument, "api_key is required when creating an agent network provider")
}
if provider.ID == "" {
fresh := types.NewProvider(provider.AccountID)
provider.ID = fresh.ID
provider.CreatedAt = fresh.CreatedAt
provider.UpdatedAt = fresh.UpdatedAt
}
if err := ensureSessionKeys(provider); err != nil {
return nil, err
}
if err := m.store.SaveAgentNetworkProvider(ctx, provider); err != nil {
return nil, fmt.Errorf("save agent network provider: %w", err)
}
if strings.TrimSpace(bootstrapCluster) != "" {
if _, err := m.bootstrapSettingsIfNeeded(ctx, provider.AccountID, bootstrapCluster); err != nil {
// The provider create has already succeeded; logging the
// bootstrap miss matches the plan's PoC behaviour. The synth
// path treats a missing settings row as a no-op, and the next
// provider create retries the bootstrap.
log.WithContext(ctx).Debugf("agent-network bootstrap settings for account %s on cluster %s: %v", provider.AccountID, bootstrapCluster, err)
}
}
m.accountManager.StoreEvent(ctx, userID, provider.ID, provider.AccountID, activity.AgentNetworkProviderCreated, provider.EventMeta())
m.reconcile(ctx, provider.AccountID)
return provider, nil
}
func (m *managerImpl) UpdateProvider(ctx context.Context, userID string, provider *types.Provider) (*types.Provider, error) {
if err := m.requirePermission(ctx, provider.AccountID, userID, operations.Update); err != nil {
return nil, err
}
existing, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthUpdate, provider.AccountID, provider.ID)
if err != nil {
return nil, fmt.Errorf("failed to get agent network provider: %w", err)
}
// Preserve the API key if the caller didn't rotate it. A
// whitespace-only value is treated as "not rotated" rather than a
// real key, but it must not silently overwrite a valid stored key.
if provider.APIKey == "" {
provider.APIKey = existing.APIKey
} else if strings.TrimSpace(provider.APIKey) == "" {
return nil, status.Errorf(status.InvalidArgument, "api_key must be non-blank when rotating an agent network provider")
}
// Always preserve the session keypair across updates so existing
// session cookies stay valid. The keys are server-managed and
// never surfaced through the API.
provider.SessionPrivateKey = existing.SessionPrivateKey
provider.SessionPublicKey = existing.SessionPublicKey
if err := ensureSessionKeys(provider); err != nil {
return nil, err
}
provider.CreatedAt = existing.CreatedAt
provider.UpdatedAt = time.Now().UTC()
if err := m.store.SaveAgentNetworkProvider(ctx, provider); err != nil {
return nil, fmt.Errorf("save agent network provider: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, provider.ID, provider.AccountID, activity.AgentNetworkProviderUpdated, provider.EventMeta())
m.reconcile(ctx, provider.AccountID)
return provider, nil
}
func (m *managerImpl) DeleteProvider(ctx context.Context, accountID, userID, providerID string) error {
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
return err
}
provider, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthUpdate, accountID, providerID)
if err != nil {
return fmt.Errorf("failed to get agent network provider: %w", err)
}
// Refuse to delete while any policy still references this provider.
// The operator must detach it first.
policies, err := m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get agent network policies: %w", err)
}
var blocking []string
for _, p := range policies {
if slices.Contains(p.DestinationProviderIDs, providerID) {
blocking = append(blocking, p.Name)
}
}
if len(blocking) > 0 {
return status.Errorf(
status.InvalidArgument,
"provider is in use by %d %s (%s); detach it before deleting",
len(blocking),
pluralize(len(blocking), "policy", "policies"),
strings.Join(blocking, ", "),
)
}
if err := m.store.DeleteAgentNetworkProvider(ctx, accountID, providerID); err != nil {
return fmt.Errorf("failed to delete agent network provider: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, providerID, accountID, activity.AgentNetworkProviderDeleted, provider.EventMeta())
m.reconcile(ctx, accountID)
return nil
}
func pluralize(n int, singular, plural string) string {
if n == 1 {
return singular
}
return plural
}
func (m *managerImpl) GetAllPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) {
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
return nil, err
}
return m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, accountID)
}
func (m *managerImpl) GetPolicy(ctx context.Context, accountID, userID, policyID string) (*types.Policy, error) {
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
return nil, err
}
return m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID)
}
func (m *managerImpl) CreatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error) {
if err := m.requirePermission(ctx, policy.AccountID, userID, operations.Create); err != nil {
return nil, err
}
if policy.ID == "" {
fresh := types.NewPolicy(policy.AccountID)
policy.ID = fresh.ID
policy.CreatedAt = fresh.CreatedAt
policy.UpdatedAt = fresh.UpdatedAt
}
if err := m.validateProviderRefs(ctx, policy.AccountID, policy.DestinationProviderIDs); err != nil {
return nil, err
}
if err := m.store.SaveAgentNetworkPolicy(ctx, policy); err != nil {
return nil, fmt.Errorf("failed to save agent network policy: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, policy.ID, policy.AccountID, activity.AgentNetworkPolicyCreated, policy.EventMeta())
m.reconcile(ctx, policy.AccountID)
return policy, nil
}
func (m *managerImpl) UpdatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error) {
if err := m.requirePermission(ctx, policy.AccountID, userID, operations.Update); err != nil {
return nil, err
}
existing, err := m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthUpdate, policy.AccountID, policy.ID)
if err != nil {
return nil, fmt.Errorf("failed to get agent network policy: %w", err)
}
if err := m.validateProviderRefs(ctx, policy.AccountID, policy.DestinationProviderIDs); err != nil {
return nil, err
}
policy.CreatedAt = existing.CreatedAt
policy.UpdatedAt = time.Now().UTC()
if err := m.store.SaveAgentNetworkPolicy(ctx, policy); err != nil {
return nil, fmt.Errorf("failed to save agent network policy: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, policy.ID, policy.AccountID, activity.AgentNetworkPolicyUpdated, policy.EventMeta())
m.reconcile(ctx, policy.AccountID)
return policy, nil
}
func (m *managerImpl) DeletePolicy(ctx context.Context, accountID, userID, policyID string) error {
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
return err
}
policy, err := m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID)
if err != nil {
return fmt.Errorf("failed to get agent network policy: %w", err)
}
if err := m.store.DeleteAgentNetworkPolicy(ctx, accountID, policyID); err != nil {
return fmt.Errorf("failed to delete agent network policy: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, policyID, accountID, activity.AgentNetworkPolicyDeleted, policy.EventMeta())
m.reconcile(ctx, accountID)
return nil
}
func (m *managerImpl) GetAllGuardrails(ctx context.Context, accountID, userID string) ([]*types.Guardrail, error) {
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
return nil, err
}
return m.store.GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, accountID)
}
func (m *managerImpl) GetGuardrail(ctx context.Context, accountID, userID, guardrailID string) (*types.Guardrail, error) {
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
return nil, err
}
return m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthNone, accountID, guardrailID)
}
func (m *managerImpl) CreateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error) {
if err := m.requirePermission(ctx, guardrail.AccountID, userID, operations.Create); err != nil {
return nil, err
}
if guardrail.ID == "" {
fresh := types.NewGuardrail(guardrail.AccountID)
guardrail.ID = fresh.ID
guardrail.CreatedAt = fresh.CreatedAt
guardrail.UpdatedAt = fresh.UpdatedAt
}
if err := m.store.SaveAgentNetworkGuardrail(ctx, guardrail); err != nil {
return nil, fmt.Errorf("failed to save agent network guardrail: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, guardrail.ID, guardrail.AccountID, activity.AgentNetworkGuardrailCreated, guardrail.EventMeta())
m.reconcile(ctx, guardrail.AccountID)
return guardrail, nil
}
func (m *managerImpl) UpdateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error) {
if err := m.requirePermission(ctx, guardrail.AccountID, userID, operations.Update); err != nil {
return nil, err
}
existing, err := m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthUpdate, guardrail.AccountID, guardrail.ID)
if err != nil {
return nil, fmt.Errorf("failed to get agent network guardrail: %w", err)
}
guardrail.CreatedAt = existing.CreatedAt
guardrail.UpdatedAt = time.Now().UTC()
if err := m.store.SaveAgentNetworkGuardrail(ctx, guardrail); err != nil {
return nil, fmt.Errorf("failed to save agent network guardrail: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, guardrail.ID, guardrail.AccountID, activity.AgentNetworkGuardrailUpdated, guardrail.EventMeta())
m.reconcile(ctx, guardrail.AccountID)
return guardrail, nil
}
func (m *managerImpl) DeleteGuardrail(ctx context.Context, accountID, userID, guardrailID string) error {
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
return err
}
guardrail, err := m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthUpdate, accountID, guardrailID)
if err != nil {
return fmt.Errorf("failed to get agent network guardrail: %w", err)
}
if err := m.store.DeleteAgentNetworkGuardrail(ctx, accountID, guardrailID); err != nil {
return fmt.Errorf("failed to delete agent network guardrail: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, guardrailID, accountID, activity.AgentNetworkGuardrailDeleted, guardrail.EventMeta())
m.reconcile(ctx, accountID)
return nil
}
// GetAllBudgetRules returns every account-level budget rule for the account.
func (m *managerImpl) GetAllBudgetRules(ctx context.Context, accountID, userID string) ([]*types.AccountBudgetRule, error) {
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
return nil, err
}
return m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, accountID)
}
// GetBudgetRule returns a single account-level budget rule.
func (m *managerImpl) GetBudgetRule(ctx context.Context, accountID, userID, ruleID string) (*types.AccountBudgetRule, error) {
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
return nil, err
}
return m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthNone, accountID, ruleID)
}
// CreateBudgetRule persists a new account-level budget rule. Budget rules are
// enforced at request time (CheckLLMPolicyLimits), not baked into the synth
// proxy config, so no reconcile is needed.
func (m *managerImpl) CreateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
if err := m.requirePermission(ctx, rule.AccountID, userID, operations.Create); err != nil {
return nil, err
}
if rule.ID == "" {
fresh := types.NewAccountBudgetRule(rule.AccountID)
rule.ID = fresh.ID
rule.CreatedAt = fresh.CreatedAt
rule.UpdatedAt = fresh.UpdatedAt
}
if err := m.store.SaveAgentNetworkBudgetRule(ctx, rule); err != nil {
return nil, fmt.Errorf("save agent network budget rule: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, rule.ID, rule.AccountID, activity.AgentNetworkBudgetRuleCreated, rule.EventMeta())
return rule, nil
}
// UpdateBudgetRule updates an existing account-level budget rule.
func (m *managerImpl) UpdateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
if err := m.requirePermission(ctx, rule.AccountID, userID, operations.Update); err != nil {
return nil, err
}
existing, err := m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthUpdate, rule.AccountID, rule.ID)
if err != nil {
return nil, fmt.Errorf("get agent network budget rule: %w", err)
}
rule.CreatedAt = existing.CreatedAt
rule.UpdatedAt = time.Now().UTC()
if err := m.store.SaveAgentNetworkBudgetRule(ctx, rule); err != nil {
return nil, fmt.Errorf("save agent network budget rule: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, rule.ID, rule.AccountID, activity.AgentNetworkBudgetRuleUpdated, rule.EventMeta())
return rule, nil
}
// DeleteBudgetRule removes an account-level budget rule.
func (m *managerImpl) DeleteBudgetRule(ctx context.Context, accountID, userID, ruleID string) error {
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
return err
}
rule, err := m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthUpdate, accountID, ruleID)
if err != nil {
return fmt.Errorf("get agent network budget rule: %w", err)
}
if err := m.store.DeleteAgentNetworkBudgetRule(ctx, accountID, ruleID); err != nil {
return fmt.Errorf("delete agent network budget rule: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, ruleID, accountID, activity.AgentNetworkBudgetRuleDeleted, rule.EventMeta())
return nil
}
// UpdateSettings applies the mutable account-level settings — the collection
// toggles — onto the existing row. Cluster and Subdomain are immutable and are
// preserved from the persisted row regardless of the input. Because the
// collection toggles change the synthesised service config (prompt-capture
// gating, access-log emission), a reconcile is triggered so the proxy and peer
// network maps converge on the new state.
func (m *managerImpl) UpdateSettings(ctx context.Context, userID string, settings *types.Settings) (*types.Settings, error) {
if err := m.requirePermission(ctx, settings.AccountID, userID, operations.Update); err != nil {
return nil, err
}
existing, err := m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthUpdate, settings.AccountID)
if err != nil {
return nil, fmt.Errorf("get agent network settings: %w", err)
}
existing.EnableLogCollection = settings.EnableLogCollection
existing.EnablePromptCollection = settings.EnablePromptCollection
existing.RedactPii = settings.RedactPii
existing.AccessLogRetentionDays = settings.AccessLogRetentionDays
existing.UpdatedAt = time.Now().UTC()
if err := m.store.SaveAgentNetworkSettings(ctx, existing); err != nil {
return nil, fmt.Errorf("save agent network settings: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, settings.AccountID, settings.AccountID, activity.AgentNetworkSettingsUpdated, map[string]any{
"log_collection": existing.EnableLogCollection,
"prompt_collection": existing.EnablePromptCollection,
"redact_pii": existing.RedactPii,
})
m.reconcile(ctx, settings.AccountID)
return existing, nil
}
// validateProviderRefs ensures every destination provider id refers to a
// provider that exists in the same account.
func (m *managerImpl) validateProviderRefs(ctx context.Context, accountID string, providerIDs []string) error {
if len(providerIDs) == 0 {
return nil
}
for _, id := range providerIDs {
if _, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthNone, accountID, id); err != nil {
// Only a genuine not-found means the reference is invalid; a
// store/runtime error must propagate as-is rather than be
// masked as a client validation error.
var sErr *status.Error
if errors.As(err, &sErr) && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "destination_provider_ids: provider %s does not exist", id)
}
return fmt.Errorf("get destination provider %s: %w", id, err)
}
}
return nil
}
// GetSettings returns the agent-network settings row for the account.
// Returns the underlying status.NotFound when no row has been
// bootstrapped yet (i.e. the account has no providers).
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
return nil, err
}
return m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID)
}
// bootstrapSettingsIfNeeded creates the per-account agent-network
// settings row when missing. The cluster comes from the create-time
// hint the dashboard sends (auto-picked from the active cluster list);
// the subdomain is picked from the curated wordlist avoiding
// collisions on the same cluster. Idempotent: if a row already exists
// it is returned untouched and the hint is ignored.
func (m *managerImpl) bootstrapSettingsIfNeeded(ctx context.Context, accountID, providerCluster string) (*types.Settings, error) {
if accountID == "" {
return nil, fmt.Errorf("bootstrap settings: account id is required")
}
if strings.TrimSpace(providerCluster) == "" {
return nil, fmt.Errorf("bootstrap settings: provider cluster is required")
}
existing, err := m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID)
if err == nil {
return existing, nil
}
var sErr *status.Error
if !errors.As(err, &sErr) || sErr.Type() != status.NotFound {
return nil, fmt.Errorf("get agent network settings: %w", err)
}
siblings, err := m.store.GetAgentNetworkSettingsByCluster(ctx, store.LockingStrengthNone, providerCluster)
if err != nil {
return nil, fmt.Errorf("list agent network settings on cluster: %w", err)
}
taken := make(map[string]struct{}, len(siblings))
for _, s := range siblings {
taken[s.Subdomain] = struct{}{}
}
suffix := accountID
if len(suffix) > 4 {
suffix = suffix[:4]
}
m.labelRngMu.Lock()
subdomain := labelgen.PickUnique(m.labelRng, taken, suffix)
m.labelRngMu.Unlock()
now := time.Now().UTC()
settings := &types.Settings{
AccountID: accountID,
Cluster: providerCluster,
Subdomain: subdomain,
// Logs on by default; usage is collected regardless. Retention bounds
// how long full log rows are kept.
EnableLogCollection: true,
AccessLogRetentionDays: types.DefaultAccessLogRetentionDays,
CreatedAt: now,
UpdatedAt: now,
}
if err := m.store.SaveAgentNetworkSettings(ctx, settings); err != nil {
return nil, fmt.Errorf("save agent network settings: %w", err)
}
return settings, nil
}
// ListConsumption returns every consumption row recorded for the
// account, ordered window-newest-first. Backs the dashboard's basic
// counter view; permission gate is the same Read role that gates
// every other agent-network surface.
func (m *managerImpl) ListConsumption(ctx context.Context, accountID, userID string) ([]*types.Consumption, error) {
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
return nil, err
}
return m.store.ListAgentNetworkConsumption(ctx, store.LockingStrengthNone, accountID)
}
// ListAccessLogs returns a paginated, server-side-filtered page of
// agent-network access logs plus the total count matching the filter.
func (m *managerImpl) ListAccessLogs(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error) {
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
return nil, 0, err
}
return m.store.GetAgentNetworkAccessLogs(ctx, store.LockingStrengthNone, accountID, filter)
}
// GetUsageOverview returns the filtered usage rows aggregated into time buckets
// at the requested granularity, oldest-first.
func (m *managerImpl) GetUsageOverview(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter, granularity types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error) {
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
return nil, err
}
rows, err := m.store.GetAgentNetworkUsageRows(ctx, store.LockingStrengthNone, accountID, filter)
if err != nil {
return nil, err
}
return types.AggregateUsageByGranularity(rows, granularity), nil
}
// StartAccessLogCleanup launches a background sweep that periodically deletes
// each account's agent-network access-log rows older than that account's
// AccessLogRetentionDays. Usage records are never swept. A non-positive
// interval defaults to 24h.
func (m *managerImpl) StartAccessLogCleanup(ctx context.Context, cleanupIntervalHours int) {
if cleanupIntervalHours <= 0 {
cleanupIntervalHours = 24
}
interval := time.Duration(cleanupIntervalHours) * time.Hour
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
m.cleanupAccessLogsOnce(ctx) // run once on startup
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.cleanupAccessLogsOnce(ctx)
}
}
}()
}
// cleanupAccessLogsOnce sweeps every account's expired access-log rows against
// its configured retention. Best-effort: a per-account failure is logged and
// the sweep continues.
func (m *managerImpl) cleanupAccessLogsOnce(ctx context.Context) {
settings, err := m.store.GetAllAgentNetworkSettings(ctx, store.LockingStrengthNone)
if err != nil {
log.WithContext(ctx).Errorf("agent-network access-log cleanup: list settings: %v", err)
return
}
for _, s := range settings {
if s.AccessLogRetentionDays <= 0 {
continue // keep indefinitely
}
cutoff := time.Now().UTC().AddDate(0, 0, -s.AccessLogRetentionDays)
deleted, err := m.store.DeleteOldAgentNetworkAccessLogs(ctx, s.AccountID, cutoff)
if err != nil {
log.WithContext(ctx).Warnf("agent-network access-log cleanup for account %s: %v", s.AccountID, err)
continue
}
if deleted > 0 {
log.WithContext(ctx).Infof("agent-network access-log cleanup: deleted %d rows for account %s (retention %d days)", deleted, s.AccountID, s.AccessLogRetentionDays)
}
}
}
// RecordConsumption increments the (dim, window) counter by the
// supplied deltas. The window_start is computed from time.Now under
// the supplied window_seconds so callers don't have to pre-align —
// the proxy's post-flight path simply hands us tokens + cost and
// which dimension we're booking against.
func (m *managerImpl) RecordConsumption(ctx context.Context, accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds, tokensIn, tokensOut int64, costUSD float64) error {
if accountID == "" || dimID == "" || windowSeconds <= 0 {
return status.Errorf(status.InvalidArgument, "account_id, dim_id and window_seconds must be set")
}
windowStart := types.WindowStart(time.Now(), windowSeconds)
return m.store.IncrementAgentNetworkConsumption(ctx, accountID, kind, dimID, windowSeconds, windowStart, tokensIn, tokensOut, costUSD)
}
func (m *managerImpl) requirePermission(ctx context.Context, accountID, userID string, op operations.Operation) error {
ok, _, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.AgentNetwork, op)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
return nil
}
type mockManager struct{}
// NewManagerMock returns a no-op manager useful for tests.
func NewManagerMock() Manager {
return &mockManager{}
}
func (*mockManager) GetAllProviders(_ context.Context, _, _ string) ([]*types.Provider, error) {
return []*types.Provider{}, nil
}
func (*mockManager) GetProvider(_ context.Context, _, _, _ string) (*types.Provider, error) {
return &types.Provider{}, nil
}
func (*mockManager) CreateProvider(_ context.Context, _ string, p *types.Provider, _ string) (*types.Provider, error) {
return p, nil
}
func (*mockManager) UpdateProvider(_ context.Context, _ string, p *types.Provider) (*types.Provider, error) {
return p, nil
}
func (*mockManager) DeleteProvider(_ context.Context, _, _, _ string) error { return nil }
func (*mockManager) GetAllPolicies(_ context.Context, _, _ string) ([]*types.Policy, error) {
return []*types.Policy{}, nil
}
func (*mockManager) GetPolicy(_ context.Context, _, _, _ string) (*types.Policy, error) {
return &types.Policy{}, nil
}
func (*mockManager) CreatePolicy(_ context.Context, _ string, p *types.Policy) (*types.Policy, error) {
return p, nil
}
func (*mockManager) UpdatePolicy(_ context.Context, _ string, p *types.Policy) (*types.Policy, error) {
return p, nil
}
func (*mockManager) DeletePolicy(_ context.Context, _, _, _ string) error { return nil }
func (*mockManager) GetAllGuardrails(_ context.Context, _, _ string) ([]*types.Guardrail, error) {
return []*types.Guardrail{}, nil
}
func (*mockManager) GetGuardrail(_ context.Context, _, _, _ string) (*types.Guardrail, error) {
return &types.Guardrail{}, nil
}
func (*mockManager) CreateGuardrail(_ context.Context, _ string, g *types.Guardrail) (*types.Guardrail, error) {
return g, nil
}
func (*mockManager) UpdateGuardrail(_ context.Context, _ string, g *types.Guardrail) (*types.Guardrail, error) {
return g, nil
}
func (*mockManager) DeleteGuardrail(_ context.Context, _, _, _ string) error { return nil }
func (*mockManager) GetAllBudgetRules(_ context.Context, _, _ string) ([]*types.AccountBudgetRule, error) {
return []*types.AccountBudgetRule{}, nil
}
func (*mockManager) GetBudgetRule(_ context.Context, _, _, _ string) (*types.AccountBudgetRule, error) {
return &types.AccountBudgetRule{}, nil
}
func (*mockManager) CreateBudgetRule(_ context.Context, _ string, r *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
return r, nil
}
func (*mockManager) UpdateBudgetRule(_ context.Context, _ string, r *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
return r, nil
}
func (*mockManager) DeleteBudgetRule(_ context.Context, _, _, _ string) error { return nil }
func (*mockManager) GetSettings(_ context.Context, _, _ string) (*types.Settings, error) {
return nil, status.Errorf(status.NotFound, "agent network settings not found")
}
func (*mockManager) UpdateSettings(_ context.Context, _ string, s *types.Settings) (*types.Settings, error) {
return s, nil
}
func (*mockManager) ListConsumption(_ context.Context, _, _ string) ([]*types.Consumption, error) {
return nil, nil
}
func (*mockManager) ListAccessLogs(_ context.Context, _, _ string, _ types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error) {
return nil, 0, nil
}
func (*mockManager) GetUsageOverview(_ context.Context, _, _ string, _ types.AgentNetworkAccessLogFilter, _ types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error) {
return nil, nil
}
func (*mockManager) StartAccessLogCleanup(_ context.Context, _ int) {}
func (*mockManager) RecordConsumption(_ context.Context, _ string, _ types.ConsumptionDimension, _ string, _, _, _ int64, _ float64) error {
return nil
}
func (*mockManager) RecordAccountBudgetUsage(_ context.Context, _, _ string, _ []string, _, _ int64, _ float64) error {
return nil
}
func (*mockManager) RecordUsage(_ context.Context, _ RecordUsageInput) error {
return nil
}

View File

@@ -0,0 +1,660 @@
package agentnetwork
import (
"context"
"fmt"
"math"
"sort"
"time"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
// validateUsageDeltas rejects negative or non-finite usage counters before they
// reach the consumption store, so a bad delta can't decrement or poison totals.
// The store batch method enforces the same invariant; this is the manager-level
// guard so direct callers fail fast with a clear error.
func validateUsageDeltas(tokensIn, tokensOut int64, costUSD float64) error {
if tokensIn < 0 || tokensOut < 0 || costUSD < 0 || math.IsNaN(costUSD) || math.IsInf(costUSD, 0) {
return status.Errorf(status.InvalidArgument, "usage deltas must be non-negative and finite")
}
return nil
}
// Deny codes the proxy surfaces back to the caller when every
// applicable policy is exhausted. The proxy converts these into
// upstream-shaped error responses.
const (
//nolint:gosec // policy deny code label, not a credential
denyCodeTokenCapExceeded = "llm_policy.token_cap_exceeded"
//nolint:gosec // policy deny code label, not a credential
denyCodeBudgetCapExceeded = "llm_policy.budget_cap_exceeded"
//nolint:gosec // account deny code label, not a credential
denyCodeAccountTokenCapExceeded = "llm_account.token_cap_exceeded"
//nolint:gosec // account deny code label, not a credential
denyCodeAccountBudgetCapExceeded = "llm_account.budget_cap_exceeded"
)
// consumptionCache holds the consumption counters prefetched for one
// policy-selection request, keyed by ConsumptionKey. A miss returns a zero
// counter — the same contract the store's single-row getter uses for absent
// rows — so the eval logic is identical whether a counter exists yet or not.
type consumptionCache map[types.ConsumptionKey]*types.Consumption
func (c consumptionCache) get(accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds int64, windowStart time.Time) *types.Consumption {
key := types.ConsumptionKey{Kind: kind, DimID: dimID, WindowSeconds: windowSeconds, WindowStartUTC: windowStart.UTC()}
if row, ok := c[key]; ok && row != nil {
return row
}
return &types.Consumption{
AccountID: accountID,
DimensionKind: kind,
DimensionID: dimID,
WindowSeconds: windowSeconds,
WindowStartUTC: windowStart.UTC(),
}
}
// addLimitKeys records the user/group consumption keys a single enabled (token
// or budget) limit window reads for the given attribution group, into a dedup
// set. attrGroup may be empty (no group dimension applies).
func addLimitKeys(set map[types.ConsumptionKey]struct{}, userID, attrGroup string, windowSeconds int64, now time.Time) {
if windowSeconds <= 0 {
return
}
ws := types.WindowStart(now, windowSeconds)
if userID != "" {
set[types.ConsumptionKey{Kind: types.DimensionUser, DimID: userID, WindowSeconds: windowSeconds, WindowStartUTC: ws}] = struct{}{}
}
if attrGroup != "" {
set[types.ConsumptionKey{Kind: types.DimensionGroup, DimID: attrGroup, WindowSeconds: windowSeconds, WindowStartUTC: ws}] = struct{}{}
}
}
// prefetchConsumption loads, in one store round-trip, every consumption counter
// that the account-budget ceiling and the candidate policies will read while
// scoring this request. This replaces the per-cap point reads the selector
// previously issued one at a time (the N+1 on the hot path).
func (m *managerImpl) prefetchConsumption(ctx context.Context, in PolicySelectionInput, rules []*types.AccountBudgetRule, candidates []*types.Policy, now time.Time) (consumptionCache, error) {
set := make(map[types.ConsumptionKey]struct{})
for _, p := range candidates {
attr := lowestIntersect(p.SourceGroups, in.GroupIDs)
if p.Limits.TokenLimit.Enabled {
addLimitKeys(set, in.UserID, attr, p.Limits.TokenLimit.WindowSeconds, now)
}
if p.Limits.BudgetLimit.Enabled {
addLimitKeys(set, in.UserID, attr, p.Limits.BudgetLimit.WindowSeconds, now)
}
}
for _, r := range rules {
if r == nil || !r.Enabled || !budgetRuleApplies(r, in) {
continue
}
attr := lowestIntersect(r.TargetGroups, in.GroupIDs)
if r.Limits.TokenLimit.Enabled {
addLimitKeys(set, in.UserID, attr, r.Limits.TokenLimit.WindowSeconds, now)
}
if r.Limits.BudgetLimit.Enabled {
addLimitKeys(set, in.UserID, attr, r.Limits.BudgetLimit.WindowSeconds, now)
}
}
if len(set) == 0 {
return consumptionCache{}, nil
}
keys := make([]types.ConsumptionKey, 0, len(set))
for k := range set {
keys = append(keys, k)
}
rows, err := m.store.GetAgentNetworkConsumptionBatch(ctx, store.LockingStrengthNone, in.AccountID, keys)
if err != nil {
return nil, fmt.Errorf("batch read consumption: %w", err)
}
return consumptionCache(rows), nil
}
// SelectPolicyForRequest picks the policy that "pays" for the
// incoming request. The chosen policy is the one with the largest
// pool that still has headroom — drain the bigger bucket first,
// fall through to the next-biggest only when the current one's
// group cap or shared per-user cap is exhausted. This matches
// operator intuition for layered tiers ("privileged group has the
// 10k budget, regular group has 1k as the safety net") and avoids
// the load-balancer flapping that fraction-based scoring produces
// once any cap has been touched.
//
// Ordering across non-exhausted candidates:
// 1. Policies with NO enabled caps (catch-all-allow) win over any
// capped policy — operators who configure unlimited access
// expect requests to attribute there until they explicitly add
// caps.
// 2. Larger group token cap wins.
// 3. Larger group budget USD cap wins.
// 4. Larger user token cap wins.
// 5. Larger user budget USD cap wins.
// 6. Older created_at wins (deterministic final tiebreak so
// multi-node selection converges).
//
// Returns Allow=true with empty SelectedPolicyID when no policy in
// the account targets the (provider, caller-groups) combination —
// llm_router is the gate that owns "no policy authorises this
// request" semantics; this function trusts that authorisation has
// already happened upstream and only does the limit-aware
// attribution.
func (m *managerImpl) SelectPolicyForRequest(ctx context.Context, in PolicySelectionInput) (*PolicySelectionResult, error) {
if in.AccountID == "" {
return nil, status.Errorf(status.InvalidArgument, "account_id is required")
}
now := time.Now().UTC()
rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, in.AccountID)
if err != nil {
return nil, fmt.Errorf("list account budget rules: %w", err)
}
policies, err := m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, in.AccountID)
if err != nil {
return nil, fmt.Errorf("list account policies: %w", err)
}
candidates := filterApplicablePolicies(policies, in)
// Prefetch every consumption counter the ceiling + candidate policies will
// read, in a single store round-trip, then score against the cache.
cache, err := m.prefetchConsumption(ctx, in, rules, candidates, now)
if err != nil {
return nil, err
}
// Account-level budget rules are an always-on ceiling, evaluated
// independently of policy selection (they bind even for catch-all-allow
// policies or requests that match no policy). All applicable rules must
// pass — this is where min-wins lives.
if deny, code, reason := checkAccountBudget(in, rules, cache, now); deny {
return &PolicySelectionResult{Allow: false, DenyCode: code, DenyReason: reason}, nil
}
if len(candidates) == 0 {
return &PolicySelectionResult{Allow: true}, nil
}
scored, lastDenyCode, lastDenyReason := scoreCandidates(in, candidates, cache, now)
if len(scored) == 0 {
return &PolicySelectionResult{
Allow: false,
DenyCode: lastDenyCode,
DenyReason: lastDenyReason,
}, nil
}
sort.SliceStable(scored, func(i, j int) bool {
// Catch-all-allow (no caps configured) wins outright over
// any capped policy.
iNoCap := isUncapped(scored[i].policy)
jNoCap := isUncapped(scored[j].policy)
if iNoCap != jNoCap {
return iNoCap
}
// Bigger pool drains first. Group caps dominate (shared
// across the group) before individual caps.
if a, b := groupCapTokens(scored[i].policy), groupCapTokens(scored[j].policy); a != b {
return a > b
}
if a, b := groupCapBudgetUsd(scored[i].policy), groupCapBudgetUsd(scored[j].policy); a != b {
return a > b
}
if a, b := userCapTokens(scored[i].policy), userCapTokens(scored[j].policy); a != b {
return a > b
}
if a, b := userCapBudgetUsd(scored[i].policy), userCapBudgetUsd(scored[j].policy); a != b {
return a > b
}
return scored[i].policy.CreatedAt.Before(scored[j].policy.CreatedAt)
})
winner := scored[0]
return &PolicySelectionResult{
Allow: true,
SelectedPolicyID: winner.policy.ID,
AttributionGroupID: winner.attributionGroup,
WindowSeconds: winner.windowSeconds,
}, nil
}
// filterApplicablePolicies returns the enabled policies that target
// the requested provider and have at least one of the caller's groups
// in their source_groups. Caller's group set is matched
// case-sensitively against policy.SourceGroups.
func filterApplicablePolicies(policies []*types.Policy, in PolicySelectionInput) []*types.Policy {
if len(policies) == 0 {
return nil
}
groupSet := make(map[string]struct{}, len(in.GroupIDs))
for _, g := range in.GroupIDs {
if g != "" {
groupSet[g] = struct{}{}
}
}
out := make([]*types.Policy, 0, len(policies))
for _, p := range policies {
if p == nil || !p.Enabled {
continue
}
if !sliceContains(p.DestinationProviderIDs, in.ProviderID) {
continue
}
if !anyGroupMatches(p.SourceGroups, groupSet) {
continue
}
out = append(out, p)
}
return out
}
// candidate is the per-policy intermediate the selector ranks. A
// policy that's been exhausted on any enabled cap never makes it
// into this slice; the selector's deny envelope carries the latest
// exhaustion's reason out separately.
type candidate struct {
policy *types.Policy
attributionGroup string
windowSeconds int64
}
// scoreCandidates evaluates every applicable policy against the
// caller's current consumption. Exhausted policies are filtered out
// of the returned slice; the most recent exhaustion's deny code +
// human reason is returned alongside so the caller can surface it
// when no candidate survives.
func scoreCandidates(
in PolicySelectionInput,
candidates []*types.Policy,
cache consumptionCache,
now time.Time,
) ([]candidate, string, string) {
out := make([]candidate, 0, len(candidates))
var lastDenyCode, lastDenyReason string
for _, p := range candidates {
c, exhausted, denyCode, denyReason := scoreOne(in, p, cache, now)
if exhausted {
lastDenyCode = denyCode
lastDenyReason = denyReason
continue
}
out = append(out, c)
}
return out, lastDenyCode, lastDenyReason
}
// scoreOne checks a single policy for cap exhaustion. Returns the
// candidate envelope when the policy still has headroom on every
// enabled cap; reports exhausted=true with a deny code naming the
// offending cap kind otherwise.
func scoreOne(
in PolicySelectionInput,
p *types.Policy,
cache consumptionCache,
now time.Time,
) (candidate, bool, string, string) {
attrGroup := lowestIntersect(p.SourceGroups, in.GroupIDs)
c := candidate{
policy: p,
attributionGroup: attrGroup,
windowSeconds: effectiveWindowSeconds(p),
}
if p.Limits.TokenLimit.Enabled && p.Limits.TokenLimit.WindowSeconds > 0 {
if exhausted, reason := evalTokenCap(cache, in.AccountID, in.UserID, attrGroup, p.Limits.TokenLimit, now, "policy "+p.ID); exhausted {
return candidate{}, true, denyCodeTokenCapExceeded, reason
}
}
if p.Limits.BudgetLimit.Enabled && p.Limits.BudgetLimit.WindowSeconds > 0 {
if exhausted, reason := evalBudgetCap(cache, in.AccountID, in.UserID, attrGroup, p.Limits.BudgetLimit, now, "policy "+p.ID); exhausted {
return candidate{}, true, denyCodeBudgetCapExceeded, reason
}
}
return c, false, "", ""
}
// evalTokenCap reports whether the token limit is already exhausted for the
// caller in its own window. attrGroup may be empty (no group dimension applies).
// label identifies the cap source ("policy <id>" or "account rule <id>") for the
// deny reason. It is the shared primitive behind both policy and account-rule
// enforcement.
func evalTokenCap(
cache consumptionCache,
accountID, userID, attrGroup string,
tl types.PolicyTokenLimit,
now time.Time,
label string,
) (bool, string) {
windowStart := types.WindowStart(now, tl.WindowSeconds)
if tl.UserCap > 0 && userID != "" {
row := cache.get(accountID, types.DimensionUser, userID, tl.WindowSeconds, windowStart)
used := row.TokensInput + row.TokensOutput
if used >= tl.UserCap {
return true, fmt.Sprintf("user token cap exhausted on %s (used %d of %d)", label, used, tl.UserCap)
}
}
if tl.GroupCap > 0 && attrGroup != "" {
row := cache.get(accountID, types.DimensionGroup, attrGroup, tl.WindowSeconds, windowStart)
used := row.TokensInput + row.TokensOutput
if used >= tl.GroupCap {
return true, fmt.Sprintf("group token cap exhausted on %s (used %d of %d)", label, used, tl.GroupCap)
}
}
return false, ""
}
// evalBudgetCap is the budget (USD) counterpart of evalTokenCap.
func evalBudgetCap(
cache consumptionCache,
accountID, userID, attrGroup string,
bl types.PolicyBudgetLimit,
now time.Time,
label string,
) (bool, string) {
windowStart := types.WindowStart(now, bl.WindowSeconds)
if bl.UserCapUsd > 0 && userID != "" {
row := cache.get(accountID, types.DimensionUser, userID, bl.WindowSeconds, windowStart)
if row.CostUSD >= bl.UserCapUsd {
return true, fmt.Sprintf("user budget cap exhausted on %s (used $%.4f of $%.4f)", label, row.CostUSD, bl.UserCapUsd)
}
}
if bl.GroupCapUsd > 0 && attrGroup != "" {
row := cache.get(accountID, types.DimensionGroup, attrGroup, bl.WindowSeconds, windowStart)
if row.CostUSD >= bl.GroupCapUsd {
return true, fmt.Sprintf("group budget cap exhausted on %s (used $%.4f of $%.4f)", label, row.CostUSD, bl.GroupCapUsd)
}
}
return false, ""
}
// checkAccountBudget evaluates every applicable account-level budget rule as an
// all-must-pass ceiling. A rule applies when the caller is in its TargetUsers,
// one of its TargetGroups, or it has no targets at all (account-wide). Returns
// deny=true with an llm_account.* code on the first exhausted rule. Group caps
// attribute to the lowest intersecting group (the same model policies use), so
// multi-group behavior is unchanged.
func checkAccountBudget(in PolicySelectionInput, rules []*types.AccountBudgetRule, cache consumptionCache, now time.Time) (bool, string, string) {
for _, r := range rules {
if r == nil || !r.Enabled || !budgetRuleApplies(r, in) {
continue
}
attrGroup := lowestIntersect(r.TargetGroups, in.GroupIDs)
label := "account rule " + r.ID
if r.Limits.TokenLimit.Enabled && r.Limits.TokenLimit.WindowSeconds > 0 {
if exhausted, reason := evalTokenCap(cache, in.AccountID, in.UserID, attrGroup, r.Limits.TokenLimit, now, label); exhausted {
return true, denyCodeAccountTokenCapExceeded, reason
}
}
if r.Limits.BudgetLimit.Enabled && r.Limits.BudgetLimit.WindowSeconds > 0 {
if exhausted, reason := evalBudgetCap(cache, in.AccountID, in.UserID, attrGroup, r.Limits.BudgetLimit, now, label); exhausted {
return true, denyCodeAccountBudgetCapExceeded, reason
}
}
}
return false, "", ""
}
// budgetRuleApplies reports whether an account budget rule binds the caller:
// a direct user match, a group intersection, or an untargeted (account-wide)
// rule.
func budgetRuleApplies(r *types.AccountBudgetRule, in PolicySelectionInput) bool {
if len(r.TargetUsers) == 0 && len(r.TargetGroups) == 0 {
return true
}
if in.UserID != "" && sliceContains(r.TargetUsers, in.UserID) {
return true
}
groupSet := make(map[string]struct{}, len(in.GroupIDs))
for _, g := range in.GroupIDs {
if g != "" {
groupSet[g] = struct{}{}
}
}
return anyGroupMatches(r.TargetGroups, groupSet)
}
// RecordAccountBudgetUsage fans the served request's usage out to every
// applicable account budget rule's own (dimension, window) counter. The user
// dimension is always booked when a rule has a user-applicable cap; the group
// dimension books against the rule's lowest intersecting group. This runs
// alongside the policy-window record so account ceilings accumulate in their own
// windows (commonly monthly) independently of the per-policy window.
func (m *managerImpl) RecordAccountBudgetUsage(ctx context.Context, accountID, userID string, groupIDs []string, tokensIn, tokensOut int64, costUSD float64) error {
if accountID == "" {
return status.Errorf(status.InvalidArgument, "account_id is required")
}
if err := validateUsageDeltas(tokensIn, tokensOut, costUSD); err != nil {
return err
}
rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("list account budget rules: %w", err)
}
set := make(map[types.ConsumptionKey]struct{})
addAccountBudgetKeys(set, PolicySelectionInput{AccountID: accountID, UserID: userID, GroupIDs: groupIDs}, rules, time.Now().UTC())
if len(set) == 0 {
return nil
}
return m.store.IncrementAgentNetworkConsumptionBatch(ctx, accountID, keysSlice(set), tokensIn, tokensOut, costUSD)
}
// RecordUsageInput carries everything RecordUsage books for one served request.
type RecordUsageInput struct {
AccountID string
UserID string
AttributionGroupID string // selected policy's attribution group (policy window)
GroupIDs []string
WindowSeconds int64 // selected policy's window; 0 means no policy cap
TokensIn int64
TokensOut int64
CostUSD float64
}
// RecordUsage books a served request's usage against every counter it touches —
// the selected policy's per-(user, group) window plus every applicable account
// budget rule's own window — deduplicated and written in a single transaction.
// Two counters that collapse to the same (dimension, window) tuple are booked
// once, so a single request can never double-count against one cap.
func (m *managerImpl) RecordUsage(ctx context.Context, in RecordUsageInput) error {
if in.AccountID == "" {
return status.Errorf(status.InvalidArgument, "account_id is required")
}
if err := validateUsageDeltas(in.TokensIn, in.TokensOut, in.CostUSD); err != nil {
return err
}
now := time.Now().UTC()
set := make(map[types.ConsumptionKey]struct{})
// Policy-window dimensions are booked only when a policy cap bound this
// request (window > 0). A zero window means catch-all-allow / no policy cap;
// the account fan-out below still books against the budget rules' windows.
if in.WindowSeconds > 0 {
addLimitKeys(set, in.UserID, in.AttributionGroupID, in.WindowSeconds, now)
}
rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, in.AccountID)
if err != nil {
return fmt.Errorf("list account budget rules: %w", err)
}
addAccountBudgetKeys(set, PolicySelectionInput{AccountID: in.AccountID, UserID: in.UserID, GroupIDs: in.GroupIDs}, rules, now)
if len(set) == 0 {
return nil
}
return m.store.IncrementAgentNetworkConsumptionBatch(ctx, in.AccountID, keysSlice(set), in.TokensIn, in.TokensOut, in.CostUSD)
}
// addAccountBudgetKeys adds the (dimension, window) keys a served request books
// against every applicable account budget rule into the dedup set.
func addAccountBudgetKeys(set map[types.ConsumptionKey]struct{}, in PolicySelectionInput, rules []*types.AccountBudgetRule, now time.Time) {
for _, r := range rules {
if r == nil || !r.Enabled || !budgetRuleApplies(r, in) {
continue
}
attrGroup := lowestIntersect(r.TargetGroups, in.GroupIDs)
for _, window := range ruleWindows(r) {
addLimitKeys(set, in.UserID, attrGroup, window, now)
}
}
}
// keysSlice flattens a ConsumptionKey set into a slice.
func keysSlice(set map[types.ConsumptionKey]struct{}) []types.ConsumptionKey {
keys := make([]types.ConsumptionKey, 0, len(set))
for k := range set {
keys = append(keys, k)
}
return keys
}
// ruleWindows returns the distinct enabled window lengths a budget rule books
// against (token window and/or budget window, deduplicated).
func ruleWindows(r *types.AccountBudgetRule) []int64 {
var windows []int64
if r.Limits.TokenLimit.Enabled && r.Limits.TokenLimit.WindowSeconds > 0 {
windows = append(windows, r.Limits.TokenLimit.WindowSeconds)
}
if r.Limits.BudgetLimit.Enabled && r.Limits.BudgetLimit.WindowSeconds > 0 {
bw := r.Limits.BudgetLimit.WindowSeconds
if len(windows) == 0 || windows[0] != bw {
windows = append(windows, bw)
}
}
return windows
}
// effectiveWindowSeconds returns the window length the proxy should
// hand back to RecordLLMUsage. When both halves are enabled with
// different windows, token_limit wins (the more common config); when
// only one is enabled that one wins; when neither is enabled the
// returned value is 0 — RecordLLMUsage treats 0 as "no limit
// tracking" and skips the increment, which is the right pass-through
// for catch-all-allow policies with no caps configured.
func effectiveWindowSeconds(p *types.Policy) int64 {
if p.Limits.TokenLimit.Enabled && p.Limits.TokenLimit.WindowSeconds > 0 {
return p.Limits.TokenLimit.WindowSeconds
}
if p.Limits.BudgetLimit.Enabled && p.Limits.BudgetLimit.WindowSeconds > 0 {
return p.Limits.BudgetLimit.WindowSeconds
}
return 0
}
// lowestIntersect returns the lowest-by-string-sort element of
// callerGroups ∩ sourceGroups. Empty when the intersection is empty.
// Lowest is deterministic so multi-node selection converges.
func lowestIntersect(sourceGroups, callerGroups []string) string {
if len(sourceGroups) == 0 || len(callerGroups) == 0 {
return ""
}
srcSet := make(map[string]struct{}, len(sourceGroups))
for _, g := range sourceGroups {
srcSet[g] = struct{}{}
}
var best string
for _, g := range callerGroups {
if _, ok := srcSet[g]; !ok {
continue
}
if best == "" || g < best {
best = g
}
}
return best
}
func anyGroupMatches(sourceGroups []string, callerSet map[string]struct{}) bool {
for _, g := range sourceGroups {
if _, ok := callerSet[g]; ok {
return true
}
}
return false
}
// isUncapped reports whether a policy has any enabled cap with a
// positive limit value. Mirrors the eval functions' guards: a policy
// with token_limit.enabled=true but every cap value at 0 still
// counts as uncapped because the eval would query nothing and bind
// nothing.
func isUncapped(p *types.Policy) bool {
tl := p.Limits.TokenLimit
if tl.Enabled && tl.WindowSeconds > 0 && (tl.GroupCap > 0 || tl.UserCap > 0) {
return false
}
bl := p.Limits.BudgetLimit
if bl.Enabled && bl.WindowSeconds > 0 && (bl.GroupCapUsd > 0 || bl.UserCapUsd > 0) {
return false
}
return true
}
// groupCapTokens returns the policy's group-token cap when the token
// limit is enabled, zero otherwise. Drives the primary "bigger pool
// first" sort.
func groupCapTokens(p *types.Policy) int64 {
if p.Limits.TokenLimit.Enabled {
return p.Limits.TokenLimit.GroupCap
}
return 0
}
// groupCapBudgetUsd returns the policy's group-budget cap in USD
// when the budget limit is enabled, zero otherwise. Secondary sort
// key after token group cap so budget-only policies still order
// predictably.
func groupCapBudgetUsd(p *types.Policy) float64 {
if p.Limits.BudgetLimit.Enabled {
return p.Limits.BudgetLimit.GroupCapUsd
}
return 0
}
// userCapTokens returns the policy's per-user token cap when the
// token limit is enabled, zero otherwise. Tertiary sort key, used
// when group caps tie or are absent.
func userCapTokens(p *types.Policy) int64 {
if p.Limits.TokenLimit.Enabled {
return p.Limits.TokenLimit.UserCap
}
return 0
}
// userCapBudgetUsd returns the policy's per-user budget cap in USD
// when the budget limit is enabled, zero otherwise. Quaternary sort
// key for budget-only policies whose group caps tie or are absent.
func userCapBudgetUsd(p *types.Policy) float64 {
if p.Limits.BudgetLimit.Enabled {
return p.Limits.BudgetLimit.UserCapUsd
}
return 0
}
func sliceContains(haystack []string, needle string) bool {
for _, v := range haystack {
if v == needle {
return true
}
}
return false
}
// mockManager fallback so tests that don't care about selection still
// compile.
func (*mockManager) SelectPolicyForRequest(_ context.Context, _ PolicySelectionInput) (*PolicySelectionResult, error) {
return &PolicySelectionResult{Allow: true}, nil
}

View File

@@ -0,0 +1,181 @@
package agentnetwork
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
"github.com/netbirdio/netbird/management/server/store"
)
// GC-2 no-mock enforcement tests for the account-budget ceiling. They drive the
// real store + real consumption accounting through SelectPolicyForRequest and
// RecordAccountBudgetUsage, asserting min-wins (account binds independently of
// policy), targeting (groups + direct users), and the record fan-out.
func accountWideUserTokenRule(id string, userCap, window int64) *types.AccountBudgetRule {
r := types.NewAccountBudgetRule(realSelectAccount)
r.ID = id
r.Limits.TokenLimit = types.PolicyTokenLimit{Enabled: true, UserCap: userCap, WindowSeconds: window}
return r
}
// TestSelectPolicy_RealStore_AccountCeilingBindsEvenWithUncappedPolicy proves
// min-wins: the account user ceiling denies once exhausted even though a
// catch-all-allow (uncapped) policy would otherwise pass the request. The
// account gate runs independently of and ahead of policy selection.
func TestSelectPolicy_RealStore_AccountCeilingBindsEvenWithUncappedPolicy(t *testing.T) {
mgr, s := newRealSelectorMgr(t)
ctx := context.Background()
// An uncapped (catch-all-allow) policy: enabled token limit, zero caps.
uncapped := capPolicy("pol-open", realSelectAccount, []string{"grp-eng"}, "prov-1", 0, 86_400)
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, uncapped))
// Account-wide user ceiling of 100 tokens in an hourly window.
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-1", 100, 3_600)))
in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"}
// Fresh: account ceiling has headroom, uncapped policy wins.
res, err := mgr.SelectPolicyForRequest(ctx, in)
require.NoError(t, err)
assert.True(t, res.Allow, "fresh account ceiling must allow")
// Drain the account user ceiling via the fan-out path.
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", []string{"grp-eng"}, 100, 0, 0))
res, err = mgr.SelectPolicyForRequest(ctx, in)
require.NoError(t, err)
assert.False(t, res.Allow, "account ceiling must deny even though the policy is uncapped (min-wins)")
assert.Equal(t, denyCodeAccountTokenCapExceeded, res.DenyCode, "deny must carry the llm_account.* code")
}
// TestSelectPolicy_RealStore_AccountGroupCeiling proves a group-targeted rule
// binds the caller's group dimension.
func TestSelectPolicy_RealStore_AccountGroupCeiling(t *testing.T) {
mgr, s := newRealSelectorMgr(t)
ctx := context.Background()
rule := types.NewAccountBudgetRule(realSelectAccount)
rule.ID = "ainbud-grp"
rule.TargetGroups = []string{"grp-eng"}
rule.Limits.BudgetLimit = types.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 5.0, WindowSeconds: 2_592_000}
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, rule))
in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"}
res, err := mgr.SelectPolicyForRequest(ctx, in)
require.NoError(t, err)
assert.True(t, res.Allow, "fresh group ceiling must allow")
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", []string{"grp-eng"}, 0, 0, 5.0))
res, err = mgr.SelectPolicyForRequest(ctx, in)
require.NoError(t, err)
assert.False(t, res.Allow, "group budget ceiling must deny once spent")
assert.Equal(t, denyCodeAccountBudgetCapExceeded, res.DenyCode, "account budget deny code")
}
// TestSelectPolicy_RealStore_AccountTargetUsersBindsOnlyThatUser proves a
// TargetUsers rule tightens only the named user, leaving others unbound.
func TestSelectPolicy_RealStore_AccountTargetUsersBindsOnlyThatUser(t *testing.T) {
mgr, s := newRealSelectorMgr(t)
ctx := context.Background()
rule := types.NewAccountBudgetRule(realSelectAccount)
rule.ID = "ainbud-alice"
rule.TargetUsers = []string{"alice"}
rule.Limits.TokenLimit = types.PolicyTokenLimit{Enabled: true, UserCap: 100, WindowSeconds: 3_600}
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, rule))
// Record alice's usage to the rule window.
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "alice", nil, 100, 0, 0))
aliceIn := PolicySelectionInput{AccountID: realSelectAccount, UserID: "alice", ProviderID: "prov-1"}
res, err := mgr.SelectPolicyForRequest(ctx, aliceIn)
require.NoError(t, err)
assert.False(t, res.Allow, "alice is bound by the TargetUsers rule and is exhausted")
bobIn := PolicySelectionInput{AccountID: realSelectAccount, UserID: "bob", ProviderID: "prov-1"}
res, err = mgr.SelectPolicyForRequest(ctx, bobIn)
require.NoError(t, err)
assert.True(t, res.Allow, "bob is not in TargetUsers, so the rule must not bind him")
}
// TestSelectPolicy_RealStore_AccountRuleRecordsToOwnWindow proves the record
// fan-out books usage in the rule's own window (distinct from any policy
// window), so the account ceiling accumulates independently.
func TestSelectPolicy_RealStore_AccountRuleRecordsToOwnWindow(t *testing.T) {
mgr, s := newRealSelectorMgr(t)
ctx := context.Background()
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-w", 100, 3_600)))
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", nil, 60, 0, 0))
// Same user, a policy-style daily window must NOT see the account-window
// usage — windows are independent counters.
dailyRow, err := s.GetAgentNetworkConsumption(ctx, store.LockingStrengthNone, realSelectAccount, types.DimensionUser, "user-1", 86_400, types.WindowStart(time.Now().UTC(), 86_400))
require.NoError(t, err)
assert.Equal(t, int64(0), dailyRow.TokensInput+dailyRow.TokensOutput, "daily window must be untouched by the hourly account-rule record")
// A second record pushes the hourly account window to its cap → deny.
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", nil, 40, 0, 0))
res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", ProviderID: "prov-1"})
require.NoError(t, err)
assert.False(t, res.Allow, "100 tokens recorded in the rule's hourly window must exhaust the 100-token ceiling")
assert.Equal(t, denyCodeAccountTokenCapExceeded, res.DenyCode, "account token deny code")
}
// TestRecordUsage_RealStore_BooksPolicyAndAccountWindows proves the batched
// post-flight write books the selected policy's window AND every applicable
// account rule's (independent) window in a single call — the #6 batched-write
// path the proxy's RecordLLMUsage RPC now uses.
func TestRecordUsage_RealStore_BooksPolicyAndAccountWindows(t *testing.T) {
mgr, s := newRealSelectorMgr(t)
ctx := context.Background()
// Policy: 100-token group cap on a daily window. Account rule: 100-token
// user ceiling on an hourly window — an independent counter.
policy := capPolicy("pol-1", realSelectAccount, []string{"grp-eng"}, "prov-1", 100, 86_400)
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, policy))
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-1", 100, 3_600)))
in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"}
res, err := mgr.SelectPolicyForRequest(ctx, in)
require.NoError(t, err)
require.True(t, res.Allow)
require.Equal(t, "pol-1", res.SelectedPolicyID)
// One batched record books the policy window (group + user @86400) and the
// account rule window (user @3600) atomically.
require.NoError(t, mgr.RecordUsage(ctx, RecordUsageInput{
AccountID: realSelectAccount,
UserID: "user-1",
AttributionGroupID: res.AttributionGroupID,
GroupIDs: []string{"grp-eng"},
WindowSeconds: res.WindowSeconds,
TokensIn: 100,
}))
// The next selection denies — the account hourly ceiling binds first.
res, err = mgr.SelectPolicyForRequest(ctx, in)
require.NoError(t, err)
assert.False(t, res.Allow, "usage booked by RecordUsage must enforce on the next request")
// Prove BOTH windows were booked in the one call via a direct batch read.
now := time.Now().UTC()
userKey := types.ConsumptionKey{Kind: types.DimensionUser, DimID: "user-1", WindowSeconds: 3_600, WindowStartUTC: types.WindowStart(now, 3_600)}
groupKey := types.ConsumptionKey{Kind: types.DimensionGroup, DimID: "grp-eng", WindowSeconds: 86_400, WindowStartUTC: types.WindowStart(now, 86_400)}
rows, err := s.GetAgentNetworkConsumptionBatch(ctx, store.LockingStrengthNone, realSelectAccount, []types.ConsumptionKey{userKey, groupKey})
require.NoError(t, err)
require.Contains(t, rows, userKey, "account rule user/hourly window booked")
require.Contains(t, rows, groupKey, "policy group/daily window booked")
assert.Equal(t, int64(100), rows[userKey].TokensInput, "account hourly user counter")
assert.Equal(t, int64(100), rows[groupKey].TokensInput, "policy daily group counter")
}

View File

@@ -0,0 +1,214 @@
package agentnetwork
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
"github.com/netbirdio/netbird/management/server/store"
)
// This file is the no-mock regression guard for policy limit enforcement.
// policyselect_test.go pins the same behavior through a gomock store with
// explicit call-sequence expectations — brittle precisely where the upcoming
// account-budget work (GC-2) refactors the cap-eval primitive and adds an
// account-level gate. These tests drive the REAL sqlite store + REAL
// consumption accounting and assert observable behavior (allow / deny /
// selection / attribution), not which store methods get called. They must keep
// passing unchanged after GC-2 lands, which is what proves "current behavior is
// not changed."
const realSelectAccount = "acc-realselect-1"
// newRealSelectorMgr builds a managerImpl backed by a real sqlite test store.
func newRealSelectorMgr(t *testing.T) (*managerImpl, store.Store) {
t.Helper()
ctx := context.Background()
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err, "real sqlite test store must come up")
t.Cleanup(cleanup)
return &managerImpl{store: s}, s
}
// TestSelectPolicy_RealStore_NoApplicablePolicies pins the pass-through:
// nothing targets the (provider, groups) combination, so the selector allows
// without attribution or consumption tracking.
func TestSelectPolicy_RealStore_NoApplicablePolicies(t *testing.T) {
mgr, _ := newRealSelectorMgr(t)
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: realSelectAccount,
UserID: "user-1",
GroupIDs: []string{"grp-x"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.True(t, res.Allow, "no applicable policy must pass through as allow")
assert.Empty(t, res.SelectedPolicyID, "no selection when nothing applies")
}
// TestSelectPolicy_RealStore_AllowAndLowestGroupAttribution pins the v1
// attribution rule (lowest intersecting group by string sort) through the
// real store, with a fresh (zero) consumption row.
func TestSelectPolicy_RealStore_AllowAndLowestGroupAttribution(t *testing.T) {
mgr, s := newRealSelectorMgr(t)
ctx := context.Background()
p := capPolicy("pol-A", realSelectAccount, []string{"grp-zz", "grp-aa", "grp-mm"}, "prov-1", 10_000, 86_400)
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p))
res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{
AccountID: realSelectAccount,
UserID: "user-1",
GroupIDs: []string{"grp-zz", "grp-aa", "grp-mm"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.True(t, res.Allow, "fresh state under cap must allow")
assert.Equal(t, "pol-A", res.SelectedPolicyID, "only applicable policy must be selected")
assert.Equal(t, "grp-aa", res.AttributionGroupID, "lowest-by-sort intersecting group must win")
assert.Equal(t, int64(86_400), res.WindowSeconds, "selected policy's window must be returned")
}
// TestSelectPolicy_RealStore_LargerPoolWins_FallsThroughWhenExhausted pins the
// core selection behavior end to end. The two policies bind DISTINCT groups so
// they read separate counters — the only shape where fall-through actually
// yields headroom (policies on the same group share one counter, as
// policyselect_test.go notes). Larger pool wins fresh; after real consumption
// drains the larger group, selection falls through to the smaller; once both
// counters are exhausted the request is denied.
func TestSelectPolicy_RealStore_LargerPoolWins_FallsThroughWhenExhausted(t *testing.T) {
mgr, s := newRealSelectorMgr(t)
ctx := context.Background()
tight := capPolicy("pol-tight", realSelectAccount, []string{"grp-tight"}, "prov-1", 100, 86_400)
tight.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
wide := capPolicy("pol-wide", realSelectAccount, []string{"grp-wide"}, "prov-1", 10_000, 86_400)
wide.CreatedAt = time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, tight))
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, wide))
// Caller is in both groups, so both policies apply with independent counters.
in := PolicySelectionInput{
AccountID: realSelectAccount,
UserID: "user-1",
GroupIDs: []string{"grp-tight", "grp-wide"},
ProviderID: "prov-1",
}
// Fresh: larger pool wins.
res, err := mgr.SelectPolicyForRequest(ctx, in)
require.NoError(t, err)
assert.Equal(t, "pol-wide", res.SelectedPolicyID, "larger pool drains first")
// Drain only the wide group's counter to its cap.
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-wide", 86_400, 10_000, 0, 0))
// Wide exhausted, tight's separate counter is fresh → fall through to tight.
res, err = mgr.SelectPolicyForRequest(ctx, in)
require.NoError(t, err)
assert.True(t, res.Allow, "tight pool has its own untouched counter")
assert.Equal(t, "pol-tight", res.SelectedPolicyID, "selection falls through to the smaller pool once the larger is exhausted")
// Drain the tight group's counter too → both exhausted → deny.
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-tight", 86_400, 100, 0, 0))
res, err = mgr.SelectPolicyForRequest(ctx, in)
require.NoError(t, err)
assert.False(t, res.Allow, "both group counters exhausted must deny")
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode, "deny code names the offending cap kind")
}
// TestSelectPolicy_RealStore_BudgetCapDenies pins budget (USD) enforcement
// through the real store: once recorded cost reaches the cap, deny.
func TestSelectPolicy_RealStore_BudgetCapDenies(t *testing.T) {
mgr, s := newRealSelectorMgr(t)
ctx := context.Background()
p := &types.Policy{
ID: "pol-budget",
AccountID: realSelectAccount,
Enabled: true,
SourceGroups: []string{"grp-eng"},
DestinationProviderIDs: []string{"prov-1"},
Limits: types.PolicyLimits{
BudgetLimit: types.PolicyBudgetLimit{
Enabled: true,
GroupCapUsd: 5.0,
WindowSeconds: 86_400,
},
},
CreatedAt: time.Now().UTC(),
}
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p))
in := PolicySelectionInput{
AccountID: realSelectAccount,
UserID: "user-1",
GroupIDs: []string{"grp-eng"},
ProviderID: "prov-1",
}
res, err := mgr.SelectPolicyForRequest(ctx, in)
require.NoError(t, err)
assert.True(t, res.Allow, "fresh budget must allow")
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-eng", 86_400, 0, 0, 5.0))
res, err = mgr.SelectPolicyForRequest(ctx, in)
require.NoError(t, err)
assert.False(t, res.Allow, "cost at the cap must deny")
assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode, "budget deny code must be surfaced")
}
// TestSelectPolicy_RealStore_GroupCounterSharedAcrossPolicies pins that two
// policies on the same group+window read one shared consumption counter: usage
// recorded once is visible to both, so exhausting the group budget denies
// regardless of which policy would attribute.
func TestSelectPolicy_RealStore_GroupCounterSharedAcrossPolicies(t *testing.T) {
mgr, s := newRealSelectorMgr(t)
ctx := context.Background()
a := capPolicy("pol-a", realSelectAccount, []string{"grp-eng"}, "prov-1", 1_000, 86_400)
b := capPolicy("pol-b", realSelectAccount, []string{"grp-eng"}, "prov-1", 1_000, 86_400)
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, a))
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, b))
in := PolicySelectionInput{
AccountID: realSelectAccount,
UserID: "user-1",
GroupIDs: []string{"grp-eng"},
ProviderID: "prov-1",
}
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-eng", 86_400, 1_000, 0, 0))
res, err := mgr.SelectPolicyForRequest(ctx, in)
require.NoError(t, err)
assert.False(t, res.Allow, "shared group counter at cap denies both equal policies")
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode, "token deny code on the shared counter")
}
// TestSelectPolicy_RealStore_DisabledPolicyIgnored pins that a disabled policy
// is invisible to selection even when it otherwise matches.
func TestSelectPolicy_RealStore_DisabledPolicyIgnored(t *testing.T) {
mgr, s := newRealSelectorMgr(t)
ctx := context.Background()
p := capPolicy("pol-disabled", realSelectAccount, []string{"grp-eng"}, "prov-1", 10_000, 86_400)
p.Enabled = false
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p))
res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{
AccountID: realSelectAccount,
UserID: "user-1",
GroupIDs: []string{"grp-eng"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.True(t, res.Allow, "no enabled policy applies → pass-through allow")
assert.Empty(t, res.SelectedPolicyID, "disabled policy must not be selected")
}

View File

@@ -0,0 +1,641 @@
package agentnetwork
import (
"context"
"errors"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
"github.com/netbirdio/netbird/management/server/store"
nbstatus "github.com/netbirdio/netbird/shared/management/status"
)
func newSelectorMgr(t *testing.T, ctrl *gomock.Controller) (*managerImpl, *store.MockStore) {
t.Helper()
mockStore := store.NewMockStore(ctrl)
// SelectPolicyForRequest evaluates the account-budget ceiling before policy
// selection. These policy-selection tests don't exercise account rules, so
// default to "no rules" — the no-mock policyselect_realstore_test.go covers
// the account gate's behavior end to end.
mockStore.EXPECT().
GetAccountAgentNetworkBudgetRules(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil, nil).
AnyTimes()
return &managerImpl{store: mockStore}, mockStore
}
type usedKey struct {
kind types.ConsumptionDimension
dimID string
window int64
}
// expectConsumptionBatch stubs the batched consumption read to return the
// supplied per-(kind, dim, window) counters, filling each row's window start
// from the actual request keys so it always matches what the selector computed.
// Keys absent from used resolve to zero counters.
func expectConsumptionBatch(mockStore *store.MockStore, used map[usedKey]*types.Consumption) {
mockStore.EXPECT().
GetAgentNetworkConsumptionBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, _ store.LockingStrength, _ string, keys []types.ConsumptionKey) (map[types.ConsumptionKey]*types.Consumption, error) {
out := make(map[types.ConsumptionKey]*types.Consumption)
for _, k := range keys {
if row, ok := used[usedKey{k.Kind, k.DimID, k.WindowSeconds}]; ok {
rc := *row
rc.WindowStartUTC = k.WindowStartUTC
out[k] = &rc
}
}
return out, nil
}).
AnyTimes()
}
func capPolicy(id, account string, sourceGroups []string, providerID string, tokenCap int64, windowSec int64) *types.Policy {
return &types.Policy{
ID: id,
AccountID: account,
Enabled: true,
SourceGroups: sourceGroups,
DestinationProviderIDs: []string{providerID},
Limits: types.PolicyLimits{
TokenLimit: types.PolicyTokenLimit{
Enabled: true,
GroupCap: tokenCap,
WindowSeconds: windowSec,
},
},
CreatedAt: time.Now().UTC(),
}
}
// TestSelectPolicy_NoApplicablePolicies covers the pass-through path:
// llm_router authorisation is upstream of selection; when the
// selector finds no policy targeting the (provider, caller-groups)
// combination, it returns Allow with no attribution and lets the
// request continue without consumption tracking.
func TestSelectPolicy_NoApplicablePolicies(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return([]*types.Policy{}, nil)
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
UserID: "user-1",
GroupIDs: []string{"grp-x"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.True(t, res.Allow, "no applicable policies = pass-through allow")
assert.Empty(t, res.SelectedPolicyID, "no selection when nothing applies")
}
// TestSelectPolicy_AllowWithLowestGroupAttribution proves the v1
// attribution rule: when the caller's groups intersect a policy's
// source_groups in multiple positions, the selector picks the lowest
// group id by string sort so multi-node selection converges.
func TestSelectPolicy_AllowWithLowestGroupAttribution(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
policy := capPolicy("pol-A", "acc-1", []string{"grp-zz", "grp-aa", "grp-mm"}, "prov-1", 10_000, 86_400)
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return([]*types.Policy{policy}, nil)
// Fresh: zero consumption across the board.
expectConsumptionBatch(mockStore, nil)
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
UserID: "user-1",
GroupIDs: []string{"grp-zz", "grp-aa", "grp-mm"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.True(t, res.Allow)
assert.Equal(t, "pol-A", res.SelectedPolicyID)
assert.Equal(t, "grp-aa", res.AttributionGroupID,
"lowest-by-sort intersection wins so multi-node selection converges")
assert.Equal(t, int64(86_400), res.WindowSeconds)
}
// TestSelectPolicy_LargerPoolWinsAcrossUsageLevels proves the core
// selection rule: among multiple applicable policies with caps, the
// selector picks the one with the larger absolute pool — at every
// usage level, not just at fresh state. The smaller-pool policy is
// only reached when the larger one is exhausted. This is the
// "drain biggest first" semantic operators expect for layered
// tiers; a fraction-based score would flap between the two as
// soon as one is partially used.
func TestSelectPolicy_LargerPoolWinsAcrossUsageLevels(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
tight := capPolicy("pol-tight", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400)
tight.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
wide := capPolicy("pol-wide", "acc-1", []string{"grp-engineers"}, "prov-1", 10_000, 86_400)
wide.CreatedAt = time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return([]*types.Policy{tight, wide}, nil)
// Both partially used. tight at 50/100 (50% used); wide at
// 50/10000 (0.5% used). Old fraction-based algo would pick wide
// here too — but for the wrong reason ("more relative slack").
// New algo picks wide because its initial group cap is bigger
// (10000 > 100), and that decision is stable as wide drains.
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 50},
})
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
UserID: "user-1",
GroupIDs: []string{"grp-engineers"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.Equal(t, "pol-wide", res.SelectedPolicyID,
"the policy with the bigger initial pool wins — operators expect 'drain the privileged tier first', not load-balance across tiers")
}
// TestSelectPolicy_StaysOnLargerPoolAfterPartialDrain locks the
// stickiness contract reported by operators: with two policies
// where A has a 200-token group cap and B has 150, the very first
// request goes to A AND every subsequent request continues to land
// on A until A's group cap is exhausted — at which point B becomes
// the only candidate. A fraction-based score would flap to B as
// soon as A had any consumption (B's 1.0 fraction beats A's 0.75)
// even though A still has more absolute headroom; that produced
// confusing per-policy attribution ledger entries and stranded
// A's remaining capacity behind B's exhaustion.
func TestSelectPolicy_StaysOnLargerPoolAfterPartialDrain(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
policyA := capPolicy("pol-A-200", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400)
policyB := capPolicy("pol-B-150", "acc-1", []string{"grp-engineers"}, "prov-1", 150, 86_400)
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return([]*types.Policy{policyA, policyB}, nil)
// A is partially drained (50/200 used = 25% used; 75% headroom
// remaining). B is fresh (0/150). The old fraction-based score
// would pick B here (1.0 > 0.75 fraction); the new pool-size
// score sticks with A (200 > 150 absolute cap).
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 50},
})
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
UserID: "user-1",
GroupIDs: []string{"grp-engineers"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.Equal(t, "pol-A-200", res.SelectedPolicyID,
"once attribution lands on the bigger pool it must STAY there until exhausted — operators expect 'drain A then B', not 'flip to B as soon as A is touched'")
}
// TestSelectPolicy_FallsThroughToSmallerPoolWhenLargerExhausted
// proves the second half of the stickiness contract: once the
// larger-pool policy IS exhausted, the smaller one takes over.
// Without this we'd deny on requests the smaller policy is fully
// equipped to serve.
func TestSelectPolicy_FallsThroughToSmallerPoolWhenLargerExhausted(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
policyA := capPolicy("pol-A-200", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400)
// B uses a different window length so it has an INDEPENDENT counter — the
// realistic shape for fall-through. On the SAME (group, window) tuple the
// counter is shared, so A's cap of 200 being reached would also exhaust B's
// 150; independent counters are what let A exhaust while B retains headroom.
policyB := capPolicy("pol-B-150", "acc-1", []string{"grp-engineers"}, "prov-1", 150, 3_600)
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return([]*types.Policy{policyA, policyB}, nil)
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 200}, // A: 200 >= 200 → exhausted
{types.DimensionGroup, "grp-engineers", 3_600}: {TokensInput: 100}, // B: 100 < 150 → headroom
})
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
UserID: "user-1",
GroupIDs: []string{"grp-engineers"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.Equal(t, "pol-B-150", res.SelectedPolicyID,
"once the bigger pool is exhausted, the smaller one must take over — denying when capacity remains would strand B's allowance")
}
// TestSelectPolicy_TiebreakByLargerGroupPool covers the user-reported
// bug: an admin in two groups (Users + Admins) where Users is bound
// by a smaller-group-cap policy (50 group, 100 user) and Admins is
// bound by a bigger-group-cap policy (100 group, 20 user) MUST get
// attributed to the Admins policy on the first request.
//
// Without this rule, the fresh-state fraction is 1.0 for both and
// the older policy wins by created_at. The first 24-token request
// then drains the shared user counter past Admins's tight 20-token
// user cap, locking Admins out of selection forever. The 100-token
// Admins group pool ends up stranded while requests pile onto the
// 50-token Users pool — the opposite of what the operator intended
// when they put the bigger pool on the privileged group.
func TestSelectPolicy_TiebreakByLargerGroupPool(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
// Policy A: Users group, smaller group pool, looser per-user cap.
policyA := &types.Policy{
ID: "pol-Users",
AccountID: "acc-1",
Enabled: true,
SourceGroups: []string{"grp-Users"},
DestinationProviderIDs: []string{"prov-1"},
Limits: types.PolicyLimits{
TokenLimit: types.PolicyTokenLimit{
Enabled: true, GroupCap: 50, UserCap: 100, WindowSeconds: 86_400,
},
},
// Older — would win the legacy created_at tiebreak.
CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
}
// Policy B: Admins group, bigger group pool, tighter per-user cap.
policyB := &types.Policy{
ID: "pol-Admins",
AccountID: "acc-1",
Enabled: true,
SourceGroups: []string{"grp-Admins"},
DestinationProviderIDs: []string{"prov-1"},
Limits: types.PolicyLimits{
TokenLimit: types.PolicyTokenLimit{
Enabled: true, GroupCap: 100, UserCap: 20, WindowSeconds: 86_400,
},
},
CreatedAt: time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC),
}
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return([]*types.Policy{policyA, policyB}, nil)
// Fresh state: every cap evaluation reads zero usage.
expectConsumptionBatch(mockStore, nil)
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
UserID: "user-1",
GroupIDs: []string{"grp-Users", "grp-Admins"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.Equal(t, "pol-Admins", res.SelectedPolicyID,
"the bigger group pool wins the fresh-state tiebreak — picking Users first would burn the shared user counter past Admins's tight user cap on the very first request and strand the bigger Admins pool")
assert.Equal(t, "grp-Admins", res.AttributionGroupID)
}
// TestSelectPolicy_TiebreakByCreatedAt proves the deterministic
// final tiebreak: when two applicable policies have the same
// headroom fraction AND the same group cap (so the larger-pool rule
// can't differentiate either), the older policy wins so attribution
// is stable across replays.
func TestSelectPolicy_TiebreakByCreatedAt(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
older := capPolicy("pol-old", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400)
older.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
newer := capPolicy("pol-new", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400)
newer.CreatedAt = time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return([]*types.Policy{newer, older}, nil)
// Both at zero consumption → identical headroom fraction.
expectConsumptionBatch(mockStore, nil)
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
GroupIDs: []string{"grp-engineers"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.Equal(t, "pol-old", res.SelectedPolicyID,
"older policy wins on equal-headroom tiebreak so attribution is stable across replays")
}
// TestSelectPolicy_DeniesWhenAllExhausted proves the deny envelope:
// when every applicable policy has at least one cap fully exhausted,
// the selector returns Allow=false with the most-recent exhaustion's
// deny code + human reason. The proxy's middleware surfaces this as
// a 403 with the canonical llm_policy.* code.
func TestSelectPolicy_DeniesWhenAllExhausted(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
a := capPolicy("pol-a", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400)
b := capPolicy("pol-b", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400)
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return([]*types.Policy{a, b}, nil)
// Shared group counter at 200: A (cap 100) and B (cap 200) both exhausted.
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 200},
})
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
GroupIDs: []string{"grp-engineers"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.False(t, res.Allow, "every applicable policy exhausted = deny")
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode)
assert.Contains(t, res.DenyReason, "token cap exhausted",
"deny reason must name the exhausted cap kind for operator debugging")
}
// TestSelectPolicy_UncappedPolicyAlwaysWinsAgainstCapped proves the
// catch-all-allow contract: a policy with NO enabled caps wins
// against any capped policy regardless of how much headroom the
// capped one has, because operators who configure unlimited access
// expect requests to attribute there until they explicitly add caps.
func TestSelectPolicy_UncappedPolicyAlwaysWinsAgainstCapped(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
uncapped := &types.Policy{
ID: "pol-uncapped",
AccountID: "acc-1",
Enabled: true,
SourceGroups: []string{"grp-engineers"},
DestinationProviderIDs: []string{"prov-1"},
// All Limits.*.Enabled = false (zero-value).
CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
}
wide := capPolicy("pol-wide", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000_000, 86_400)
wide.CreatedAt = time.Date(2025, 12, 1, 0, 0, 0, 0, time.UTC) // older than uncapped
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return([]*types.Policy{uncapped, wide}, nil)
// Only the wide policy reads consumption; uncapped doesn't query
// because it has no enabled caps.
expectConsumptionBatch(mockStore, nil)
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
GroupIDs: []string{"grp-engineers"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.Equal(t, "pol-uncapped", res.SelectedPolicyID,
"a no-caps policy must always win selection — that's how operators express 'unlimited access through this path'")
assert.Equal(t, int64(0), res.WindowSeconds, "no caps configured = WindowSeconds=0 so RecordLLMUsage skips counter writes")
}
// TestSelectPolicy_DisabledPolicyIgnored proves disabled policies
// don't count toward selection — even when they'd otherwise be the
// best match. Operators disable a policy to take it offline; the
// selector must respect that and route through whatever's left.
func TestSelectPolicy_DisabledPolicyIgnored(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
disabled := capPolicy("pol-disabled", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000_000, 86_400)
disabled.Enabled = false
enabled := capPolicy("pol-enabled", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400)
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return([]*types.Policy{disabled, enabled}, nil)
expectConsumptionBatch(mockStore, nil)
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
GroupIDs: []string{"grp-engineers"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.Equal(t, "pol-enabled", res.SelectedPolicyID,
"disabled policies must be ignored at selection time")
}
// TestSelectPolicy_StoreErrorPropagates locks the no-fail-open
// contract: a transient store error must surface to the caller, not
// be silently treated as "no policies = allow". A false allow on the
// hot path would let a request slip past every cap.
func TestSelectPolicy_StoreErrorPropagates(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return(nil, errors.New("boom"))
_, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
})
require.Error(t, err, "store errors must surface — never fail open on the hot path")
}
// TestSelectPolicy_RejectsEmptyAccount is the input-validation guard:
// empty account_id is a programmer error and must surface as
// InvalidArgument, not as a silent zero-result lookup.
func TestSelectPolicy_RejectsEmptyAccount(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, _ := newSelectorMgr(t, ctrl)
_, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{})
require.Error(t, err)
var sErr *nbstatus.Error
require.True(t, errors.As(err, &sErr))
assert.Equal(t, nbstatus.InvalidArgument, sErr.Type())
}
// TestSelectPolicy_SharesGroupCounterAcrossPolicies locks the
// counter-keying design fork: counters are keyed on (account,
// dim_kind, dim_id, window_hours, window_start) — NOT on policy_id.
// Two policies that target the same group with the SAME window length
// share one bucket: spend booked under policy A is visible to policy
// B's headroom calculation and counts toward B's cap.
//
// This is what makes "operator's per-group enforcement" sane — caps
// describe how much a GROUP can use, not how much each policy owes.
func TestSelectPolicy_SharesGroupCounterAcrossPolicies(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
// Two policies, both targeting grp-engineers + prov-1, same 24h
// window length. Different cap sizes.
policyA := capPolicy("pol-A", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400)
policyB := capPolicy("pol-B", "acc-1", []string{"grp-engineers"}, "prov-1", 5_000, 86_400)
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return([]*types.Policy{policyA, policyB}, nil)
// Both policies query the SAME consumption row — same dim_id,
// same window_hours, same window_start. The mock returns the
// same row for both calls, simulating the shared counter.
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 800},
})
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
GroupIDs: []string{"grp-engineers"},
ProviderID: "prov-1",
})
require.NoError(t, err)
// 800 used → policy A has 200 tokens left of 1000 (20% headroom);
// policy B has 4200 left of 5000 (84% headroom). B wins.
assert.Equal(t, "pol-B", res.SelectedPolicyID,
"the SAME 800 tokens count toward both policies — counters share the (group, window) key, caps differ per policy")
}
// TestSelectPolicy_AntiFallThroughOnLowestGroup locks the no-fall-
// through behaviour: when a caller is in multiple of a policy's
// source_groups and the lowest-by-sort group is exhausted, we DENY
// rather than fall through to a less-loaded sibling. Per-group caps
// are independent (each group has its own bucket), but attribution
// is one-shot — operators wanting fall-through must split into
// separate policies.
//
// This nails down semantics future contributors might "improve" into
// fall-through behaviour by accident.
func TestSelectPolicy_AntiFallThroughOnLowestGroup(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
// Policy targets two groups; caller is in both.
policy := capPolicy("pol-1", "acc-1", []string{"grp-aaa", "grp-bbb"}, "prov-1", 100, 86_400)
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return([]*types.Policy{policy}, nil)
// grp-aaa is the lowest by sort → attribution picks it, and the
// prefetch only collects the attribution group's key. We exhaust
// grp-aaa (100/100); grp-bbb's counter is never requested because the
// selector attributes one-shot to the lowest group, so it can't fall
// through to a less-loaded sibling.
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
{types.DimensionGroup, "grp-aaa", 86_400}: {TokensInput: 100},
})
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
GroupIDs: []string{"grp-aaa", "grp-bbb"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.False(t, res.Allow,
"lowest-group-by-sort attribution does NOT fall through to a less-loaded sibling — operators wanting fall-through must split into separate policies")
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode)
assert.Contains(t, res.DenyReason, "pol-1",
"deny reason names the exhausted policy id so operators can grep it from the access log")
}
// TestSelectPolicy_BudgetOnlyExhaustionDenies covers the symmetric
// path to TestSelectPolicy_DeniesWhenAllExhausted but for the budget
// cap: a policy with token_limit DISABLED and budget_limit at-cap
// must deny with llm_policy.budget_cap_exceeded (not the token code).
//
// Without this, the budget evaluation path in evalBudgetCap could
// silently regress and we'd still pass DeniesWhenAllExhausted (which
// only exercises tokens).
func TestSelectPolicy_BudgetOnlyExhaustionDenies(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
policy := &types.Policy{
ID: "pol-budget",
AccountID: "acc-1",
Enabled: true,
SourceGroups: []string{"grp-engineers"},
DestinationProviderIDs: []string{"prov-1"},
Limits: types.PolicyLimits{
TokenLimit: types.PolicyTokenLimit{Enabled: false},
BudgetLimit: types.PolicyBudgetLimit{
Enabled: true,
GroupCapUsd: 10.00,
WindowSeconds: 86_400,
},
},
CreatedAt: time.Now().UTC(),
}
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return([]*types.Policy{policy}, nil)
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
{types.DimensionGroup, "grp-engineers", 86_400}: {CostUSD: 10.50}, // over the $10 cap
})
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
GroupIDs: []string{"grp-engineers"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.False(t, res.Allow, "budget cap exhausted must deny independently of any token cap state")
assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode,
"deny code must be the budget code — token-only deny would silently regress the budget evaluation path")
assert.Contains(t, res.DenyReason, "budget", "deny reason names the budget cap kind for operator debugging")
}
// TestSelectPolicy_BudgetTighterThanTokenWins is the dual-cap headroom
// fork: when both Token and Budget are enabled on the same policy,
// the SMALLER remaining ratio gates the policy. A policy with
// abundant token headroom but near-zero budget headroom must deny on
// budget, not pass on tokens.
func TestSelectPolicy_BudgetTighterThanTokenWins(t *testing.T) {
ctrl := gomock.NewController(t)
mgr, mockStore := newSelectorMgr(t, ctrl)
policy := &types.Policy{
ID: "pol-dual",
AccountID: "acc-1",
Enabled: true,
SourceGroups: []string{"grp-engineers"},
DestinationProviderIDs: []string{"prov-1"},
Limits: types.PolicyLimits{
TokenLimit: types.PolicyTokenLimit{Enabled: true, GroupCap: 10_000_000, WindowSeconds: 86_400},
BudgetLimit: types.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 1.00, WindowSeconds: 86_400},
},
CreatedAt: time.Now().UTC(),
}
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
Return([]*types.Policy{policy}, nil)
// One shared counter carries both token usage (ample headroom) and cost
// (at the $1 budget cap); the tighter budget cap gates the policy.
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 100, CostUSD: 1.00},
})
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
AccountID: "acc-1",
GroupIDs: []string{"grp-engineers"},
ProviderID: "prov-1",
})
require.NoError(t, err)
assert.False(t, res.Allow,
"the tighter of (token, budget) wins — abundant token headroom must NOT mask an exhausted budget")
assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode)
}

View File

@@ -0,0 +1,131 @@
package agentnetwork
import (
"context"
log "github.com/sirupsen/logrus"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// reconcile recomputes the synthesised reverse-proxy services for an
// account, diffs them against the previously-synthesised set in the
// in-memory cache, and emits Create / Update / Delete proxy mappings
// to the affected clusters. Also triggers a peer-side network-map
// recompute via accountManager.UpdateAccountPeers so the
// private-service ACL injection picks up the new state immediately.
//
// Reconcile failures are logged and swallowed — the underlying CRUD
// has already completed, and the next mutation (or proxy reconnect)
// will re-converge the cluster's view.
func (m *managerImpl) reconcile(ctx context.Context, accountID string) {
if accountID == "" {
return
}
defer func() {
if m.accountManager != nil {
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{
Resource: types.UpdateResourceService,
Operation: types.UpdateOperationUpdate,
})
}
}()
if m.proxyController == nil {
return
}
services, err := SynthesizeServices(ctx, m.store, accountID)
if err != nil {
log.WithContext(ctx).WithError(err).Warnf("agent-network reconcile: synthesise services for account %s", accountID)
return
}
oidcCfg := m.proxyController.GetOIDCValidationConfig()
current := make(map[string]*proto.ProxyMapping, len(services))
for _, svc := range services {
if svc == nil || svc.ID == "" {
continue
}
current[svc.ID] = svc.ToProtoMapping(rpservice.Update, "", oidcCfg)
}
m.reconcileMu.Lock()
previous := m.reconcileCache[accountID]
if previous == nil {
previous = make(map[string]*proto.ProxyMapping)
}
creates, updates, deletes := diffMappings(previous, current)
if len(current) == 0 {
delete(m.reconcileCache, accountID)
} else {
m.reconcileCache[accountID] = current
}
m.reconcileMu.Unlock()
for _, mapping := range creates {
mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping))
}
for _, mapping := range updates {
mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping))
}
for _, mapping := range deletes {
mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping))
}
}
// diffMappings classifies the previous→current transition for a
// single account into Create / Update / Delete sets.
//
// Cluster moves (current.cluster != previous.cluster) are surfaced as
// a Delete on the old cluster + Create on the new — handled by
// emitting both a delete (on previous mapping) and a create (on the
// current mapping) for that service ID.
func diffMappings(previous, current map[string]*proto.ProxyMapping) (creates, updates, deletes []*proto.ProxyMapping) {
for id, cur := range current {
prev, existed := previous[id]
switch {
case !existed:
creates = append(creates, cur)
case prev.GetDomain() == "" || cur.GetAccountId() == prev.GetAccountId() && currentClusterChanged(prev, cur):
deletes = append(deletes, prev)
creates = append(creates, cur)
default:
updates = append(updates, cur)
}
}
for id, prev := range previous {
if _, stillThere := current[id]; !stillThere {
deletes = append(deletes, prev)
}
}
return creates, updates, deletes
}
func currentClusterChanged(prev, cur *proto.ProxyMapping) bool {
return clusterFromMapping(prev) != clusterFromMapping(cur)
}
// clusterFromMapping returns the cluster the mapping should be sent
// to. ProxyMapping doesn't carry the cluster directly, so we rely on
// the synthesised service's domain (`<slug>.<cluster>`) and split on
// the first '.'.
func clusterFromMapping(m *proto.ProxyMapping) string {
if m == nil {
return ""
}
domain := m.GetDomain()
for i := 0; i < len(domain); i++ {
if domain[i] == '.' {
return domain[i+1:]
}
}
return ""
}

View File

@@ -0,0 +1,232 @@
package agentnetwork
import (
"context"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/proto"
)
func newReconcileMgr(t *testing.T, ctrl *gomock.Controller) (*managerImpl, *store.MockStore, *proxy.MockController) {
t.Helper()
mockStore := store.NewMockStore(ctrl)
mockProxy := proxy.NewMockController(ctrl)
return &managerImpl{
store: mockStore,
proxyController: mockProxy,
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
}, mockStore, mockProxy
}
func newReconcileTestProvider() *types.Provider {
return &types.Provider{
ID: "prov-1",
AccountID: "acct-1",
ProviderID: "openai_api",
Name: "OpenAI",
UpstreamURL: "https://api.openai.com",
APIKey: "sk-test-key",
Enabled: true,
SessionPrivateKey: "test-priv-key",
SessionPublicKey: "test-pub-key",
}
}
func newReconcileTestPolicy(providerID, sourceGroupID string) *types.Policy {
return &types.Policy{
ID: "pol-1",
AccountID: "acct-1",
Name: "engineers",
Enabled: true,
SourceGroups: []string{sourceGroupID},
DestinationProviderIDs: []string{providerID},
}
}
func newReconcileTestSettings() *types.Settings {
return &types.Settings{
AccountID: "acct-1",
Cluster: "eu.proxy.netbird.io",
Subdomain: "violet",
}
}
func expectReconcileSynthInputs(mockStore *store.MockStore, ctx context.Context, providers []*types.Provider, policies []*types.Policy, guardrails []*types.Guardrail) {
mockStore.EXPECT().
GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").
Return(newReconcileTestSettings(), nil)
mockStore.EXPECT().
GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").
Return(providers, nil)
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").
Return(policies, nil)
mockStore.EXPECT().
GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1").
Return(guardrails, nil)
}
func TestReconcile_FirstSynth_EmitsCreate(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl)
provider := newReconcileTestProvider()
policy := newReconcileTestPolicy(provider.ID, "grp-eng")
expectReconcileSynthInputs(mockStore, ctx, []*types.Provider{provider}, []*types.Policy{policy}, []*types.Guardrail{})
mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{})
var sentMappings []*proto.ProxyMapping
mockProxy.EXPECT().
SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), "eu.proxy.netbird.io").
Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) {
sentMappings = append(sentMappings, m)
})
mgr.reconcile(ctx, "acct-1")
require.Len(t, sentMappings, 1, "first synth must emit one mapping")
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, sentMappings[0].Type, "first synth is a Create")
assert.Equal(t, "agent-net-svc-acct-1", sentMappings[0].Id, "stable account-scoped virtual service id")
assert.Equal(t, "violet.eu.proxy.netbird.io", sentMappings[0].Domain, "domain comes from settings (subdomain.cluster)")
mgr.reconcileMu.Lock()
cached := mgr.reconcileCache["acct-1"]
mgr.reconcileMu.Unlock()
require.Len(t, cached, 1, "cache must hold the synth result for next diff")
}
func TestReconcile_NoChange_EmitsNothingExtra(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl)
provider := newReconcileTestProvider()
policy := newReconcileTestPolicy(provider.ID, "grp-eng")
// Two identical synth runs.
mockStore.EXPECT().
GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").
Return(newReconcileTestSettings(), nil).Times(2)
mockStore.EXPECT().
GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").
Return([]*types.Provider{provider}, nil).Times(2)
mockStore.EXPECT().
GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").
Return([]*types.Policy{policy}, nil).Times(2)
mockStore.EXPECT().
GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1").
Return([]*types.Guardrail{}, nil).Times(2)
mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).Times(2)
createCalls := 0
updateCalls := 0
mockProxy.EXPECT().
SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), gomock.Any()).
Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) {
switch m.Type {
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
createCalls++
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
updateCalls++
}
}).
AnyTimes()
mgr.reconcile(ctx, "acct-1")
mgr.reconcile(ctx, "acct-1")
assert.Equal(t, 1, createCalls, "first reconcile creates")
assert.Equal(t, 1, updateCalls, "second reconcile re-pushes as Modified (no semantic change but mapping fields refresh)")
}
func TestReconcile_PolicyRemoved_EmitsDelete(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl)
provider := newReconcileTestProvider()
policy := newReconcileTestPolicy(provider.ID, "grp-eng")
gomock.InOrder(
// First reconcile: provider + policy, synthesised.
mockStore.EXPECT().GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").Return(newReconcileTestSettings(), nil),
mockStore.EXPECT().GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Provider{provider}, nil),
mockStore.EXPECT().GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Policy{policy}, nil),
mockStore.EXPECT().GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Guardrail{}, nil),
// Second reconcile: policy gone, provider stays but no longer referenced.
mockStore.EXPECT().GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").Return(newReconcileTestSettings(), nil),
mockStore.EXPECT().GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Provider{provider}, nil),
mockStore.EXPECT().GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Policy{}, nil),
)
mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()
var seenTypes []proto.ProxyMappingUpdateType
mockProxy.EXPECT().
SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), "eu.proxy.netbird.io").
Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) {
seenTypes = append(seenTypes, m.Type)
}).
AnyTimes()
mgr.reconcile(ctx, "acct-1")
mgr.reconcile(ctx, "acct-1")
require.Len(t, seenTypes, 2, "create then delete")
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, seenTypes[0])
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, seenTypes[1])
mgr.reconcileMu.Lock()
_, present := mgr.reconcileCache["acct-1"]
mgr.reconcileMu.Unlock()
assert.False(t, present, "cache for the account must be cleared once nothing is synthesised")
}
func TestReconcile_NilProxyController_NoOp(t *testing.T) {
ctx := context.Background()
mgr := &managerImpl{
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
}
// Must not panic; must not query the store.
mgr.reconcile(ctx, "acct-1")
}
func TestReconcile_EmptyAccountID_NoOp(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mgr, _, _ := newReconcileMgr(t, ctrl)
// Empty accountID short-circuits before any store call.
mgr.reconcile(ctx, "")
}
func TestClusterFromMapping(t *testing.T) {
tests := []struct {
name string
domain string
want string
}{
{"simple", "openai.eu.proxy.netbird.io", "eu.proxy.netbird.io"},
{"deeply nested", "a.b.c.d", "b.c.d"},
{"no dot", "openai", ""},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := clusterFromMapping(&proto.ProxyMapping{Domain: tt.domain})
assert.Equal(t, tt.want, got)
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,178 @@
package agentnetwork
import (
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/store"
)
// decodeServiceGuardrailConfig pulls the llm_guardrail middleware config off the
// synthesised service's single target.
func decodeServiceGuardrailConfig(t *testing.T, svc *rpservice.Service) guardrailConfig {
t.Helper()
require.NotEmpty(t, svc.Targets, "synth service must carry a target")
for _, mw := range svc.Targets[0].Options.Middlewares {
if mw.ID == middlewareIDLLMGuardrail {
var cfg guardrailConfig
require.NoError(t, json.Unmarshal(mw.ConfigJSON, &cfg), "guardrail config must decode")
return cfg
}
}
t.Fatal("llm_guardrail middleware not present on synthesised service")
return guardrailConfig{}
}
// decodeMiddlewareRawConfig returns the raw ConfigJSON bytes for the named
// middleware on the synth service's target, or fails the test.
func decodeMiddlewareRawConfig(t *testing.T, svc *rpservice.Service, id string) []byte {
t.Helper()
require.NotEmpty(t, svc.Targets, "synth service must carry a target")
for _, mw := range svc.Targets[0].Options.Middlewares {
if mw.ID == id {
return mw.ConfigJSON
}
}
t.Fatalf("middleware %q not present on synthesised service", id)
return nil
}
// saveGuardrailAndPolicy persists a guardrail with prompt capture + redact + a
// model allowlist, referenced by one enabled policy. Shared by the GC-3 tests.
func saveGuardrailAndPolicy(t *testing.T, ctx context.Context, s store.Store, provider *types.Provider) {
t.Helper()
guardrail := &types.Guardrail{
ID: "ainguard-1",
AccountID: testAccountID,
Name: "strict",
Checks: types.GuardrailChecks{
ModelAllowlist: types.GuardrailModelAllowlist{Enabled: true, Models: []string{"gpt-5.4"}},
PromptCapture: types.GuardrailPromptCapture{Enabled: true, RedactPii: true},
},
}
require.NoError(t, s.SaveAgentNetworkGuardrail(ctx, guardrail))
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", guardrail.ID)))
}
// TestSynthesizeServices_RealStore_PromptCaptureAccountIsSoleControl is the
// GC-3 contract: the account master switch (EnablePromptCollection) is the
// SOLE control for capture enablement. Policy-level guardrail prompt_capture is
// ignored for enablement — operators don't need to attach a capture guardrail
// to a policy just to turn capture on for the account. Off by default.
func TestSynthesizeServices_RealStore_PromptCaptureAccountIsSoleControl(t *testing.T) {
ctx := context.Background()
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err, "real sqlite test store must come up")
defer cleanup()
// Account collection master switch OFF (default).
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
saveGuardrailAndPolicy(t, ctx, s, newSynthTestProvider())
services, err := SynthesizeServices(ctx, s, testAccountID)
require.NoError(t, err)
require.Len(t, services, 1)
cfg := decodeServiceGuardrailConfig(t, services[0])
assert.Equal(t, []string{"gpt-5.4"}, cfg.ModelAllowlist,
"model allowlist is a pure policy guardrail and must always reach the config")
assert.False(t, cfg.PromptCapture.Enabled,
"prompt capture must be off when the account toggle is off, even with a capture-enabled guardrail")
}
// TestSynthesizeServices_RealStore_PromptCaptureFlowsWhenAccountOptsIn proves
// the account toggle is sufficient on its own — even with NO guardrail
// attached to the policy, capture fires when the account opts in. Redact is
// the OR of account + guardrail.
func TestSynthesizeServices_RealStore_PromptCaptureFlowsWhenAccountOptsIn(t *testing.T) {
ctx := context.Background()
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err, "real sqlite test store must come up")
defer cleanup()
settings := newSynthTestSettings()
settings.EnablePromptCollection = true
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
// Save a provider and a policy with NO guardrails attached — proves the
// account toggle is sufficient on its own.
provider := newSynthTestProvider()
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
services, err := SynthesizeServices(ctx, s, testAccountID)
require.NoError(t, err)
require.Len(t, services, 1)
cfg := decodeServiceGuardrailConfig(t, services[0])
assert.True(t, cfg.PromptCapture.Enabled,
"account toggle alone must enable capture; no guardrail attachment required")
}
// TestSynthesizeServices_RealStore_AccountRedactWithoutGuardrailRedact proves
// the redact OR-merge from the account side: account RedactPii on, guardrail
// redact off, capture on at both levels.
func TestSynthesizeServices_RealStore_AccountRedactWithoutGuardrailRedact(t *testing.T) {
ctx := context.Background()
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err, "real sqlite test store must come up")
defer cleanup()
settings := newSynthTestSettings()
settings.EnablePromptCollection = true
settings.RedactPii = true
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
provider := newSynthTestProvider()
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
guardrail := &types.Guardrail{
ID: "ainguard-noredact",
AccountID: testAccountID,
Name: "capture-only",
Checks: types.GuardrailChecks{
PromptCapture: types.GuardrailPromptCapture{Enabled: true, RedactPii: false},
},
}
require.NoError(t, s.SaveAgentNetworkGuardrail(ctx, guardrail))
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", guardrail.ID)))
services, err := SynthesizeServices(ctx, s, testAccountID)
require.NoError(t, err)
require.Len(t, services, 1)
cfg := decodeServiceGuardrailConfig(t, services[0])
assert.True(t, cfg.PromptCapture.Enabled, "capture on (account + guardrail)")
assert.True(t, cfg.PromptCapture.RedactPii, "account RedactPii must apply even when the guardrail leaves it off (OR)")
}
// TestSynthesizeServices_RealStore_NoGuardrail_CaptureOff pins the default:
// with no guardrail referenced, the synth service's guardrail config has prompt
// capture disabled and an empty allowlist. This is the "off by default" baseline
// the account switch must preserve.
func TestSynthesizeServices_RealStore_NoGuardrail_CaptureOff(t *testing.T) {
ctx := context.Background()
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err, "real sqlite test store must come up")
defer cleanup()
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
provider := newSynthTestProvider()
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
services, err := SynthesizeServices(ctx, s, testAccountID)
require.NoError(t, err)
require.Len(t, services, 1, "exactly one synth service expected")
cfg := decodeServiceGuardrailConfig(t, services[0])
assert.Empty(t, cfg.ModelAllowlist, "no guardrail → no allowlist")
assert.False(t, cfg.PromptCapture.Enabled, "no guardrail → prompt capture off by default")
assert.False(t, cfg.PromptCapture.RedactPii, "no guardrail → redact off by default")
}

View File

@@ -0,0 +1,70 @@
package agentnetwork
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/store"
)
// TestSynthesizeServices_RealStore_LogCollectionOff_SuppressesAccessLog drives the
// happy default: account settings ship with EnableLogCollection=false, so the
// synthesised target opts out of access-log emission (DisableAccessLog=true) and
// the proto mapping the proxy receives reflects that.
func TestSynthesizeServices_RealStore_LogCollectionOff_SuppressesAccessLog(t *testing.T) {
ctx := context.Background()
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err, "real sqlite test store must come up")
defer cleanup()
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
provider := newSynthTestProvider()
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
services, err := SynthesizeServices(ctx, s, testAccountID)
require.NoError(t, err)
require.Len(t, services, 1, "exactly one synth service expected")
require.NotEmpty(t, services[0].Targets, "synth service must carry a target")
assert.True(t, services[0].Targets[0].Options.DisableAccessLog,
"EnableLogCollection=false (default) must produce DisableAccessLog=true on the synth target")
mapping := services[0].ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{})
require.NotEmpty(t, mapping.GetPath(), "proto mapping must carry a path")
assert.True(t, mapping.GetPath()[0].GetOptions().GetDisableAccessLog(),
"proto mapping must propagate DisableAccessLog=true so the proxy suppresses access-log emission")
}
// TestSynthesizeServices_RealStore_LogCollectionOn_PermitsAccessLog asserts the
// inverse: once the account opts in, the synth target leaves DisableAccessLog
// at its default false and the proto wire stays unset.
func TestSynthesizeServices_RealStore_LogCollectionOn_PermitsAccessLog(t *testing.T) {
ctx := context.Background()
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err, "real sqlite test store must come up")
defer cleanup()
settings := newSynthTestSettings()
settings.EnableLogCollection = true
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
provider := newSynthTestProvider()
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
services, err := SynthesizeServices(ctx, s, testAccountID)
require.NoError(t, err)
require.Len(t, services, 1, "exactly one synth service expected")
require.NotEmpty(t, services[0].Targets, "synth service must carry a target")
assert.False(t, services[0].Targets[0].Options.DisableAccessLog,
"EnableLogCollection=true must leave DisableAccessLog=false on the synth target")
mapping := services[0].ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{})
require.NotEmpty(t, mapping.GetPath(), "proto mapping must carry a path")
assert.False(t, mapping.GetPath()[0].GetOptions().GetDisableAccessLog(),
"proto mapping must propagate DisableAccessLog=false so access-log emission stays on")
}

View File

@@ -0,0 +1,145 @@
package agentnetwork
import (
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/store"
)
// parserRedactConfig mirrors the on-wire shape of the redact + capture knobs
// that both llm_request_parser and llm_response_parser unmarshal. We don't
// import the proxy-side packages from a management test (cross-module), so we
// decode the JSON directly and assert on the fields that are part of the
// synth contract.
type parserRedactConfig struct {
RedactPii bool `json:"redact_pii,omitempty"`
CapturePrompt *bool `json:"capture_prompt,omitempty"` // present only on the request parser
CaptureCompletion *bool `json:"capture_completion,omitempty"` // present only on the response parser
}
// TestSynthesizeServices_RealStore_ParserConfigsCarryRedactPii is the
// management-side contract test for the request/response parser redaction
// wiring. When settings.RedactPii is true, the synthesised middleware chain
// MUST stamp redact_pii=true on both llm_request_parser and llm_response_parser
// configs — otherwise the parsers ship raw prompts / completions to the
// access log even though the account has opted in. This is exactly the live
// leak path that motivated the parser-side redaction in the first place.
func TestSynthesizeServices_RealStore_ParserConfigsCarryRedactPii(t *testing.T) {
ctx := context.Background()
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err, "real sqlite test store must come up")
defer cleanup()
settings := newSynthTestSettings()
settings.RedactPii = true
settings.EnablePromptCollection = true
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
provider := newSynthTestProvider()
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
services, err := SynthesizeServices(ctx, s, testAccountID)
require.NoError(t, err)
require.Len(t, services, 1, "exactly one synth service expected")
for _, parserID := range []string{middlewareIDLLMRequestParser, middlewareIDLLMResponseParser} {
raw := decodeMiddlewareRawConfig(t, services[0], parserID)
var cfg parserRedactConfig
require.NoError(t, json.Unmarshal(raw, &cfg), "%s config must be valid JSON", parserID)
assert.True(t, cfg.RedactPii, "%s config must carry redact_pii=true when settings.RedactPii is on (otherwise the parser ships raw prompts/completions to the access log)", parserID)
}
// The capture flag is set explicitly to enable_prompt_collection on each
// parser. With it on here, both must allow emission.
reqCfg := decodeParserConfig(t, services[0], middlewareIDLLMRequestParser)
require.NotNil(t, reqCfg.CapturePrompt, "request parser must carry an explicit capture_prompt")
assert.True(t, *reqCfg.CapturePrompt, "capture_prompt=true when EnablePromptCollection=true")
respCfg := decodeParserConfig(t, services[0], middlewareIDLLMResponseParser)
require.NotNil(t, respCfg.CaptureCompletion, "response parser must carry an explicit capture_completion")
assert.True(t, *respCfg.CaptureCompletion, "capture_completion=true when EnablePromptCollection=true")
}
// decodeParserConfig is a small helper around decodeMiddlewareRawConfig that
// also unmarshals into parserRedactConfig.
func decodeParserConfig(t *testing.T, svc *rpservice.Service, parserID string) parserRedactConfig {
t.Helper()
raw := decodeMiddlewareRawConfig(t, svc, parserID)
var cfg parserRedactConfig
require.NoError(t, json.Unmarshal(raw, &cfg), "%s config must be valid JSON", parserID)
return cfg
}
// TestSynthesizeServices_RealStore_ParserConfigsSuppressCaptureWhenLogCollectionOnly
// is the contract test for the bug: enable_log_collection=true with
// enable_prompt_collection=false MUST result in capture_prompt=false on the
// request parser AND capture_completion=false on the response parser, so the
// access-log row stays metadata-only (provider, model, tokens, cost) and
// carries NO prompt input nor response output. Without this, operators who
// want billing-style logs end up with raw user prompts and model outputs in
// every access-log entry.
func TestSynthesizeServices_RealStore_ParserConfigsSuppressCaptureWhenLogCollectionOnly(t *testing.T) {
ctx := context.Background()
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err, "real sqlite test store must come up")
defer cleanup()
settings := newSynthTestSettings()
settings.EnableLogCollection = true // operator wants logs ON
settings.EnablePromptCollection = false // but NOT content capture
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
provider := newSynthTestProvider()
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
services, err := SynthesizeServices(ctx, s, testAccountID)
require.NoError(t, err)
require.Len(t, services, 1)
reqCfg := decodeParserConfig(t, services[0], middlewareIDLLMRequestParser)
require.NotNil(t, reqCfg.CapturePrompt, "request parser must carry an explicit capture_prompt gate")
assert.False(t, *reqCfg.CapturePrompt, "capture_prompt MUST be false when EnablePromptCollection is off — otherwise llm.request_prompt_raw leaks user input into the access log")
respCfg := decodeParserConfig(t, services[0], middlewareIDLLMResponseParser)
require.NotNil(t, respCfg.CaptureCompletion, "response parser must carry an explicit capture_completion gate")
assert.False(t, *respCfg.CaptureCompletion, "capture_completion MUST be false when EnablePromptCollection is off — otherwise llm.response_completion leaks model output into the access log")
}
// TestSynthesizeServices_RealStore_ParserConfigsOmitRedactPiiWhenOff proves
// the inverse: with the account toggle off, the parser configs stay clean (no
// redact_pii field, which the parsers treat as zero / no redaction). This is
// the operator-opt-out path — the access log keeps raw prompts/completions
// for debugging until the operator opts in.
func TestSynthesizeServices_RealStore_ParserConfigsOmitRedactPiiWhenOff(t *testing.T) {
ctx := context.Background()
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
defer cleanup()
// Default settings: RedactPii = false.
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
provider := newSynthTestProvider()
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
services, err := SynthesizeServices(ctx, s, testAccountID)
require.NoError(t, err)
require.Len(t, services, 1)
for _, parserID := range []string{middlewareIDLLMRequestParser, middlewareIDLLMResponseParser} {
raw := decodeMiddlewareRawConfig(t, services[0], parserID)
// Inspect the decoded JSON directly: a struct decode would also pass
// if redact_pii were present-but-false. The contract is that the key
// is omitted entirely while the account toggle is off.
var rawCfg map[string]json.RawMessage
require.NoError(t, json.Unmarshal(raw, &rawCfg), "%s config must be valid JSON", parserID)
assert.NotContains(t, rawCfg, "redact_pii",
"%s config must omit redact_pii entirely while the account toggle is off", parserID)
}
}

View File

@@ -0,0 +1,174 @@
package agentnetwork
import (
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/store"
nbtypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// decodeServiceRouterConfig finds the llm_router middleware on the synthesised
// service's single target and decodes its config — the model→provider routing
// table the proxy authorises against.
func decodeServiceRouterConfig(t *testing.T, svc *rpservice.Service) routerConfig {
t.Helper()
require.NotEmpty(t, svc.Targets, "synth service must carry a target")
for _, mw := range svc.Targets[0].Options.Middlewares {
if mw.ID == middlewareIDLLMRouter {
var cfg routerConfig
require.NoError(t, json.Unmarshal(mw.ConfigJSON, &cfg), "router config must decode")
return cfg
}
}
t.Fatal("llm_router middleware not present on synthesised service")
return routerConfig{}
}
// decodeMappingRouterConfig is the proto-wire equivalent: it pulls the
// llm_router config off the ProxyMapping the proxy actually receives.
func decodeMappingRouterConfig(t *testing.T, m *proto.ProxyMapping) routerConfig {
t.Helper()
require.NotEmpty(t, m.GetPath(), "mapping must carry a path")
for _, mw := range m.GetPath()[0].GetOptions().GetMiddlewares() {
if mw.GetId() == middlewareIDLLMRouter {
var cfg routerConfig
require.NoError(t, json.Unmarshal(mw.GetConfigJson(), &cfg), "wire router config must decode")
return cfg
}
}
t.Fatal("llm_router middleware not present on proxy mapping")
return routerConfig{}
}
// TestSynthesizeServices_RealStore_SurvivesStatusToggle drives synthesis through
// a REAL sqlite store (Save → gorm/JSON serialize → reload → decrypt) instead of
// a MockStore, so it exercises the field round-trip that a provider/policy edit
// actually hits. Mock-based tests can't catch a field that dies in persistence;
// this one can. It then performs the exact operation that reproduced the live
// 403 — disable then re-enable the provider — and asserts the re-enabled state
// is fully routable again.
func TestSynthesizeServices_RealStore_SurvivesStatusToggle(t *testing.T) {
ctx := context.Background()
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err, "real sqlite test store must come up")
defer cleanup()
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
provider := newSynthTestProvider()
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
assertRoutable := func(t *testing.T, stage string) {
services, err := SynthesizeServices(ctx, s, testAccountID)
require.NoError(t, err, stage)
require.Len(t, services, 1, "%s: exactly one synth service expected", stage)
svc := services[0]
assert.True(t, svc.Private, "%s: synth service must be Private after store round-trip", stage)
assert.Equal(t, []string{"grp-eng"}, svc.AccessGroups, "%s: AccessGroups must survive the round-trip", stage)
m := svc.ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{})
assert.True(t, m.GetPrivate(), "%s: proto mapping Private must be true (proxy gates tunnel-peer auth on it)", stage)
cfg := decodeServiceRouterConfig(t, svc)
require.Len(t, cfg.Providers, 1, "%s: the enabled+linked provider must appear in the router config", stage)
assert.Equal(t, []string{"gpt-5.4"}, cfg.Providers[0].Models, "%s: provider models must reach the route", stage)
assert.Equal(t, []string{"grp-eng"}, cfg.Providers[0].AllowedGroupIDs, "%s: policy source groups must reach the route", stage)
}
assertRoutable(t, "initial")
provider.Enabled = false
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
disabled, err := SynthesizeServices(ctx, s, testAccountID)
require.NoError(t, err, "synthesis must not error with a disabled provider")
for _, svc := range disabled {
assert.Empty(t, decodeServiceRouterConfig(t, svc).Providers,
"a disabled provider must not appear in the router config (otherwise it would route while off)")
}
provider.Enabled = true
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
assertRoutable(t, "after disable->enable")
}
// captureController is a proxy.Controller that records the mappings reconcile
// pushes, so the test can inspect the exact wire payload — Private flag and
// router config included.
type captureController struct {
rpproxy.Controller
pushed []*proto.ProxyMapping
}
func (c *captureController) GetOIDCValidationConfig() rpproxy.OIDCValidationConfig {
return rpproxy.OIDCValidationConfig{}
}
func (c *captureController) SendServiceUpdateToCluster(_ context.Context, _ string, update *proto.ProxyMapping, _ string) {
c.pushed = append(c.pushed, update)
}
// noopAccountManager satisfies the reconcile path's accountManager dependency.
type noopAccountManager struct {
account.Manager
}
func (noopAccountManager) UpdateAccountPeers(context.Context, string, nbtypes.UpdateReason) {}
// TestReconcile_RealStore_PushesPrivateAfterStatusToggle reproduces the live
// path end-to-end below the gRPC boundary: a real store + the real
// managerImpl.reconcile + a capturing proxy controller. It runs the operation
// that broke in production — provider disable then re-enable — and asserts the
// mapping reconcile pushes to the cluster after re-enable is Private=true and
// carries the routable provider. If reconcile ever pushes private=false (the
// symptom that left UserGroups empty → no_authorised_provider), this fails.
func TestReconcile_RealStore_PushesPrivateAfterStatusToggle(t *testing.T) {
ctx := context.Background()
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
defer cleanup()
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
provider := newSynthTestProvider()
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
ctrl := &captureController{}
m := &managerImpl{
store: s,
accountManager: noopAccountManager{},
proxyController: ctrl,
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
}
m.reconcile(ctx, testAccountID) // initial, provider enabled
provider.Enabled = false
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
m.reconcile(ctx, testAccountID) // disabled
provider.Enabled = true
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
m.reconcile(ctx, testAccountID) // re-enabled — the reproduction step
require.NotEmpty(t, ctrl.pushed, "reconcile must push at least one mapping")
last := ctrl.pushed[len(ctrl.pushed)-1]
assert.Equal(t, newSynthTestSettings().Endpoint(), last.GetDomain(), "synth domain on the wire")
assert.True(t, last.GetPrivate(),
"reconcile-pushed mapping after re-enable MUST be Private=true; a false here is the exact bug — the proxy skips ValidateTunnelPeer, UserGroups stays empty, and llm_router denies no_authorised_provider")
cfg := decodeMappingRouterConfig(t, last)
require.Len(t, cfg.Providers, 1, "re-enabled provider must be back in the pushed router config")
assert.Equal(t, []string{"gpt-5.4"}, cfg.Providers[0].Models, "model must be routable again after re-enable")
assert.Equal(t, []string{"grp-eng"}, cfg.Providers[0].AllowedGroupIDs, "authorised groups must be present after re-enable")
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,117 @@
package types
import (
"time"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// AgentNetworkAccessLog is the dedicated, flattened agent-network access-log
// row. Unlike the shared reverse-proxy AccessLogEntry (which kept LLM data in
// an opaque metadata JSON blob), the LLM dimensions live in first-class,
// indexed columns so the access-log surface can filter server-side by
// user / group / provider / model / decision.
type AgentNetworkAccessLog struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
ServiceID string `gorm:"index"`
Timestamp time.Time `gorm:"index"`
UserID string `gorm:"index"`
SourceIP string
Method string
Host string
Path string `gorm:"type:text"`
Duration time.Duration
StatusCode int `gorm:"index"`
AuthMethod string
BytesUpload int64
BytesDownload int64
// Flattened LLM dimensions (queryable). Sourced from proxy metadata keys.
Provider string `gorm:"index"` // vendor, e.g. "openai" (llm.provider)
Model string `gorm:"index"` // llm.model
SessionID string `gorm:"index"` // llm.session_id — groups a conversation / coding session
ResolvedProviderID string `gorm:"index"` // llm.resolved_provider_id
SelectedPolicyID string `gorm:"index"` // llm.selected_policy_id
Decision string `gorm:"index"` // llm_policy.decision (allow/deny)
DenyReason string // llm_policy.reason (raw code, mapped in the UI)
InputTokens int64
OutputTokens int64
TotalTokens int64
CostUSD float64
Stream bool
// Prompt capture. Only populated when prompt collection is enabled
// (account master switch AND policy guardrail). Heavy free text.
RequestPrompt string `gorm:"type:text"`
ResponseCompletion string `gorm:"type:text"`
CreatedAt time.Time
// GroupIDs is the authorising group ids for this entry, hydrated from the
// group child table on read. Not a column.
GroupIDs []string `gorm:"-"`
}
// TableName keeps agent-network access logs in their own table, separate from
// the reverse-proxy AccessLogEntry table.
func (AgentNetworkAccessLog) TableName() string { return "agent_network_access_log" }
// ToAPIResponse renders the flattened entry as the API representation.
func (a *AgentNetworkAccessLog) ToAPIResponse() api.AgentNetworkAccessLog {
out := api.AgentNetworkAccessLog{
Id: a.ID,
ServiceId: a.ServiceID,
Timestamp: a.Timestamp,
StatusCode: a.StatusCode,
DurationMs: int(a.Duration.Milliseconds()),
InputTokens: a.InputTokens,
OutputTokens: a.OutputTokens,
TotalTokens: a.TotalTokens,
CostUsd: a.CostUSD,
Stream: &a.Stream,
}
out.UserId = strPtr(a.UserID)
out.SourceIp = strPtr(a.SourceIP)
out.Method = strPtr(a.Method)
out.Host = strPtr(a.Host)
out.Path = strPtr(a.Path)
out.Provider = strPtr(a.Provider)
out.Model = strPtr(a.Model)
out.SessionId = strPtr(a.SessionID)
out.ResolvedProviderId = strPtr(a.ResolvedProviderID)
out.SelectedPolicyId = strPtr(a.SelectedPolicyID)
out.Decision = strPtr(a.Decision)
out.DenyReason = strPtr(a.DenyReason)
out.RequestPrompt = strPtr(a.RequestPrompt)
out.ResponseCompletion = strPtr(a.ResponseCompletion)
if len(a.GroupIDs) > 0 {
groups := a.GroupIDs
out.GroupIds = &groups
}
return out
}
// strPtr returns a pointer to s, or nil when s is empty — so empty optional
// fields are omitted from the JSON rather than serialised as "".
func strPtr(s string) *string {
if s == "" {
return nil
}
return &s
}
// AgentNetworkAccessLogGroup is the normalised many-to-many row linking a log
// entry to one authorising group, so the access-log endpoint can filter by
// group with a simple `group_id IN (...)` join instead of substring-matching a
// CSV column.
type AgentNetworkAccessLogGroup struct {
LogID string `gorm:"primaryKey"`
GroupID string `gorm:"primaryKey;index"`
AccountID string `gorm:"index"`
}
// TableName names the access-log group child table.
func (AgentNetworkAccessLogGroup) TableName() string { return "agent_network_access_log_group" }

View File

@@ -0,0 +1,213 @@
package types
import (
"math"
"net/http"
"strconv"
"strings"
"time"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
// AccessLogDefaultPageSize is the default number of records per page.
AccessLogDefaultPageSize = 50
// AccessLogMaxPageSize is the maximum number of records allowed per page.
AccessLogMaxPageSize = 100
accessLogDefaultSortBy = "timestamp"
accessLogDefaultSortOrder = "desc"
// usageOverviewDefaultLookback bounds an unbounded usage-overview query so
// it never aggregates an account's entire history into memory.
usageOverviewDefaultLookback = 90 * 24 * time.Hour
// usageOverviewMaxRange caps how far back an explicit range may reach.
usageOverviewMaxRange = 366 * 24 * time.Hour
)
// ApplyUsageOverviewBounds bounds a missing or over-wide date range so the
// in-memory usage aggregation can't load an account's full usage history. An
// absent range defaults to the last usageOverviewDefaultLookback; a range wider
// than usageOverviewMaxRange is clamped from the (possibly defaulted) end.
func (f *AgentNetworkAccessLogFilter) ApplyUsageOverviewBounds(now time.Time) {
end := now
if f.EndDate != nil {
end = *f.EndDate
}
f.EndDate = &end
if f.StartDate == nil {
start := end.Add(-usageOverviewDefaultLookback)
f.StartDate = &start
return
}
if end.Sub(*f.StartDate) > usageOverviewMaxRange {
start := end.Add(-usageOverviewMaxRange)
f.StartDate = &start
}
}
// accessLogSortFields maps the API sort_by values to their database columns.
var accessLogSortFields = map[string]string{
"timestamp": "timestamp",
"model": "model",
"provider": "provider",
"status_code": "status_code",
"duration": "duration",
"cost_usd": "cost_usd",
"total_tokens": "total_tokens",
"user_id": "user_id",
"decision": "decision",
}
// AgentNetworkAccessLogFilter holds pagination, filtering and sorting
// parameters for the agent-network access-log listing. Group / provider /
// model are multi-valued (the UI uses multi-select; an entry matches when it
// matches any selected value).
type AgentNetworkAccessLogFilter struct {
Page int
PageSize int
SortBy string
SortOrder string
Search *string // log id, host, path, model, user email/name
UserID *string // exact user id (the dashboard sends the picked user's id)
SessionID *string // exact session id — groups one conversation / coding session
GroupIDs []string // authorising group ids (match any)
ProviderIDs []string // resolved provider ids (match any)
Models []string // models (match any)
Decision *string // policy decision (allow/deny)
PathPrefix *string // request path prefix (path LIKE 'prefix%')
StartDate *time.Time // timestamp >= start_date
EndDate *time.Time // timestamp <= end_date
}
// ParseFromRequest fills the filter from the request query parameters. It
// returns a validation error when a supplied start_date / end_date is present
// but not valid RFC3339: silently dropping a malformed date would broaden the
// query (and, for the usage overview, fall back to the default window).
func (f *AgentNetworkAccessLogFilter) ParseFromRequest(r *http.Request) error {
q := r.URL.Query()
f.Page = parseAccessLogPositiveInt(q.Get("page"), 1)
f.PageSize = min(parseAccessLogPositiveInt(q.Get("page_size"), AccessLogDefaultPageSize), AccessLogMaxPageSize)
f.SortBy = parseAccessLogSortField(q.Get("sort_by"))
f.SortOrder = parseAccessLogSortOrder(q.Get("sort_order"))
f.Search = parseAccessLogOptionalString(q.Get("search"))
f.UserID = parseAccessLogOptionalString(q.Get("user_id"))
f.SessionID = parseAccessLogOptionalString(q.Get("session_id"))
f.Decision = parseAccessLogOptionalString(q.Get("decision"))
f.PathPrefix = parseAccessLogOptionalString(q.Get("path"))
// Multi-value filters accept either repeated params (?group_id=a&group_id=b)
// or a single comma-separated value (?group_id=a,b) so both the OpenAPI
// array form and the dashboard's single-value query builder work.
f.GroupIDs = splitMultiValue(q["group_id"])
f.ProviderIDs = splitMultiValue(q["provider_id"])
f.Models = splitMultiValue(q["model"])
var err error
if f.StartDate, err = parseAccessLogOptionalRFC3339(q.Get("start_date")); err != nil {
return status.Errorf(status.InvalidArgument, "invalid start_date: %v", err)
}
if f.EndDate, err = parseAccessLogOptionalRFC3339(q.Get("end_date")); err != nil {
return status.Errorf(status.InvalidArgument, "invalid end_date: %v", err)
}
return nil
}
// GetSortColumn returns the database column for the active sort field.
func (f *AgentNetworkAccessLogFilter) GetSortColumn() string {
if col, ok := accessLogSortFields[f.SortBy]; ok {
return col
}
return accessLogSortFields[accessLogDefaultSortBy]
}
// GetSortOrder returns the normalised sort order ("ASC"/"DESC").
func (f *AgentNetworkAccessLogFilter) GetSortOrder() string {
if strings.EqualFold(f.SortOrder, "asc") {
return "ASC"
}
return "DESC"
}
// GetLimit returns the page size, defaulting/clamping when unset.
func (f *AgentNetworkAccessLogFilter) GetLimit() int {
if f.PageSize <= 0 {
return AccessLogDefaultPageSize
}
return min(f.PageSize, AccessLogMaxPageSize)
}
// GetOffset returns the zero-based row offset for the active page. Page is
// user-controlled, so the multiplication is guarded against int overflow.
func (f *AgentNetworkAccessLogFilter) GetOffset() int {
limit := f.GetLimit()
if f.Page <= 1 || limit <= 0 {
return 0
}
if f.Page-1 > math.MaxInt/limit {
return math.MaxInt - (math.MaxInt % limit)
}
return (f.Page - 1) * limit
}
func parseAccessLogPositiveInt(s string, def int) int {
if v, err := strconv.Atoi(strings.TrimSpace(s)); err == nil && v > 0 {
return v
}
return def
}
func parseAccessLogSortField(s string) string {
if _, ok := accessLogSortFields[s]; ok {
return s
}
return accessLogDefaultSortBy
}
func parseAccessLogSortOrder(s string) string {
if strings.EqualFold(s, "asc") {
return "asc"
}
return accessLogDefaultSortOrder
}
func parseAccessLogOptionalString(s string) *string {
if s = strings.TrimSpace(s); s != "" {
return &s
}
return nil
}
func parseAccessLogOptionalRFC3339(s string) (*time.Time, error) {
if s = strings.TrimSpace(s); s == "" {
return nil, nil //nolint:nilnil // not provided: no value and no error
}
t, err := time.Parse(time.RFC3339, s)
if err != nil {
return nil, err
}
return &t, nil
}
// splitMultiValue flattens repeated query params and comma-separated values
// into a single trimmed, blank-free list. Returns nil when nothing remains so
// callers can skip the filter entirely.
func splitMultiValue(values []string) []string {
out := make([]string, 0, len(values))
for _, raw := range values {
for _, v := range strings.Split(raw, ",") {
if v = strings.TrimSpace(v); v != "" {
out = append(out, v)
}
}
}
if len(out) == 0 {
return nil
}
return out
}

View File

@@ -0,0 +1,106 @@
package types
import (
"time"
"github.com/rs/xid"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// AccountBudgetRule is an account-level, limit-only rule bound to groups
// and/or users. It mirrors the policy budget experience without any routing:
// it carries the same cap shape as a policy (PolicyLimits) but never selects a
// provider. Rules apply across policies as an always-on ceiling — every
// applicable rule binds (min-wins), so a rule can only tighten a caller's
// effective limit, never loosen it.
//
// TargetGroups matches when it intersects the caller's groups; TargetUsers
// binds a specific user directly. Empty TargetGroups and TargetUsers means the
// rule applies to every caller (the account-wide default).
type AccountBudgetRule struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Enabled bool
TargetGroups []string `gorm:"serializer:json;column:target_groups"`
TargetUsers []string `gorm:"serializer:json;column:target_users"`
Limits PolicyLimits `gorm:"serializer:json;column:limits"`
CreatedAt time.Time
UpdatedAt time.Time
}
// TableName puts budget rules in their own table.
func (AccountBudgetRule) TableName() string { return "agent_network_budget_rules" }
// NewAccountBudgetRule returns a new rule with a freshly minted ID.
func NewAccountBudgetRule(accountID string) *AccountBudgetRule {
now := time.Now().UTC()
return &AccountBudgetRule{
ID: "ainbud_" + xid.New().String(),
AccountID: accountID,
Enabled: true,
CreatedAt: now,
UpdatedAt: now,
}
}
// Copy returns a deep copy of the rule, including its target slices.
func (r *AccountBudgetRule) Copy() *AccountBudgetRule {
c := *r
c.TargetGroups = append([]string(nil), r.TargetGroups...)
c.TargetUsers = append([]string(nil), r.TargetUsers...)
return &c
}
// EventMeta renders the rule for the activity log.
func (r *AccountBudgetRule) EventMeta() map[string]any {
return map[string]any{
"name": r.Name,
"enabled": r.Enabled,
}
}
// FromAPIRequest applies the request payload onto the receiver.
func (r *AccountBudgetRule) FromAPIRequest(req *api.AgentNetworkBudgetRuleRequest) {
r.Name = req.Name
if req.Enabled != nil {
r.Enabled = *req.Enabled
}
if req.TargetGroups != nil {
r.TargetGroups = append([]string(nil), (*req.TargetGroups)...)
} else {
r.TargetGroups = []string{}
}
if req.TargetUsers != nil {
r.TargetUsers = append([]string(nil), (*req.TargetUsers)...)
} else {
r.TargetUsers = []string{}
}
r.Limits = limitsFromAPI(req.Limits)
}
// ToAPIResponse renders the rule as the API representation.
func (r *AccountBudgetRule) ToAPIResponse() *api.AgentNetworkBudgetRule {
groups := r.TargetGroups
if groups == nil {
groups = []string{}
}
users := r.TargetUsers
if users == nil {
users = []string{}
}
created := r.CreatedAt
updated := r.UpdatedAt
return &api.AgentNetworkBudgetRule{
Id: r.ID,
Name: r.Name,
Enabled: r.Enabled,
TargetGroups: groups,
TargetUsers: users,
Limits: limitsToAPI(r.Limits),
CreatedAt: &created,
UpdatedAt: &updated,
}
}

View File

@@ -0,0 +1,69 @@
package types
import "time"
// ConsumptionDimension classifies which kind of identity a consumption
// row counts against. The proxy-side enforcement layer ticks one row
// per dimension per request — typically one user row plus one group
// row.
type ConsumptionDimension string
const (
// DimensionUser counts tokens / spend for a single end user. The
// dim_id column carries the netbird user id (or peer.ID when the
// caller is a tunnel-peer principal).
DimensionUser ConsumptionDimension = "user"
// DimensionGroup counts tokens / spend for a single source group
// across every member of that group. The dim_id column carries
// the netbird group id.
DimensionGroup ConsumptionDimension = "group"
)
// Consumption is a per-dimension token + USD counter for a fixed
// aligned window. The (account, dim_kind, dim_id, window_seconds,
// window_start) tuple is the primary key; rows are rolled forward by
// the proxy's post-flight RecordLLMUsage path on every request.
//
// The same dim_id (e.g. a group id) gets one row per distinct
// window_seconds length in scope across the account's policies,
// because two policies with different window lengths read independent
// counters even though they share the dimension. Two policies with
// identical window_seconds on the same dimension share one counter
// (correct: their caps are checked against the same shared bucket).
type Consumption struct {
AccountID string `gorm:"primaryKey;type:varchar(255)"`
DimensionKind ConsumptionDimension `gorm:"primaryKey;type:varchar(16);column:dim_kind"`
DimensionID string `gorm:"primaryKey;type:varchar(255);column:dim_id"`
WindowSeconds int64 `gorm:"primaryKey;column:window_seconds"`
WindowStartUTC time.Time `gorm:"primaryKey;column:window_start_utc"`
TokensInput int64 `gorm:"column:tokens_input"`
TokensOutput int64 `gorm:"column:tokens_output"`
CostUSD float64 `gorm:"column:cost_usd"`
UpdatedAt time.Time
}
// TableName forces a stable name independent of GORM's pluraliser.
func (Consumption) TableName() string { return "agent_network_consumption" }
// ConsumptionKey identifies a single consumption counter within an account:
// the (dim_kind, dim_id, window_seconds, window_start) part of the row's
// primary key. Used to batch-read and batch-increment many counters for one
// request in a single store round-trip / transaction.
type ConsumptionKey struct {
Kind ConsumptionDimension
DimID string
WindowSeconds int64
WindowStartUTC time.Time
}
// WindowStart returns the aligned UTC start of the window of length
// windowSeconds that contains t. Aligned to the unix epoch so the
// same bucket boundary is computed deterministically across processes.
func WindowStart(t time.Time, windowSeconds int64) time.Time {
if windowSeconds <= 0 {
return t.UTC()
}
step := windowSeconds * int64(time.Second)
bucketed := t.UTC().UnixNano() / step * step
return time.Unix(0, bucketed).UTC()
}

View File

@@ -0,0 +1,141 @@
package types
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// TestWindowStart_AlignedToUnixEpoch is the multi-node-convergence
// guarantee: any two proxies computing WindowStart(now, s) for the
// same s must land on the same boundary. The implementation aligns
// to the unix epoch (UTC) rather than local time, calendar weeks, or
// process start time — none of which are shared across nodes.
//
// Table covers the load-bearing window lengths (5m, 1h, 24h, 30d)
// plus a few odd values that still need to align cleanly.
func TestWindowStart_AlignedToUnixEpoch(t *testing.T) {
cases := []struct {
name string
instant time.Time
windowSeconds int64
want time.Time
}{
{
name: "5m window — drops seconds inside the bucket",
instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.UTC),
windowSeconds: 300,
want: time.Date(2026, 5, 6, 13, 45, 0, 0, time.UTC),
},
{
name: "1h window — drops minutes / seconds, keeps the hour",
instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.UTC),
windowSeconds: 3600,
want: time.Date(2026, 5, 6, 13, 0, 0, 0, time.UTC),
},
{
name: "24h window aligns to UTC midnight",
instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.UTC),
windowSeconds: 86_400,
want: time.Date(2026, 5, 6, 0, 0, 0, 0, time.UTC),
},
{
name: "30d (2_592_000s) window aligns to the 30d epoch grid, not month boundaries",
instant: time.Date(2026, 5, 6, 0, 0, 0, 0, time.UTC),
windowSeconds: 2_592_000,
// 2026-05-06 UTC = 1778025600s; 1778025600 / 2592000 = 685
// 685 * 2592000 = 1775520000s = 2026-04-07 00:00:00 UTC
want: time.Date(2026, 4, 7, 0, 0, 0, 0, time.UTC),
},
{
name: "non-UTC input still anchors on UTC epoch boundaries",
instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.FixedZone("CEST", 2*3600)),
windowSeconds: 86_400,
// 2026-05-06 13:47:23 CEST = 11:47:23 UTC → bucket 2026-05-06 00:00:00 UTC
want: time.Date(2026, 5, 6, 0, 0, 0, 0, time.UTC),
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := WindowStart(tc.instant, tc.windowSeconds)
assert.True(t, got.Equal(tc.want),
"WindowStart(%v, %ds) = %v, want %v", tc.instant, tc.windowSeconds, got, tc.want)
})
}
}
// TestWindowStart_WithinWindowConverges proves the determinism
// contract: any two timestamps inside the same window land on the
// exact same boundary. Two proxy nodes serving requests 7s apart
// must agree on which counter row to upsert.
func TestWindowStart_WithinWindowConverges(t *testing.T) {
t1 := time.Date(2026, 5, 6, 14, 0, 0, 0, time.UTC)
t2 := t1.Add(7 * time.Second)
t3 := t1.Add(59*time.Minute + 59*time.Second)
a := WindowStart(t1, 3600)
b := WindowStart(t2, 3600)
c := WindowStart(t3, 3600)
assert.True(t, a.Equal(b), "two timestamps 7s apart in the same 1h window must align to the same boundary")
assert.True(t, a.Equal(c), "the very last second of a 1h window still lands on the SAME bucket as the first second")
}
// TestWindowStart_AcrossWindowsDiverges is the symmetric guarantee:
// two timestamps separated by a window's worth of time MUST land on
// different boundaries. Without this, a 24h window's "rollover"
// would never reset the counter.
func TestWindowStart_AcrossWindowsDiverges(t *testing.T) {
t1 := time.Date(2026, 5, 6, 23, 59, 59, 0, time.UTC)
t2 := t1.Add(2 * time.Second) // 2026-05-07 00:00:01
a := WindowStart(t1, 86_400)
b := WindowStart(t2, 86_400)
assert.False(t, a.Equal(b),
"timestamps straddling a 24h-window boundary must land on different buckets — otherwise daily caps never reset")
}
// TestWindowStart_DifferentWindowsHaveDifferentBuckets locks the
// design fork "two policies with different window_seconds on the same
// group produce independent counters". A 24h boundary at noon is NOT
// the same as the 30d boundary that contains it.
func TestWindowStart_DifferentWindowsHaveDifferentBuckets(t *testing.T) {
now := time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC)
short := WindowStart(now, 86_400)
long := WindowStart(now, 2_592_000)
assert.False(t, short.Equal(long),
"the 24h bucket and 30d bucket containing the same instant must differ — independent counters require independent keys")
}
// TestWindowStart_SubMinuteAndMinuteAlignment locks sub-hour windows.
// A 5-minute window must align to multiples of 300s from the unix
// epoch — minute marks 0/5/10/.../55 within an hour, deterministic
// across nodes regardless of clock drift.
func TestWindowStart_SubMinuteAndMinuteAlignment(t *testing.T) {
t1 := time.Date(2026, 5, 6, 14, 12, 30, 0, time.UTC)
t2 := time.Date(2026, 5, 6, 14, 14, 59, 0, time.UTC)
t3 := time.Date(2026, 5, 6, 14, 15, 0, 0, time.UTC)
a := WindowStart(t1, 300)
b := WindowStart(t2, 300)
c := WindowStart(t3, 300)
assert.True(t, a.Equal(b),
"14:12:30 and 14:14:59 fall in the same 5m bucket starting at 14:10:00")
assert.True(t, a.Equal(time.Date(2026, 5, 6, 14, 10, 0, 0, time.UTC)),
"5m bucket containing 14:12 starts at 14:10 — aligned to multiples of 300s from unix epoch")
assert.False(t, a.Equal(c),
"14:15:00 is the start of the next 5m bucket — must not fold into the previous one")
}
// TestWindowStart_ZeroWindowReturnsInputUTC covers the defensive
// path: caller hands a zero / negative window (shouldn't happen, but
// might mid-refactor). The function returns the input as UTC rather
// than dividing by zero.
func TestWindowStart_ZeroWindowReturnsInputUTC(t *testing.T) {
now := time.Date(2026, 5, 6, 12, 30, 45, 0, time.FixedZone("CEST", 2*3600))
got := WindowStart(now, 0)
assert.True(t, got.Equal(now.UTC()), "zero window must not panic — return input as UTC")
}

View File

@@ -0,0 +1,120 @@
package types
import (
"time"
"github.com/rs/xid"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// GuardrailChecks is the configurable parameter set persisted with each
// guardrail. Stored as a JSON blob to keep the table flat.
type GuardrailChecks struct {
ModelAllowlist GuardrailModelAllowlist `json:"model_allowlist"`
PromptCapture GuardrailPromptCapture `json:"prompt_capture"`
}
type GuardrailModelAllowlist struct {
Enabled bool `json:"enabled"`
Models []string `json:"models"`
}
type GuardrailPromptCapture struct {
Enabled bool `json:"enabled"`
RedactPii bool `json:"redact_pii"`
}
// Guardrail is an Agent Network reusable guardrail set persisted per account.
type Guardrail struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Description string
Checks GuardrailChecks `gorm:"serializer:json"`
CreatedAt time.Time
UpdatedAt time.Time
}
// TableName uses an explicit name so guardrail rows live in their own
// table.
func (Guardrail) TableName() string { return "agent_network_guardrails" }
// NewGuardrail returns a new Guardrail with a freshly minted ID.
func NewGuardrail(accountID string) *Guardrail {
now := time.Now().UTC()
return &Guardrail{
ID: "ainguard_" + xid.New().String(),
AccountID: accountID,
Checks: GuardrailChecks{ModelAllowlist: GuardrailModelAllowlist{Models: []string{}}},
CreatedAt: now,
UpdatedAt: now,
}
}
// FromAPIRequest applies the request payload onto the receiver.
func (g *Guardrail) FromAPIRequest(req *api.AgentNetworkGuardrailRequest) {
g.Name = req.Name
if req.Description != nil {
g.Description = *req.Description
}
g.Checks = checksFromAPI(req.Checks)
}
// ToAPIResponse renders the guardrail as the API representation.
func (g *Guardrail) ToAPIResponse() *api.AgentNetworkGuardrail {
created := g.CreatedAt
updated := g.UpdatedAt
return &api.AgentNetworkGuardrail{
Id: g.ID,
Name: g.Name,
Description: g.Description,
Checks: checksToAPI(g.Checks),
CreatedAt: &created,
UpdatedAt: &updated,
}
}
// Copy returns a deep copy of the guardrail.
func (g *Guardrail) Copy() *Guardrail {
clone := *g
if g.Checks.ModelAllowlist.Models != nil {
clone.Checks.ModelAllowlist.Models = append([]string(nil), g.Checks.ModelAllowlist.Models...)
}
return &clone
}
// EventMeta is the audit-log payload for activity events.
func (g *Guardrail) EventMeta() map[string]any {
return map[string]any{"name": g.Name}
}
func checksFromAPI(c api.AgentNetworkGuardrailChecks) GuardrailChecks {
models := append([]string(nil), c.ModelAllowlist.Models...)
if models == nil {
models = []string{}
}
return GuardrailChecks{
ModelAllowlist: GuardrailModelAllowlist{
Enabled: c.ModelAllowlist.Enabled,
Models: models,
},
PromptCapture: GuardrailPromptCapture{
Enabled: c.PromptCapture.Enabled,
RedactPii: c.PromptCapture.RedactPii,
},
}
}
func checksToAPI(c GuardrailChecks) api.AgentNetworkGuardrailChecks {
models := c.ModelAllowlist.Models
if models == nil {
models = []string{}
}
out := api.AgentNetworkGuardrailChecks{}
out.ModelAllowlist.Enabled = c.ModelAllowlist.Enabled
out.ModelAllowlist.Models = models
out.PromptCapture.Enabled = c.PromptCapture.Enabled
out.PromptCapture.RedactPii = c.PromptCapture.RedactPii
return out
}

Some files were not shown because too many files have changed in this diff Show More