mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-27 02:10:00 +00:00
Compare commits
16 Commits
dependabot
...
agent-netw
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
668af0dc4f | ||
|
|
5f130959ea | ||
|
|
5644279888 | ||
|
|
f22ac6d271 | ||
|
|
9f485be2f9 | ||
|
|
c83e46fbe1 | ||
|
|
405607c584 | ||
|
|
29f55d4255 | ||
|
|
3993fa32e4 | ||
|
|
6ade3839aa | ||
|
|
d4d158a8f3 | ||
|
|
6613d194ef | ||
|
|
769e12840d | ||
|
|
350a96c640 | ||
|
|
615631567a | ||
|
|
f4daf59bcd |
6
.github/workflows/golang-test-linux.yml
vendored
6
.github/workflows/golang-test-linux.yml
vendored
@@ -579,10 +579,11 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags devcert -run=^$ -bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
|
||||
env:
|
||||
GIT_BRANCH: ${{ github.ref_name }}
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
@@ -673,12 +674,13 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags=benchmark \
|
||||
-run=^$ \
|
||||
-bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/server/http/...
|
||||
env:
|
||||
GIT_BRANCH: ${{ github.ref_name }}
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals,flate,recordin,unparseable
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
|
||||
109
docs/agent-networks/00-overview.md
Normal file
109
docs/agent-networks/00-overview.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# Agent Networks — overview
|
||||
|
||||
Single-entry point. Feature scope, the module map, and the cross-cutting
|
||||
topics worth keeping in mind, with links into every per-module guide.
|
||||
|
||||
## TL;DR
|
||||
|
||||
Agent Networks introduces an **LLM-aware reverse-proxy middleware system**
|
||||
plus **account-level controls** (budget rules, log collection toggles,
|
||||
PII redaction). The management server synthesises a per-peer middleware
|
||||
chain that the proxy executes on every LLM request; the chain enforces
|
||||
quotas, injects identity, redacts PII, parses tokens/cost, and emits
|
||||
access-log entries. The dashboard exposes the surface as a single **AI
|
||||
Observability** page with four tabs.
|
||||
|
||||
- **Backend** lives in this repo, primarily under
|
||||
`management/server/agentnetwork`, `proxy/internal/middleware`, and
|
||||
`proxy/internal/llm`, with wire contracts in `shared/management`.
|
||||
- **Dashboard** lives in the dashboard repo under
|
||||
`src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`.
|
||||
|
||||
## Reading order
|
||||
|
||||
| # | Doc | Why |
|
||||
|---|-----|-----|
|
||||
| 1 | [01-end-to-end-flows.md](01-end-to-end-flows.md) | Get the three big diagrams in your head first. |
|
||||
| 2 | [modules/10-shared-api.md](modules/10-shared-api.md) | Wire contracts — every other module either produces or consumes these. |
|
||||
| 3 | [modules/21-management-agentnetwork.md](modules/21-management-agentnetwork.md) | The largest module; everything the proxy executes originates here. |
|
||||
| 4 | [modules/30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md) | The generic plugin system on the proxy side. |
|
||||
| 5 | [modules/31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md) | The 8 LLM middlewares that ride on the framework. |
|
||||
| 6 | Everything else in any order. | |
|
||||
|
||||
## Module map
|
||||
|
||||
11 modules. Each is described in detail in its own file under
|
||||
[`modules/`](modules/).
|
||||
|
||||
| # | Module | Risk | BC impact |
|
||||
|---|--------|------|-----------|
|
||||
| 10 | [shared/api](modules/10-shared-api.md) — proto + OpenAPI | Low | Additive only |
|
||||
| 20 | [management/store](modules/20-management-store.md) — SQL persistence | Medium | Auto-migrate (additive) |
|
||||
| 21 | [management/agentnetwork](modules/21-management-agentnetwork.md) — domain layer + synthesizer | **High** | Additive |
|
||||
| 22 | [management/handlers + wiring](modules/22-management-handlers-wiring.md) — HTTP API + gRPC delivery | Medium | Additive |
|
||||
| 30 | [proxy/middleware-framework](modules/30-proxy-middleware-framework.md) — generic plugin system | High | Additive |
|
||||
| 31 | [proxy/middleware-builtin](modules/31-proxy-middleware-builtin.md) — 8 LLM middlewares | High | Additive |
|
||||
| 32 | [proxy/llm-parsers](modules/32-proxy-llm-parsers.md) — SDK adapters + pricing | Medium | Additive |
|
||||
| 33 | [proxy/runtime](modules/33-proxy-runtime.md) — translate + serve + access-log | High | Additive (touches hot path) |
|
||||
| 40 | [dashboard](modules/40-dashboard.md) — UI for everything above | Medium | Sidebar reshape |
|
||||
| 50 | [path-routed-providers](modules/50-path-routed-providers.md) — Vertex AI + Bedrock | Medium | Additive (new catalog entries) |
|
||||
|
||||
The largest and highest-risk module is `management/agentnetwork`: it is
|
||||
the single writer of the middleware chain the proxy executes.
|
||||
|
||||
## Cross-cutting topics
|
||||
|
||||
These are the items most likely to bite production. Each is fully
|
||||
documented in the linked module guide.
|
||||
|
||||
1. **Capture-pointer semantics** (`*bool` for `capture_prompt` and
|
||||
`capture_completion`): nil = legacy emit, false = suppress, true =
|
||||
emit. nil-vs-false must be handled at every JSON hop. See
|
||||
[21-management-agentnetwork.md](modules/21-management-agentnetwork.md)
|
||||
and [31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md).
|
||||
2. **`ProxyMapping.Private` preservation** on per-proxy live updates.
|
||||
Failure mode: `auth` skips `ValidateTunnelPeer` →
|
||||
`CapturedData.UserGroups` empty → `llm_router` denies. See
|
||||
[33-proxy-runtime.md](modules/33-proxy-runtime.md).
|
||||
3. **respInput carrying `UserEmail`/`UserGroups`/`UserGroupNames` onto
|
||||
the response leg** in `reverseproxy.go`. Load-bearing wire that lets
|
||||
`llm_limit_record` ship non-empty `group_ids` on `RecordLLMUsage`. See
|
||||
[33-proxy-runtime.md](modules/33-proxy-runtime.md).
|
||||
4. **Min-wins all-must-pass budget rule semantics**. Every matching
|
||||
rule's remaining quota must be > 0 for the request to proceed; one
|
||||
exhausted rule blocks the whole call. Documented in
|
||||
[21-management-agentnetwork.md](modules/21-management-agentnetwork.md)
|
||||
and the `llm_limit_check` middleware in
|
||||
[31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md).
|
||||
5. **body-tap memory bounds**: per-direction 1 MiB cap, shared 256 MiB
|
||||
budget, `LimitReader(r.Body, limit+1)` for truncation detection with
|
||||
`replayReadCloser` fallback so upstream still sees the full body.
|
||||
`cloneInputFor` deep-copies the body up to 16 times per chain — a
|
||||
perf hot-spot. See
|
||||
[30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md).
|
||||
6. **UpstreamRewrite.AuthHeader bypasses the header denylist**
|
||||
deliberately. The runtime consumer only unpacks it via the
|
||||
trusted upstream-build path. See
|
||||
[30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md).
|
||||
7. **`disable_access_log` default-false semantics**: the synth target
|
||||
sets it true, all other targets leave it false. See
|
||||
[10-shared-api.md](modules/10-shared-api.md).
|
||||
8. **String-typed `decision` / `deny_code`** on
|
||||
`CheckLLMPolicyLimitsResponse` — would benefit from enum pinning
|
||||
before external consumers integrate. See
|
||||
[10-shared-api.md](modules/10-shared-api.md).
|
||||
|
||||
## Explicit non-goals
|
||||
|
||||
- **Reaper / GC pass over stale synth services** — designed but cut from
|
||||
scope.
|
||||
- **URL-sync for tab state on AI Observability** — read path is wired
|
||||
(`?tab=`) but write path isn't. Future work.
|
||||
- **CI golden-file regen-and-diff for `types.gen.go` /
|
||||
`proxy_service.pb.go`** — would catch codegen drift; not yet in place.
|
||||
|
||||
## Where to read the code
|
||||
|
||||
Per-module file scopes are listed in each module guide. Behaviour is
|
||||
covered by Go tests co-located with each package (and an end-to-end
|
||||
chain integration test under `proxy/internal/proxy`).
|
||||
217
docs/agent-networks/01-end-to-end-flows.md
Normal file
217
docs/agent-networks/01-end-to-end-flows.md
Normal file
@@ -0,0 +1,217 @@
|
||||
# End-to-end flows
|
||||
|
||||
Three cross-module mermaid diagrams. Each per-module guide repeats the
|
||||
slice that's relevant to its own scope — these are the canonical
|
||||
top-down views.
|
||||
|
||||
- [Flow A — Config → runtime (synth + deliver)](#flow-a--config--runtime-synth--deliver)
|
||||
- [Flow B — Request lifecycle through the LLM chain](#flow-b--request-lifecycle-through-the-llm-chain)
|
||||
- [Flow C — Budget rule feedback loop](#flow-c--budget-rule-feedback-loop)
|
||||
|
||||
---
|
||||
|
||||
## Flow A — Config → runtime (synth + deliver)
|
||||
|
||||
How an operator's change to a Provider, Policy, Guardrail, Budget Rule,
|
||||
or Settings record ends up as live middleware on a peer's proxy.
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
actor Op as Operator
|
||||
participant UI as Dashboard
|
||||
participant HTTP as management/handlers
|
||||
participant Mgr as agentnetwork.Manager
|
||||
participant Store as management/store (SQL)
|
||||
participant Ctl as network_map.Controller
|
||||
participant Synth as agentnetwork.SynthesizeServices
|
||||
participant Grpc as management gRPC
|
||||
participant Proxy as netbird-proxy
|
||||
participant Xlate as middleware_translate
|
||||
participant Chain as middleware.Chain
|
||||
|
||||
Op->>UI: edit provider/policy/budget/settings
|
||||
UI->>HTTP: REST PUT/POST /api/agent-network/*
|
||||
HTTP->>Mgr: SaveProvider / SavePolicy / SaveBudgetRule / SaveSettings
|
||||
Mgr->>Store: persist (gorm)
|
||||
Mgr-->>Ctl: account change event (Network-Map dirty)
|
||||
loop per connected peer
|
||||
Ctl->>Synth: SynthesizeServices(ctx, store, accountID)
|
||||
Synth->>Store: load providers, policies, guardrails, budget rules, settings
|
||||
Synth-->>Synth: build per-peer Service list
|
||||
Note over Synth: each Service has a middleware<br/>chain with capture_prompt /<br/>capture_completion / redact_pii<br/>baked from account settings
|
||||
Synth-->>Ctl: []rpservice.Service
|
||||
Ctl->>Grpc: NetworkMap push (services + middleware configs)
|
||||
end
|
||||
Grpc-->>Proxy: NetworkMap stream
|
||||
Proxy->>Xlate: translate proto MiddlewareConfig → runtime Spec
|
||||
Xlate->>Chain: register / replace per-service chain
|
||||
Note over Chain: chain replacement is live<br/>(no proxy restart, in-flight<br/>requests unaffected)
|
||||
```
|
||||
|
||||
**Notes on the diagram**
|
||||
|
||||
- The `network_map.Controller` synthesises on every push, not on a
|
||||
timer. A single config change costs O(connected peers × policies ×
|
||||
providers) per push. See [`modules/22-management-handlers-wiring.md`](modules/22-management-handlers-wiring.md).
|
||||
- `SynthesizeServices` is the single source of truth for the wire
|
||||
format the proxy executes. Anything the proxy does that the
|
||||
synthesiser didn't request is a bug. See
|
||||
[`modules/21-management-agentnetwork.md`](modules/21-management-agentnetwork.md).
|
||||
- The translate step (step 13) is the only place that knows the
|
||||
middleware-ID strings on the proxy side. It must reject unknown IDs;
|
||||
silently dropping middlewares would create a security gap (e.g.
|
||||
missing `llm_limit_check` ⇒ unbounded spend). See
|
||||
[`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md).
|
||||
|
||||
---
|
||||
|
||||
## Flow B — Request lifecycle through the LLM chain
|
||||
|
||||
What happens when an agent on the client peer sends a chat-completion /
|
||||
messages request through the synthesised reverse-proxy.
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
actor Agent as Agent (local)
|
||||
participant Px as netbird-proxy
|
||||
participant Auth as auth middleware
|
||||
participant Map as service-mapping
|
||||
participant Req as llm_request_parser
|
||||
participant Rt as llm_router
|
||||
participant Chk as llm_limit_check
|
||||
participant Inj as llm_identity_inject
|
||||
participant Grd as llm_guardrail
|
||||
participant Up as upstream LLM
|
||||
participant Resp as llm_response_parser
|
||||
participant Cost as cost_meter
|
||||
participant Rec as llm_limit_record
|
||||
participant Log as access-log
|
||||
participant MgmtGrpc as management gRPC
|
||||
|
||||
Agent->>Px: POST /v1/chat/completions (OpenAI / Anthropic)
|
||||
Px->>Auth: identify peer (user, groups)
|
||||
Auth->>Map: resolve service from Host + path
|
||||
Map-->>Req: dispatch chain in slot order
|
||||
|
||||
Req->>Req: parse body → provider, model, prompt, token estimate
|
||||
Note over Req: capture_prompt gates raw_prompt<br/>capture (nil = legacy emit,<br/>false = drop, true = emit)
|
||||
Req->>Rt: pass metadata
|
||||
Rt->>Chk: route to upstream candidate
|
||||
|
||||
Chk->>MgmtGrpc: CheckLLMPolicyLimits(provider, model, est_tokens, groups, user)
|
||||
MgmtGrpc-->>Chk: decision = allow / deny + deny_code
|
||||
alt decision == deny
|
||||
Chk-->>Log: emit access-log with deny_code<br/>(if EnableLogCollection)
|
||||
Chk-->>Agent: 429 (or 403 per deny_code)
|
||||
else decision == allow
|
||||
Chk->>Inj: continue
|
||||
Inj->>Inj: inject NetBird identity headers per provider config
|
||||
Inj->>Grd: continue
|
||||
Grd->>Grd: enforce model allowlist
|
||||
Grd->>Up: forward (over WireGuard)
|
||||
Up-->>Resp: response (JSON or SSE stream)
|
||||
Resp->>Resp: parse usage tokens, completion
|
||||
Note over Resp: capture_completion gates raw<br/>completion capture
|
||||
Resp->>Cost: tokens
|
||||
Cost->>Cost: lookup pricing.yaml + compute cost
|
||||
Cost->>Rec: tokens + cost
|
||||
Rec->>MgmtGrpc: RecordLLMUsage(provider, model, prompt_t, completion_t, cost, groups, user)
|
||||
Rec-->>Log: emit access-log entry<br/>(if EnableLogCollection)
|
||||
Log-->>Agent: 200 + body (streamed if SSE)
|
||||
end
|
||||
```
|
||||
|
||||
**Notes on the diagram**
|
||||
|
||||
- The chain runs in synth-defined order. Re-ordering middlewares
|
||||
changes invariants — `llm_limit_check` must precede `llm_router` so
|
||||
a denied request never hits upstream, and `llm_limit_record` must
|
||||
pair with `llm_limit_check` so a successful check is always recorded
|
||||
(or the rate-limit semantics break). See
|
||||
[`modules/31-proxy-middleware-builtin.md`](modules/31-proxy-middleware-builtin.md).
|
||||
- `llm_guardrail` is also where PII redaction happens
|
||||
(`redact_pii = settings.RedactPii`). Phones, emails, credit cards,
|
||||
PII names — see `redact.go` for the full set. See
|
||||
[`modules/31-proxy-middleware-builtin.md`](modules/31-proxy-middleware-builtin.md).
|
||||
- SSE streaming requires special handling on the response side; the
|
||||
parser must handle partial chunks without buffering the whole
|
||||
stream. See [`modules/32-proxy-llm-parsers.md`](modules/32-proxy-llm-parsers.md).
|
||||
- Access-log emission is gated on `settings.EnableLogCollection`. With
|
||||
it OFF, neither the deny nor the allow leg writes an entry — the
|
||||
chain still runs (budget rules are still enforced) but no audit trail
|
||||
is kept. See
|
||||
[`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md).
|
||||
|
||||
---
|
||||
|
||||
## Flow C — Budget rule feedback loop
|
||||
|
||||
How an account's budget rules tighten ceilings on every request and how
|
||||
consumption flows back into the dashboard.
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph Operator
|
||||
DashBud[Dashboard Budget Settings tab]
|
||||
end
|
||||
subgraph Mgmt[Management]
|
||||
Save[POST/PUT /api/agent-network/budget-rules]
|
||||
Store[(SQL store)]
|
||||
Synth[SynthesizeServices]
|
||||
Check[CheckLLMPolicyLimits RPC]
|
||||
Rec[RecordLLMUsage RPC]
|
||||
Cons[/api/agent-network/consumption]
|
||||
end
|
||||
subgraph Proxy[Proxy]
|
||||
Chk[llm_limit_check]
|
||||
RecMw[llm_limit_record]
|
||||
end
|
||||
subgraph DashView[Dashboard Budget Dashboard tab]
|
||||
Panel[AgentConsumptionPanel]
|
||||
end
|
||||
|
||||
DashBud -->|create / update rules| Save
|
||||
Save --> Store
|
||||
Store --> Synth
|
||||
Synth -->|push synth-services to peer| Proxy
|
||||
|
||||
Chk -->|per request| Check
|
||||
Check -->|aggregate matching rules<br/>min-wins all-must-pass| Store
|
||||
Check -->|allow / deny| Chk
|
||||
|
||||
RecMw -->|post-response| Rec
|
||||
Rec -->|tokens + cost + groups + user| Store
|
||||
|
||||
Store -->|read counters| Cons
|
||||
Cons --> Panel
|
||||
```
|
||||
|
||||
**Notes on the diagram**
|
||||
|
||||
- **min-wins all-must-pass** is the core semantic. A budget rule binds
|
||||
to (group set, user set) with a (window, ceiling). At check time,
|
||||
every rule that matches the caller is evaluated; if ANY rule has
|
||||
zero remaining quota the request is denied. This is the most
|
||||
surprising semantic for operators — see the invariants section of
|
||||
[`modules/21-management-agentnetwork.md`](modules/21-management-agentnetwork.md).
|
||||
- The proxy never makes its own budget decisions. It always asks
|
||||
management via `CheckLLMPolicyLimits` and reports back via
|
||||
`RecordLLMUsage`. This keeps account-wide accounting in one place
|
||||
and avoids per-proxy drift.
|
||||
- `RecordLLMUsage` must carry `group_ids` and `user_id` so the
|
||||
decrement hits the right rule(s). The wire that carries those
|
||||
fields onto the response leg is `respInput` in `reverseproxy.go`. See
|
||||
[`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md).
|
||||
- The dashboard's Budget Dashboard tab polls
|
||||
`/api/agent-network/consumption` — not gRPC, not WebSocket. Poll
|
||||
interval lives in `AgentConsumptionPanel.tsx`. See
|
||||
[`modules/40-dashboard.md`](modules/40-dashboard.md).
|
||||
|
||||
---
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Per-module guides: [`modules/`](modules/)
|
||||
- Overview + module map: [`00-overview.md`](00-overview.md)
|
||||
66
docs/agent-networks/README.md
Normal file
66
docs/agent-networks/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# Agent Networks — architecture documentation
|
||||
|
||||
A self-contained set of documents describing the agent-networks feature:
|
||||
an LLM-aware reverse-proxy middleware system plus account-level controls
|
||||
(budget rules, log collection toggles, PII redaction). The management
|
||||
server synthesises a per-peer middleware chain that the proxy executes on
|
||||
every LLM request.
|
||||
|
||||
## What to read first
|
||||
|
||||
1. **[00-overview.md](00-overview.md)** — the single entry point. Feature
|
||||
scope, the module map, and the cross-cutting topics worth keeping in
|
||||
mind, with links to every per-module guide.
|
||||
2. **[01-end-to-end-flows.md](01-end-to-end-flows.md)** — three
|
||||
high-level mermaid diagrams: config-to-runtime synth/delivery,
|
||||
per-request lifecycle through the LLM chain, and the budget-rule
|
||||
feedback loop.
|
||||
3. **Per-module guides** under `modules/` — one file per package. Each
|
||||
describes the module boundary, the file-level layout, its own flow
|
||||
diagrams, the public contracts, the invariants it relies on, and the
|
||||
areas worth the closest attention.
|
||||
|
||||
## Directory layout
|
||||
|
||||
```
|
||||
docs/agent-networks/
|
||||
├── README.md # you are here
|
||||
├── 00-overview.md # feature summary + module map
|
||||
├── 01-end-to-end-flows.md # cross-module mermaid diagrams
|
||||
└── modules/
|
||||
├── 10-shared-api.md # proto + OpenAPI wire contracts
|
||||
├── 20-management-store.md # SQL persistence layer
|
||||
├── 21-management-agentnetwork.md # domain layer + synthesizer (largest)
|
||||
├── 22-management-handlers-wiring.md # HTTP API + gRPC delivery
|
||||
├── 30-proxy-middleware-framework.md # generic plugin system
|
||||
├── 31-proxy-middleware-builtin.md # 8 LLM-aware middlewares
|
||||
├── 32-proxy-llm-parsers.md # OpenAI/Anthropic/Bedrock SDKs + pricing
|
||||
├── 33-proxy-runtime.md # translate + serve + access-log
|
||||
├── 40-dashboard.md # UI for everything above (lives in the dashboard repo)
|
||||
└── 50-path-routed-providers.md # Vertex AI + Bedrock (path-routed, keyfile:: creds, /bedrock prefix)
|
||||
```
|
||||
|
||||
The `40-dashboard.md` module documents code that lives in the **dashboard
|
||||
repo**, not in this repo. The guide is co-located here so backend readers
|
||||
see the full picture in one place.
|
||||
|
||||
## How the per-module guides are structured
|
||||
|
||||
Every `modules/*.md` follows the same template so the docs are easy to
|
||||
scan:
|
||||
|
||||
- **Module boundary** — what this package owns; where it sits in the stack.
|
||||
- **Files** — path / role.
|
||||
- **Architecture & flow** — one or more mermaid diagrams.
|
||||
- **Public contracts** — function signatures, gRPC messages, JSON shapes.
|
||||
- **Invariants** — semantic guarantees the module relies on or enforces.
|
||||
- **Things to scrutinize** — split by correctness / security /
|
||||
concurrency / backward-compat / performance / observability.
|
||||
- **Test coverage** — the test files that lock down behaviour in this
|
||||
module.
|
||||
- **Known limitations / non-goals** — what is intentionally out of scope.
|
||||
- **Cross-references** — upstream/downstream module links + the
|
||||
end-to-end flow + the overview.
|
||||
|
||||
See [00-overview.md](00-overview.md) for the module map and the
|
||||
cross-cutting topics.
|
||||
105
docs/agent-networks/modules/10-shared-api.md
Normal file
105
docs/agent-networks/modules/10-shared-api.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# shared/api — wire contracts (proto + OpenAPI)
|
||||
|
||||
> **Risk level:** Medium — wire-format surface that every other module pins against; backward-compat hinges on field-number discipline more than on logic correctness.
|
||||
> **Backward-compat impact:** Additive only (new proto fields use unallocated numbers, new RPCs default to `Unimplemented`, new OpenAPI schemas/paths are append-only; no existing field/RPC/schema removed or renumbered).
|
||||
|
||||
## Module boundary
|
||||
This module owns the cross-process contract surface between management, proxy, and dashboard. Two artefacts: `shared/management/proto/proxy_service.proto` (management↔proxy gRPC) and `shared/management/http/api/openapi.yml` (dashboard/CLI↔management REST). Both have generated companions checked in (`proxy_service.pb.go`, `proxy_service_grpc.pb.go`, `types.gen.go`) which must travel in lockstep with their sources. `shared/management/status/error.go` is in scope only for the four new typed `NotFound` constructors that the new HTTP handlers return.
|
||||
|
||||
Everything downstream — `management/agentnetwork`, `management/server/http/handlers/*`, `proxy/internal/*`, the dashboard SDK — consumes these types verbatim. The concern here is wire stability and codegen reproducibility, not behaviour: behaviour is covered in the management and proxy module guides.
|
||||
|
||||
`management.proto` and `signalexchange.proto` are unchanged. `status/error.go` only receives four additive constructors (lines 208-227); no existing error types are reshaped.
|
||||
|
||||
## Files
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `shared/management/proto/proxy_service.proto` | Source of truth: 2 new RPCs, 1 new message group (`MiddlewareConfig` + slot enum), additive fields on `PathTargetOptions`, `AccessLog`, `RecordLLMUsageRequest` |
|
||||
| `shared/management/proto/proxy_service.pb.go` | Generated (protoc-gen-go) |
|
||||
| `shared/management/proto/proxy_service_grpc.pb.go` | Generated; adds `CheckLLMPolicyLimits` + `RecordLLMUsage` client/server stubs and `UnimplementedProxyServiceServer` defaults |
|
||||
| `shared/management/http/api/openapi.yml` | 15 new `AgentNetwork*` schemas, 9 new path groups under `/api/agent-network/*` |
|
||||
| `shared/management/http/api/types.gen.go` | Generated (oapi-codegen; see codegen note below) |
|
||||
| `shared/management/status/error.go` | Four `NotFound` constructors for the new resource kinds (lines 208-227) |
|
||||
|
||||
## Architecture & flow
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Dash as Dashboard / CLI
|
||||
participant Mgmt as management (HTTP+gRPC)
|
||||
participant Px as proxy
|
||||
|
||||
Note over Dash,Mgmt: REST (OpenAPI / types.gen.go)
|
||||
Dash->>Mgmt: PUT /api/agent-network/providers (AgentNetworkProviderRequest)
|
||||
Dash->>Mgmt: PUT /api/agent-network/settings (AgentNetworkSettingsRequest)
|
||||
Dash->>Mgmt: GET /api/agent-network/consumption -> [AgentNetworkConsumption]
|
||||
|
||||
Note over Mgmt,Px: gRPC ProxyService (proxy_service.proto)
|
||||
Mgmt-->>Px: SyncMappingsResponse{ ProxyMapping.path[*].options.middlewares,<br/>agent_network, disable_access_log, capture_* }
|
||||
Px->>Mgmt: CheckLLMPolicyLimits(account, user, groups, provider, model)
|
||||
Mgmt-->>Px: decision=allow|deny + selected_policy_id + attribution_group_id + window_seconds
|
||||
Px->>Mgmt: RecordLLMUsage(account, user, group_id, group_ids, window_seconds, tokens, cost)
|
||||
Px->>Mgmt: SendAccessLog(AccessLog{ agent_network=true })
|
||||
```
|
||||
|
||||
The proto changes split into three independent slices: (1) **mapping enrichment** — `PathTargetOptions` grows fields 8-13 so management can ship middleware configs, capture limits, and the agent-network / log-suppression flags down to the proxy without a second RPC; (2) **two new request/response RPCs** (`CheckLLMPolicyLimits`, `RecordLLMUsage`) for per-LLM-request budget arbitration; (3) **observability tag** — `AccessLog.agent_network` so management can route logs to the right surface.
|
||||
|
||||
The OpenAPI side is a thin CRUD surface — every resource (`Provider`, `Policy`, `Guardrail`, `BudgetRule`, `Settings`) follows the same `GET-list / POST / GET / PUT / DELETE` pattern, plus a read-only `/consumption` listing and a catalog endpoint. The `*Request` variants drop server-controlled fields (id, timestamps). `AgentNetworkBudgetRule` deliberately reuses `AgentNetworkPolicyLimits` to keep wire-shape parity with policies.
|
||||
|
||||
## Public contracts added
|
||||
- gRPC RPCs (`proxy_service.proto:52-57`): `CheckLLMPolicyLimits(CheckLLMPolicyLimitsRequest) → CheckLLMPolicyLimitsResponse`, `RecordLLMUsage(RecordLLMUsageRequest) → RecordLLMUsageResponse`. Both unary; default `UnimplementedProxyServiceServer` returns `codes.Unimplemented` (`proxy_service_grpc.pb.go:283-289`).
|
||||
- New messages (`proxy_service.proto:145-175,448-502`): `MiddlewareConfig`, `MiddlewareSlot` enum, `CheckLLMPolicyLimitsRequest`/`Response`, `RecordLLMUsageRequest`/`Response`.
|
||||
- New `PathTargetOptions` fields 8-13 (`proxy_service.proto:124-140`): `capture_max_request_bytes`, `capture_max_response_bytes`, `capture_content_types`, `middlewares`, `agent_network`, `disable_access_log`. All default-false / zero; pre-existing fields 1-7 byte-for-byte unchanged.
|
||||
- `AccessLog.agent_network = 18` (`proxy_service.proto:258-261`).
|
||||
- `RecordLLMUsageRequest.group_ids = 8` (`proxy_service.proto:496-498`) — so the record path can fan out to every applicable budget rule's window without a re-lookup.
|
||||
- 15 new OpenAPI component schemas (`openapi.yml:5072-5829`): `AgentNetworkProvider[Request|Model]`, `AgentNetworkCatalog{Model,Provider,IdentityInjection,HeaderPairInjection,JSONMetadataInjection,ExtraHeader}`, `AgentNetworkPolicy[Request|TokenLimit|BudgetLimit|Limits]`, `AgentNetworkGuardrail[Checks|Request]`, `AgentNetworkConsumption`, `AgentNetworkSettings[Request]`, `AgentNetworkBudgetRule[Request]`.
|
||||
- 9 new path groups (`openapi.yml:12797-13460`): `/api/agent-network/{consumption,settings,budget-rules,budget-rules/{ruleId},catalog/providers,providers,providers/{providerId},policies,policies/{policyId},guardrails,guardrails/{guardrailId}}`.
|
||||
- Four typed NotFound errors (`shared/management/status/error.go:208-227`).
|
||||
|
||||
## Invariants
|
||||
- **Field-number monotonicity.** Every new proto field uses a previously-unallocated number in its message: `PathTargetOptions` 8-13 (was 1-7), `AccessLog` 18 (was 1-17), `RecordLLMUsageRequest` 8. `SendStatusUpdateRequest.inbound_listener = 50` (pre-existing) reserves 50+ for observability extensions, so 8 on `RecordLLMUsageRequest` doesn't conflict.
|
||||
- **Old proxies stay compatible.** Old management never sends `disable_access_log`/`middlewares`/`agent_network` (zero value → existing behaviour); old proxies that don't decode these fields just drop them silently (proto3 unknown-field semantics) — log emission stays on. No pre-existing field number changed: the proto change is insertions only.
|
||||
- **Old management stays compatible.** The two new RPCs are registered on the same `management.ProxyService` descriptor; old proxies hitting them get `codes.Unimplemented` from the unimplemented embed (`proxy_service_grpc.pb.go:283-289`), which is the same fallback pattern `SyncMappings` already documents (`proxy_service.proto:20-21`).
|
||||
- **OpenAPI shapes are append-only.** New schemas are placed at the end of `components.schemas` (line 5072+); new paths at the end of `paths` (line 12797+). No existing schema's `required` list, enum, or property type was changed.
|
||||
- **`*Request` vs response asymmetry.** Read shapes (`AgentNetworkProvider`, `AgentNetworkPolicy`, `AgentNetworkGuardrail`, `AgentNetworkSettings`, `AgentNetworkBudgetRule`) require `created_at`/`updated_at`; the matching `*Request` shapes do not — server fills them. `AgentNetworkProviderRequest.api_key` is write-only (`openapi.yml:5158-5161` "never returned in responses"); reviewers should confirm the response schema (5072-5138) actually omits `api_key`.
|
||||
|
||||
## Things to scrutinize
|
||||
### Correctness
|
||||
- `RecordLLMUsageRequest` carries both `group_id` (singular, the attribution group — field 3) and `group_ids` (plural, full membership — field 8). `b22d5a181` adds field 8 to drive account-budget fan-out; double-check that consumers can't accidentally key counters on the wrong one. Field comments at `proxy_service.proto:489-491` and `496-498` distinguish them but it's the kind of subtle thing a follow-up commit might collapse.
|
||||
- `PathTargetOptions.disable_access_log` is the only field whose default-false meaning **changes semantics** on the proxy side: false → log (status quo), true → suppress. Synthesizer sets `DisableAccessLog = !settings.EnableLogCollection`, so a missing/default settings row yields `EnableLogCollection=false → DisableAccessLog=true → suppressed`. Worth confirming downstream (`agentnetwork.synthesizer`) that operator-defined private services never inherit this flag — the proto field default protects them, but only if synth code is explicit.
|
||||
- `CheckLLMPolicyLimitsResponse.decision` is a free-form `string` (`proxy_service.proto:471`) rather than an enum. Only documented values are "allow" / "deny". An enum would prevent typo drift; consider before this RPC ships to external consumers.
|
||||
- `deny_code` (`proxy_service.proto:478-481`) is documented as "a stable label" but is also a free string. Pin the allowed set somewhere observable to the proxy.
|
||||
|
||||
### Security
|
||||
- `AgentNetworkProvider.api_key` MUST be write-only. Schema split (request has it at line 5158; response omits it) looks correct, but a regression here leaks the upstream provider credential to every dashboard reader. Check that the handler explicitly zeros it on the response path.
|
||||
- `extra_values` / `identity_header_*` headers on `AgentNetworkProvider` get stamped onto upstream requests. Description at `openapi.yml:5099` says "values not declared by the catalog are ignored at synth time" — a contract this module documents but the synthesizer must enforce. Confirm the synth module honours it.
|
||||
- Cluster + subdomain on `AgentNetworkSettings` are documented immutable (`openapi.yml:5686-5694`) and the `AgentNetworkSettingsRequest` (lines 5733-5752) doesn't accept them. Verify the `PUT /api/agent-network/settings` handler can't be tricked by extra JSON keys (oapi-codegen's `additionalProperties: false` is not declared here; spec defaults to permissive).
|
||||
|
||||
### Backward compatibility
|
||||
- The proto change is field-number additive: every previously numbered field keeps the same name + type, and the change is insertions only (no deletions in `proxy_service.proto`), so this holds at the source-text level.
|
||||
- `proxy_service_grpc.pb.go` adds two RPC handlers and registers them in `ProxyService_ServiceDesc.Methods` (lines 543-552). The existing entries are unchanged and order-preserving — gRPC method dispatch is name-keyed, so order doesn't matter, but reviewing the diff (no method renamed/dropped) is still worth a glance.
|
||||
- OpenAPI 3.0 doesn't have a built-in deprecation flow for paths; if any client tooling iterates `paths.*`, the additive routes shouldn't break it, but generated SDKs (especially the dashboard's) need a regen to gain access to `AgentNetwork*`.
|
||||
|
||||
### Codegen pinning
|
||||
- `generate.sh` (`shared/management/http/api/generate.sh:14`) installs `oapi-codegen@latest` rather than a pinned version. **This is a reproducibility gap** — re-running the script later may produce a different `types.gen.go`. Either pin the version in `generate.sh` (e.g. `@v2.7.0`) or document the pin in a `tools.go`.
|
||||
- proto codegen has the protoc / protoc-gen-go version stamped in the generated file header (`proxy_service.pb.go:3-4`).
|
||||
- Regenerate locally and confirm zero diff against the committed `types.gen.go` / `proxy_service.pb.go`.
|
||||
|
||||
## Test coverage
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| None in this scope | The proto and OpenAPI sources are tested transitively by the handler tests (`shared/management/http/handlers/agentnetwork/...`) and by the synthesizer/manager tests (`management/server/agentnetwork/...`). No round-trip serialisation test exists in the `proto/` or `api/` packages themselves. |
|
||||
| `shared/management/proto/*_test.go` | (absent) |
|
||||
| `shared/management/http/api/*_test.go` | (absent) |
|
||||
|
||||
Acceptable for codegen artefacts, but a single golden-file test that re-runs `oapi-codegen` and `protoc` in CI and diffs against the checked-in files would close the reproducibility gap noted above.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
- **No deprecation surface.** Old fields/RPCs are kept silently; there is no `[deprecated = true]` annotation on anything. Acceptable here because nothing is being removed.
|
||||
- **No proto-side validation.** Numeric ranges (e.g. `window_seconds >= 60`, `cost_usd >= 0`, capture-byte clamps) are enforced in the OpenAPI schema via `minimum:` and inside Go code by the proxy/management, but `proto3` itself can't express them; downstream is expected to validate every message.
|
||||
- **`MiddlewareConfig.config_json` is `bytes`** (`proxy_service.proto:163`) — opaque to the proto layer. Schema validity is the middleware factory's problem. This is a deliberate tradeoff (per the comment at 161-162) but worth flagging: a corrupted/malicious config_json can only fail at proxy apply time, not at the wire-decode step.
|
||||
- **No catalog endpoint schema for the catalog itself** — the catalog data ships as a `GET /api/agent-network/catalog/providers` returning `[AgentNetworkCatalogProvider]` (`openapi.yml:13024`), but the catalog source-of-truth lives in `management/server/agentnetwork/catalog`, not here.
|
||||
- The reaper / GC design was cut from scope; no reaper-related types appear here.
|
||||
|
||||
## Cross-references
|
||||
- Downstream: [management/store](20-management-store.md), [management/agentnetwork](21-management-agentnetwork.md), [management/handlers + wiring](22-management-handlers-wiring.md), [proxy/runtime](33-proxy-runtime.md)
|
||||
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level: [../00-overview.md](../00-overview.md)
|
||||
112
docs/agent-networks/modules/20-management-store.md
Normal file
112
docs/agent-networks/modules/20-management-store.md
Normal file
@@ -0,0 +1,112 @@
|
||||
# management/store — persistence for agent-network entities
|
||||
|
||||
> **Risk level:** Medium — six brand-new tables behind AutoMigrate, one upsert-counter table that runs on the request hot path, and one column carrying an encrypted secret.
|
||||
> **Backward-compat impact:** Additive (six new tables created by AutoMigrate; the `Store` interface gains 23 methods, but no existing column/index is touched).
|
||||
|
||||
## Module boundary
|
||||
|
||||
This module is the persistence layer for the Agent Network feature. Everything the management server stores about LLM proxying — providers, policies, guardrails, the per-account settings row, a usage-counter table written on every proxied LLM request, and the account-budget rules — flows through the methods added to `store.Store`. The module owns six tables, six entity types from `management/server/agentnetwork/types`, and a single hot-path upsert (`IncrementAgentNetworkConsumption`) consumed by the proxy fleet.
|
||||
|
||||
Out of scope here: the catalog of provider definitions (compiled-in, no DB), the synthesizer/manager built on top of these CRUDs (covered in [21-management-agentnetwork.md](21-management-agentnetwork.md)), and the HTTP handlers that translate API requests into Save/Delete calls.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `management/server/store/sql_store_agentnetwork.go` | gorm implementations of all 23 store methods |
|
||||
| `management/server/store/sql_store_agentnetwork_budgetrule_test.go` | round-trip + account-scoping coverage against a real sqlite store |
|
||||
| `management/server/store/sql_store.go` | one import, six entities appended to the `AutoMigrate` slice (sql_store.go:40, sql_store.go:141-142) |
|
||||
| `management/server/store/store.go` | 23 methods added to the `Store` interface (store.go:328-354) |
|
||||
| `management/server/store/store_mock_agentnetwork.go` | mockgen output for the new interface surface |
|
||||
|
||||
## Tables added / migrations
|
||||
|
||||
All six tables are created by `db.AutoMigrate` invoked from `NewSqlStore` at sql_store.go:133-143. There is no hand-rolled SQL migration script — the schema is whatever GORM derives from the struct tags.
|
||||
|
||||
- `agent_network_providers` — `Provider.TableName()` at provider.go:76. PK `id`, index on `account_id`, named index `idx_agent_network_provider` on `provider_id`. Carries an at-rest-encrypted `api_key` and ed25519 `session_private_key` (provider.go:35,56). `extra_values` and `models` are JSON blobs (`serializer:json`).
|
||||
- `agent_network_policies` — `Policy.TableName()` at policy.go:70. PK `id`, index on `account_id`. JSON columns: `source_groups`, `destination_provider_ids`, `guardrail_ids`, `limits`.
|
||||
- `agent_network_guardrails` — `Guardrail.TableName()` at guardrail.go:41. PK `id`, index on `account_id`. JSON `checks`.
|
||||
- `agent_network_settings` — `Settings.TableName()` at settings.go:33. PK `account_id` (one row per account), named index `idx_agent_network_settings_cluster_subdomain` on `subdomain` only — the index name implies a composite, but only one column is tagged.
|
||||
- `agent_network_consumption` — `Consumption.TableName()` at consumption.go:46. Composite PK across `(account_id, dim_kind, dim_id, window_seconds, window_start_utc)` — the same tuple the upsert keys on.
|
||||
- `agent_network_budget_rules` — `AccountBudgetRule.TableName()` at budgetrule.go:35. PK `id`, index on `account_id`. JSON `target_groups`, `target_users`, `limits`.
|
||||
|
||||
## CRUD surface added
|
||||
|
||||
Provider, Policy, Guardrail, BudgetRule follow the same pattern: `Get<Kind>ByID`, `GetAccount<Kind>` (list), `Save<Kind>` (upsert), `Delete<Kind>`, with account-scoping enforced by the existing `accountAndIDQueryCondition` / `accountIDCondition` constants (sql_store.go:59-62). Provider additionally exposes `GetAllAgentNetworkProviders` (cross-account, used by the synthesizer). Settings exposes `Get`/`GetByCluster`/`Save` (no delete — one row per account, created on first save). Consumption exposes the upsert `Increment`, a point `Get`, and a cross-window `List`.
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
handlers["HTTP handlers<br/>(management/server/agentnetwork)"] -->|Save/Delete| iface["Store interface<br/>store.go:328-354"]
|
||||
manager["agentnetwork.Manager"] -->|Get*| iface
|
||||
synth["synthesizer<br/>(global)"] -->|GetAllAgentNetworkProviders| iface
|
||||
proxy["proxy fleet<br/>(hot path)"] -->|IncrementAgentNetworkConsumption| iface
|
||||
iface --> sql["SqlStore methods<br/>sql_store_agentnetwork.go"]
|
||||
iface -.gomock.-> mock["MockStore<br/>store_mock_agentnetwork.go"]
|
||||
sql --> gorm["gorm.DB"]
|
||||
gorm --> tables[("6 tables<br/>agent_network_*")]
|
||||
sql --> enc["crypt.FieldEncrypt<br/>(provider only)"]
|
||||
```
|
||||
|
||||
Reads decrypt provider secrets in-place; writes do `provider.Copy().EncryptSensitiveData(...)` before `db.Save` so the caller's in-memory object keeps the plaintext `api_key` (sql_store_agentnetwork.go:88-102). Every list/get takes a `LockingStrength` and applies `clause.Locking{Strength: ...}` when non-`None` — matching the rest of the store. The upsert path uses `clause.OnConflict` with `gorm.Expr` server-side increments so concurrent proxy nodes converge without read-modify-write races (sql_store_agentnetwork.go:321-335).
|
||||
|
||||
## Invariants enforced at the store layer
|
||||
|
||||
- **Account scoping.** Every entity-by-ID method keys on `account_id = ? and id = ?`; no cross-tenant leak path through the API is reachable as long as callers always pass the auth'd `accountID` (sql_store_agentnetwork.go:70,141,201,429).
|
||||
- **NotFound mapping.** `gorm.ErrRecordNotFound` is translated to typed `status.NewAgentNetwork*NotFoundError`; `Delete*` returns NotFound when `RowsAffected == 0` (sql_store_agentnetwork.go:111-113,171-173,231-233,461-463).
|
||||
- **Provider secret encryption at rest.** `SaveAgentNetworkProvider` always encrypts before persist; `Get*` always decrypts after read. The plaintext `api_key` never reaches the DB through this layer (sql_store_agentnetwork.go:31,54,80,90).
|
||||
- **Consumption monotonicity.** The upsert only ever issues `col = col + ?` for the three counter columns — no decrement path exists (sql_store_agentnetwork.go:330-332).
|
||||
- **Window alignment is the caller's responsibility.** The store stamps `WindowStartUTC` as-passed; alignment to epoch happens in `types.WindowStart` at consumption.go:51-58.
|
||||
- **Settings has no Delete.** Intentional — one row per account, created on first save; the row sticks around for the account lifetime.
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
- `SaveAgentNetworkProvider` saves the copy (sql_store_agentnetwork.go:95). The caller's in-memory pointer therefore keeps plaintext `api_key` and any `CreatedAt`/`UpdatedAt` gorm autofills land on the copy, not the original. Callers that need synced timestamps must re-fetch.
|
||||
- `IncrementAgentNetworkConsumption`'s `Create` provides initial counter values (`TokensInput: tokensIn`, etc.) in the row, and on conflict the assignments add the same deltas to the existing values. The insert-vs-update arithmetic is consistent. Cross-check that no engine in use (sqlite, postgres, mysql) silently rejects the `OnConflict` clause — GORM emits engine-specific SQL but `ON DUPLICATE KEY UPDATE` (mysql) vs `ON CONFLICT (...)` (sqlite/postgres) need their unique constraint to match the composite PK on `agent_network_consumption`; it does, by construction.
|
||||
- `IncrementAgentNetworkConsumption` writes `updated_at: time.Now().UTC()` literally inside the assignments map (sql_store_agentnetwork.go:333) — fine, but it's a Go-side timestamp captured at call time, not a DB-side `now()`. Acceptable for an audit field.
|
||||
- `GetAgentNetworkConsumption` returns a zero-valued non-nil row on `ErrRecordNotFound` (sql_store_agentnetwork.go:364-371). Document or rename — a typed sentinel error would be more orthodox; callers must know not to error-check.
|
||||
|
||||
### Concurrency / transactions
|
||||
- Hot-path `IncrementAgentNetworkConsumption` runs outside any explicit transaction; concurrency safety relies entirely on the DB serialising the `ON CONFLICT` upsert against the composite PK. This is correct for postgres and mysql; for sqlite it serialises behind the single writer.
|
||||
- `SaveAgentNetworkSettings` is a blind upsert with no version/etag — concurrent writes from two operators last-write-wins on the collection-toggle flags (settings.go:23-25). Acceptable for admin-curated state but worth flagging.
|
||||
- `Save*Provider` uses `db.Save` on a struct with a PK already set — GORM emits UPDATE or INSERT based on row existence. No upsert clause is attached, so a race between two creates with the same generated `xid` (vanishingly unlikely) would surface as a PK violation.
|
||||
|
||||
### Migration safety
|
||||
- All six tables ride `AutoMigrate` (sql_store.go:141-142). AutoMigrate is additive: new columns get added, but it never drops columns nor narrows types. Three `bool` columns on `agent_network_settings` (`EnableLogCollection`, `EnablePromptCollection`, `RedactPii`) default to false at the GORM/DDL layer for existing rows; the test at sql_store_agentnetwork_budgetrule_test.go:83-112 locks that down on a fresh sqlite. Verify postgres/mysql produce the same default.
|
||||
- The named index `idx_agent_network_settings_cluster_subdomain` on settings.go:15 is declared on only `subdomain`. Either the cluster column also needs `gorm:"index:idx_agent_network_settings_cluster_subdomain"` to make it composite, or the name is misleading.
|
||||
- The named index `idx_agent_network_provider` on `Provider.ProviderID` (provider.go:30) is *not* unique and not scoped to account — two providers in the same account with the same `provider_id` are permitted at the DB layer; uniqueness, if any, must live above the store.
|
||||
|
||||
### Backward compatibility
|
||||
- Net additive. No removed methods, no renamed columns, no schema change to existing tables. Existing deployments running a prior binary continue to work; the first boot of the new binary creates the six tables.
|
||||
- The `Store` interface grows by 23 methods (store.go:330-354); any non-mock external implementer of `store.Store` will fail to compile. The repo only has `SqlStore` + `MockStore`, both updated.
|
||||
|
||||
### Performance (indexes, N+1)
|
||||
- All by-account list queries hit the `idx_account_id` per-table index. No N+1: list methods return the full slice in one query.
|
||||
- `GetAgentNetworkSettingsByCluster` (sql_store_agentnetwork.go:263-277) does a tablescan on `cluster` — no index. Tolerable for the bootstrap label generator (one-shot at provisioning) but worth noting if the call moves onto a hot path.
|
||||
- `ListAgentNetworkConsumption` returns every row ever recorded for the account (sql_store_agentnetwork.go:382-400) — unbounded growth, no `LIMIT`, no time filter. With one row per (dim, window) per request burst, this table grows fastest of the six; a retention job + a paginated list method are obvious follow-ups.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkBudgetRule_RealStore_RoundTrip` | full save → reload of `AccountBudgetRule` including the JSON-serialised `PolicyLimits`, target slices, double-delete returns NotFound (lines 18-59) |
|
||||
| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkBudgetRule_RealStore_ScopedByAccount` | cross-account isolation for budget rules (lines 63-78) |
|
||||
| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkSettings_RealStore_CollectionTogglesRoundTrip` | collection toggles default off, survive save/reload at the set values (lines 83-112) |
|
||||
|
||||
Gap: there is no store-level test for providers (encryption round-trip), policies, guardrails, or `IncrementAgentNetworkConsumption` (concurrent upsert, window-key uniqueness). The consumption upsert is the most performance-sensitive method in this module and the only one without a real-sqlite test.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- No retention / GC for `agent_network_consumption`.
|
||||
- No `Delete` for `Settings` (one row per account, cleared with the account).
|
||||
- No DB-engine-specific tuning — the same struct tags drive sqlite, mysql, postgres.
|
||||
- Provider `extra_values` and `models` are JSON blobs; querying inside them is not supported by design.
|
||||
- `GetAgentNetworkConsumption` "not-found = zero row" contract is convenient but unconventional.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Upstream: [shared/api](10-shared-api.md), [management/agentnetwork](21-management-agentnetwork.md)
|
||||
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level: [../00-overview.md](../00-overview.md)
|
||||
225
docs/agent-networks/modules/21-management-agentnetwork.md
Normal file
225
docs/agent-networks/modules/21-management-agentnetwork.md
Normal file
@@ -0,0 +1,225 @@
|
||||
# management/agentnetwork — domain layer + synth pipeline
|
||||
|
||||
> **Risk level:** High — central business logic + budget enforcement + the source of every middleware-chain change the proxy executes.
|
||||
> **Backward-compat impact:** Additive within the agent-network surface; one **behavioural difference for opted-out accounts** in parser capture (the capture flag is stamped explicitly false instead of being absent — see capture-pointer semantics below). Non-agent-network proxy services are untouched (the synth chain only ships on `agent-net-svc-*` targets).
|
||||
|
||||
## Module boundary
|
||||
|
||||
`management/server/agentnetwork` owns every agent-network entity (providers, policies, guardrails, account budget rules, per-account settings, consumption rows) and **translates them into the in-memory `*rpservice.Service` that the reverse-proxy controller turns into `proto.ProxyMapping`s and pushes to clusters**. It is the *only* writer of the agent-network middleware chain.
|
||||
|
||||
Inside the package: `manager.go` is the CRUD + permissions-gated facade; `synthesizer.go` walks settings + providers + policies + guardrails and emits the per-account service plus every middleware's JSON config; `policyselect.go` runs per-request attribution (min-wins account ceiling, then "drain bigger pool first"); `reconcile.go` diffs successive synth outputs and emits precise Create/Update/Delete proxy-mapping updates plus a peer-map refresh. `labelgen/` mints DNS-safe subdomain labels; `catalog/` is the static provider catalogue; `types/` carries gorm entity structs. The `_realstack_test.go` files in the parent `management/server/` directory exercise the manager + network-map controller end-to-end with no mocks.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `agentnetwork/manager.go` | Manager interface + CRUD + permission gates + bootstrap-settings + reconcile trigger |
|
||||
| `agentnetwork/synthesizer.go` | Settings/policy → wire-format synthesis; sole writer of the proxy middleware chain |
|
||||
| `agentnetwork/policyselect.go` | Per-request policy attribution + account-budget ceiling (min-wins) |
|
||||
| `agentnetwork/reconcile.go` | Per-account synth diff vs in-memory cache → Create/Update/Delete |
|
||||
| `agentnetwork/catalog/catalog.go` | Static provider catalogue (auth headers, identity-injection shapes) |
|
||||
| `agentnetwork/labelgen/{labelgen,words}.go` | DNS-safe subdomain picker + curated wordlist |
|
||||
| `agentnetwork/types/provider.go` | Provider entity + APIKey + Models + ExtraValues + SessionKeys |
|
||||
| `agentnetwork/types/policy.go` | Policy entity + `PolicyLimits` (token + budget) |
|
||||
| `agentnetwork/types/guardrail.go` | Guardrail entity (`ModelAllowlist`, `PromptCapture`) |
|
||||
| `agentnetwork/types/budgetrule.go` | `AccountBudgetRule` (reuses `PolicyLimits`) |
|
||||
| `agentnetwork/types/settings.go` | Per-account `Settings` (Cluster, Subdomain, 3 toggles) |
|
||||
| `agentnetwork/types/consumption.go` | `Consumption` row + `WindowStart` aligner |
|
||||
| `agentnetwork/{synthesizer,policyselect,reconcile,wire_shape}_*test.go` | See test coverage table |
|
||||
| `agentnetwork/types/consumption_test.go` | `WindowStart` alignment proofs |
|
||||
| `agentnetwork/labelgen/labelgen_test.go` | Deterministic picks + exhaustion + fallback |
|
||||
| `management/server/agentnetwork_realstack_test.go` | No-mock provider CRUD → network-map fan-out |
|
||||
| `management/server/agentnetwork_budgetrule_realstack_test.go` | No-mock budget-rule CRUD + settings preserve-immutable |
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### Synthesis (settings/policy → wire format)
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Mutation: provider/policy/guardrail/settings] --> B[managerImpl.reconcile accountID]
|
||||
B --> C{proxyController nil?}
|
||||
C -- yes --> D[accountManager.UpdateAccountPeers only]
|
||||
C -- no --> E[SynthesizeServices]
|
||||
E --> F[loadSettings — NotFound returns ok=false, no synth]
|
||||
F --> G[filterEnabledProviders sorted by CreatedAt]
|
||||
G --> H[filterEnabledPolicies]
|
||||
H --> I[backfillProviderSessionKeys if missing]
|
||||
I --> J[indexProviderGroups: providerID -> sorted source groups]
|
||||
J --> K[buildRouterConfigJSON drops orphan providers]
|
||||
J --> L[buildIdentityInjectConfigJSON per catalog entry]
|
||||
H --> M[mergeGuardrails: union allowlist, OR redact]
|
||||
M --> N[applyAccountCollectionControls account toggle = SOLE capture control]
|
||||
N --> O[marshalGuardrailConfig]
|
||||
K --> P[buildMiddlewareChain 8 middleware entries]
|
||||
L --> P
|
||||
O --> P
|
||||
P --> Q[buildAccountService: AccessGroups=union source groups, noop.invalid target]
|
||||
Q --> R[reconcile.diffMappings vs cache]
|
||||
R --> S[SendServiceUpdateToCluster CREATE/MODIFY/REMOVE]
|
||||
R --> T[accountManager.UpdateAccountPeers — fans synth ACLs into network map]
|
||||
```
|
||||
|
||||
### Budget rule resolution (min-wins, group+user bound)
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[SelectPolicyForRequest in] --> B[checkAccountBudget — runs FIRST, independent of policies]
|
||||
B --> C[GetAccountAgentNetworkBudgetRules]
|
||||
C --> D{for each enabled rule}
|
||||
D --> E{budgetRuleApplies?}
|
||||
E -- no --> D
|
||||
E -- yes --> F[attrGroup = lowestIntersect TargetGroups, in.GroupIDs]
|
||||
F --> G{Token cap enabled?}
|
||||
G -- yes --> H[evalTokenCap user dim + group dim]
|
||||
H --> I{exhausted?}
|
||||
I -- yes --> J[DENY: llm_account.token_cap_exceeded - STOP]
|
||||
I -- no --> K{Budget cap enabled?}
|
||||
G -- no --> K
|
||||
K -- yes --> L[evalBudgetCap user dim + group dim]
|
||||
L --> M{exhausted?}
|
||||
M -- yes --> N[DENY: llm_account.budget_cap_exceeded - STOP]
|
||||
M -- no --> D
|
||||
K -- no --> D
|
||||
D --> O[All rules passed -> fall through to per-policy selection]
|
||||
```
|
||||
|
||||
Key invariant: **rules are checked sequentially and ANY exhausted rule denies (all-must-pass / min-wins).** Untargeted rules (`len(TargetGroups)==0 && len(TargetUsers)==0`) apply to every caller (`policyselect.go:393`).
|
||||
|
||||
### Policy selection (per-peer, per-request)
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Account-budget gate passed] --> B[GetAccountAgentNetworkPolicies]
|
||||
B --> C[filterApplicablePolicies enabled + provider match + group intersect]
|
||||
C --> D{candidates empty?}
|
||||
D -- yes --> E[Allow, empty SelectedPolicyID]
|
||||
D -- no --> F[scoreCandidates -> scoreOne per policy]
|
||||
F --> G[scoreOne: attrGroup + window]
|
||||
G --> H{any cap exhausted?}
|
||||
H -- yes --> I[Drop policy; record last deny code]
|
||||
H -- no --> K[Keep as live candidate]
|
||||
F --> L{live candidates exist?}
|
||||
L -- no --> M[Deny with last exhaustion code]
|
||||
L -- yes --> N[Sort: uncapped wins -> larger group token -> group budget -> user token -> user budget -> oldest CreatedAt]
|
||||
N --> O[winner = scored 0]
|
||||
O --> P[Allow + SelectedPolicyID + AttributionGroupID + WindowSeconds]
|
||||
```
|
||||
|
||||
End-to-end: a mutation calls `managerImpl.reconcile(ctx, accountID)` (`manager.go:205,239,...`). Reconcile defers an `accountManager.UpdateAccountPeers` so the network-map controller re-runs and `injectAllProxyPolicies` picks up the new access groups; with a `proxyController` wired, it re-synthesizes the service, diffs against `reconcileCache[accountID]` (guarded by `reconcileMu`), and emits proto mappings to the cluster derived from the mapping's domain (`reconcile.go:120`). Synthesis is stateless and idempotent. Sole persistent side effect: `backfillProviderSessionKeys` (`synthesizer.go:249`) mints ed25519 keys on legacy provider rows and writes them back.
|
||||
|
||||
At request time the path is independent: the proxy calls `SelectPolicyForRequest` (`policyselect.go:56`); account-budget ceiling first, then per-policy scoring. Token + budget caps share `evalTokenCap` / `evalBudgetCap` — same primitive for account rules and policy limits, `label` differentiates the deny reason. After a served request, `RecordAccountBudgetUsage` (`policyselect.go:415`) fans deltas to every applicable rule's distinct `(dim_kind, dim_id, window)` tuple, deduplicating to prevent double-count when two rules share target+window.
|
||||
|
||||
## Public contracts
|
||||
|
||||
- **Manager interface** (`manager.go:48-80`): CRUD for `Providers/Policies/Guardrails/BudgetRules`; `GetSettings/UpdateSettings` (cluster + subdomain immutable, only the three toggles mutate); `ListConsumption/RecordConsumption(account, kind, dimID, windowSec, in, out, USD)`; `RecordAccountBudgetUsage(account, user, groups, in, out, USD)`; `SelectPolicyForRequest(ctx, PolicySelectionInput) → *PolicySelectionResult{Allow, SelectedPolicyID, AttributionGroupID, WindowSeconds, DenyCode, DenyReason}`.
|
||||
- **`PolicySelectionInput`** (`manager.go:85-90`): `{AccountID, UserID, GroupIDs, ProviderID}` — populated by the proxy from CapturedData + `llm_router` resolution.
|
||||
- **Synthesized middleware chain** (`synthesizer.go:576-657`), order load-bearing — response slot runs reverse-of-slice:
|
||||
|
||||
| Slot | Idx | ID | ConfigJSON shape | CanMutate |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| on_request | 0 | `llm_request_parser` | `{"capture_prompt": <bool>, "redact_pii"?: true}` | – |
|
||||
| on_request | 1 | `llm_router` | `{"providers":[{id, models[], upstream_*, auth_header_*, allowed_group_ids[]}]}` | **true** |
|
||||
| on_request | 2 | `llm_limit_check` | `{}` | – |
|
||||
| on_request | 3 | `llm_identity_inject` | `{"providers":[{provider_id, header_pair?, json_metadata?, extra_headers?}]}` | **true** |
|
||||
| on_request | 4 | `llm_guardrail` | `{"model_allowlist"?, "prompt_capture":{enabled,redact_pii}}` | – |
|
||||
| on_response | 5 | `llm_limit_record` | `{}` (runs LAST at runtime) | – |
|
||||
| on_response | 6 | `cost_meter` | `{}` | – |
|
||||
| on_response | 7 | `llm_response_parser` | `{"capture_completion": <bool>, "redact_pii"?: true}` | – |
|
||||
- **Synthesized service shape** (`synthesizer.go:739`): `Mode=HTTP`, `Private=true`, `Domain=<subdomain>.<cluster>`, `AccessGroups=unionSourceGroups(enabledPolicies)`, one `TargetTypeCluster` target with `Host=noop.invalid:443` (router rewrites per request), `Options.{DirectUpstream,AgentNetwork}=true`, `DisableAccessLog=!settings.EnableLogCollection`, `CaptureMax{Req,Resp}Bytes=1<<20`, `CaptureContentTypes=["application/json","text/event-stream"]`.
|
||||
|
||||
## Invariants
|
||||
|
||||
- **Min-wins / all-must-pass for account budget rules** (`checkAccountBudget`, `policyselect.go:353`): every applicable enabled rule is checked; first exhausted cap denies. Untargeted rules bind every caller.
|
||||
- **Account toggle is the SOLE control for capture enablement.** `applyAccountCollectionControls` (`synthesizer.go:701`) sets `merged.PromptCapture.Enabled = settings.EnablePromptCollection` *unconditionally*.
|
||||
- **Capture-pointer semantics on parser configs** — see "Things to scrutinize" below.
|
||||
- **`EnableLogCollection` ↔ `DisableAccessLog` is the only access-log toggle** (`synthesizer.go:770`). Default off ⇒ access log suppressed.
|
||||
- **`RedactPii` flows verbatim to BOTH parsers** (`synthesizer.go:584-585`) and is OR'd into the merged guardrail (`synthesizer.go:706`).
|
||||
- **Cluster and Subdomain are immutable on Settings.** `UpdateSettings` reloads existing row and overlays only the three toggles (`manager.go:558-561`).
|
||||
- **Orphan providers (no enabled policy authorises them) NEVER reach the router** (`synthesizer.go:351-357`); skipped from `identity_inject` for symmetry.
|
||||
- **Provider creation refuses empty `api_key`** (`manager.go:175`); **deletion refuses while any policy still references it** (`manager.go:265-273`).
|
||||
- **Session keypair stability across provider edits** (`manager.go:226-228`) — server-managed, copied through every `UpdateProvider`, never API-surfaced.
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
|
||||
- **Capture-pointer semantics — `*bool` vs `bool`.** Three states, owned by separate sides:
|
||||
- **Wire JSON this module emits:** `buildParserConfigJSON` (`synthesizer.go:678-693`) *always* stamps the capture field. Agent-network targets ship `"capture_prompt": false` or `"capture_prompt": true` — never absent. Same for `"capture_completion"`. The happy-path test pins `{"capture_prompt":false}` (`synthesizer_test.go:174`).
|
||||
- **Proxy-side parser config (consumer):** parsers decode into `*bool`. Matrix:
|
||||
- `nil` (field absent) → **legacy default = emit**. Preserved for non-agent-network callers and pre-existing tests (the backward-compat hook).
|
||||
- `false` (field present, value false) → **suppress emission entirely**. The behaviour for opted-out agent-network accounts. Without this, `enable_log_collection=true` + `enable_prompt_collection=false` would leak raw user input AND raw model output to the access log.
|
||||
- `true` → emit normally.
|
||||
- **Why the synth always stamps a value:** an agent-network mapping omitting the field would hit legacy "always emit" and re-introduce the leak. The `json.Marshal` error fallback at `synthesizer.go:687` degrades to `{}` — comment-claimed unreachable, but if ever fired re-introduces the leak. Consider fail-closed (return literal `{"capture_prompt":false}`) instead.
|
||||
- **`scoreCandidates` non-cumulative deny code.** Only the *last* exhausted policy's deny code survives (`policyselect.go:188-190`). Iteration order is store's natural order. Auth signal is `len(scored)==0`, so this is informational only — verify no UI depends on "first exhausted policy" semantics.
|
||||
- **`effectiveWindowSeconds` token-wins tiebreak.** When both halves are enabled with different windows, token's window wins (`policyselect.go:482`). Verify `RecordLLMUsage` increments against the winning window only.
|
||||
- **`RecordAccountBudgetUsage` dedup.** Two rules with the same `(kind, dim_id, window)` would double-count without the `tuples` map (`policyselect.go:434-449`). Key includes all three dimensions — correct.
|
||||
- **Fail-closed on bad provider:** unknown catalog id (`synthesizer.go:794-796`) or empty API key (`synthesizer.go:801-803`) drops the **entire** account's synth, not just the bad provider. Confirm matches operator UX.
|
||||
|
||||
### Security
|
||||
|
||||
- **Redact OR-merge:** merged `RedactPii` = account OR guardrail (`synthesizer.go:706`). **Parser-side flag is `settings.RedactPii` only, NOT the OR** — a guardrail-only opt-in does not propagate to parsers. Correct because the account toggle gates capture, but worth noting on the proxy side.
|
||||
- **Group resolution must not leak across accounts.** Every store call carries `accountID` (`policyselect.go:73, 286, 298, 322, 334, 354`); `lowestIntersect` uses caller's claimed groups only (`policyselect.go:494`). Risk surface is upstream (handler populates `in.GroupIDs`).
|
||||
- **`UpdateSettings` preserves immutable Cluster + Subdomain** (`manager.go:558`). A client can't rebind the cluster.
|
||||
- **Provider session keypair backfill writes through `SaveAgentNetworkProvider`** (`synthesizer.go:256`) from a read-shaped call. Idempotent → worst case is a wasted write under concurrent reconcile + snapshot.
|
||||
|
||||
### Concurrency
|
||||
|
||||
- **`reconcileMu`** guards `reconcileCache`. Lock window is narrow — compute diff inside, send outside (`reconcile.go:56-68`).
|
||||
- **`labelRngMu`** guards `labelRng` because `math/rand.Source` is unsafe for concurrent use (`manager.go:638-640`).
|
||||
- **Real-store tests** use `store.NewTestStoreFromSQL` with `t.TempDir()` per test — no shared state, no `t.Parallel()`.
|
||||
- **`RecordAccountBudgetUsage` dedup `tuples` map is per-call;** concurrent calls fan out fully — correct (each request's tokens book once per applicable rule).
|
||||
- **Deferred `UpdateAccountPeers` runs inline after the proxy push** (`reconcile.go:28-35`); a slow call stretches CRUD response time.
|
||||
|
||||
### Backward compatibility
|
||||
|
||||
- **Capture-pointer semantics (restated):** non-agent-network callers see no field → legacy nil-default emit, identical to pre-PR. Agent-network targets always carry an explicit `capture_*` value.
|
||||
- **`TestSynthesizeServices_HappyPath` was updated:** request-parser config moved from `{}` to `{"capture_prompt":false}` (`synthesizer_test.go:174`). External snapshot tests against synth output need updating.
|
||||
- **`MergedGuardrails` retains zeroed `TokenLimits`/`Budget`/`Retention`** even though `Policy.Limits` carries the real values now; `llm_limit_check` is the authoritative enforcement. Comment at `synthesizer.go:940-948` calls this out.
|
||||
|
||||
### Performance
|
||||
|
||||
- **`SynthesizeServices` runs on every controller tick / mutation reconcile.** Cost: 4 store reads + optional per-provider keypair backfill. Sort + index + merge are O(N log N) / O(P × G); dominant cost is JSON marshalling. No nested loops escape these dimensions.
|
||||
- **`reconcile.diffMappings` is O(N + M)** with N=M=1 per account today — effectively constant.
|
||||
- **`SynthesizeServicesForCluster`** (`synthesizer.go:71`) walks every account on a cluster; per-account failures are **swallowed** (`synthesizer.go:91-93`) so a single misconfigured account doesn't drop the cluster. Runs per proxy reconnect.
|
||||
|
||||
### Observability
|
||||
|
||||
- **Activity codes:** `AgentNetwork{Provider,Policy,Guardrail,BudgetRule}{Created,Updated,Deleted}`; `AgentNetworkSettingsUpdated` with `log_collection/prompt_collection/redact_pii` payload (`manager.go:567-571`). **No activity code for `SelectPolicyForRequest` denies** — surfaced via proxy access log only (likely intentional given volume).
|
||||
- **Deny codes** namespaced: `llm_policy.{token,budget}_cap_exceeded`, `llm_account.{token,budget}_cap_exceeded` (`policyselect.go:18-26`).
|
||||
- **Reconcile failures are logged at warn and swallowed** (`reconcile.go:42-44`). Persistent synth failures (e.g. unknown catalog id) silently keep the proxy out of sync — consider a manager-level synth-health surface if this becomes a support burden.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| `synthesizer_test.go` | Mock-store: `HappyPath` (8-mw chain ordering, `{"capture_prompt":false}` baseline); `No{Settings,Providers}`; `Disabled{Provider,Policy}_NoService`; `RouterConfigOrdering`; `PolicyCheckConfig_UnionsSourceGroups`; `OrphanProvider_HasEmptyAllowedGroups`; identity-inject for LiteLLM / Bifrost (overrides + partial disable) / Cloudflare / Portkey / Vercel / OpenRouter / generic non-customizable; `GuardrailMerge_AllowlistUnion_LimitsRestrictive`; `BackfillsMissingSessionKeys`; `HTTPUpstream_KeepsExplicitPort`; `UpstreamURLPath_FlowsToRouter`; `UnknownProviderID_FailsClosed`; `EmptyAPIKey_FailsClosed`. |
|
||||
| `synthesizer_realstore_test.go` | Real-sqlite: `SurvivesStatusToggle` reproduces the disable/re-enable 403 regression; `Reconcile_RealStore_PushesPrivateAfterStatusToggle` extends through reconcile push. |
|
||||
| `synthesizer_guardrail_realstore_test.go` | `PromptCaptureAccountIsSoleControl`; `PromptCaptureFlowsWhenAccountOptsIn`; `AccountRedactWithoutGuardrailRedact`; `NoGuardrail_CaptureOff`. |
|
||||
| `synthesizer_log_collection_realstore_test.go` | `LogCollection{Off_SuppressesAccessLog,On_PermitsAccessLog}` — verifies `DisableAccessLog` propagation through `ToProtoMapping`. |
|
||||
| `synthesizer_parser_redact_realstore_test.go` | **Capture-pointer regression suite:** `ParserConfigsCarryRedactPii`; `ParserConfigsSuppressCaptureWhenLogCollectionOnly` (log=on/prompt=off ⇒ both capture flags false); `ParserConfigsOmitRedactPiiWhenOff`. |
|
||||
| `policyselect_test.go` | Mock-store: `NoApplicablePolicies`; `AllowWithLowestGroupAttribution`; `LargerPoolWinsAcrossUsageLevels`; `StaysOnLargerPoolAfterPartialDrain`; `FallsThroughToSmallerPoolWhenLargerExhausted`; `TiebreakBy{LargerGroupPool,CreatedAt}`; `DeniesWhenAllExhausted`; `UncappedPolicyAlwaysWinsAgainstCapped`; `DisabledPolicyIgnored`; `StoreErrorPropagates`; `RejectsEmptyAccount`; `SharesGroupCounterAcrossPolicies`; `AntiFallThroughOnLowestGroup`; `BudgetOnlyExhaustionDenies`; `BudgetTighterThanTokenWins`. |
|
||||
| `policyselect_realstore_test.go` | Real-sqlite regression guard: `NoApplicablePolicies`; `AllowAndLowestGroupAttribution`; `LargerPoolWins_FallsThroughWhenExhausted`; `BudgetCapDenies`; `GroupCounterSharedAcrossPolicies`; `DisabledPolicyIgnored`. |
|
||||
| `policyselect_account_realstore_test.go` | Account budget rules: `AccountCeilingBindsEvenWithUncappedPolicy` (min-wins); `AccountGroupCeiling`; `AccountTargetUsersBindsOnlyThatUser`; `AccountRuleRecordsToOwnWindow`. |
|
||||
| `reconcile_test.go` | `FirstSynth_EmitsCreate`; `NoChange_EmitsNothingExtra` (re-push as Modified — verify desired); `PolicyRemoved_EmitsDelete`; `NilProxyController_NoOp`; `EmptyAccountID_NoOp`; `ClusterFromMapping`. |
|
||||
| `wire_shape_test.go` | `TestSynthesizedService_WireShape` — proto-shape lockdown via `ToProtoMapping`. Catches "service not matching" (mapping reaches proxy but no SNI/HTTP route). Asserts ID, Domain, Mode, AuthToken, `Private`, `Auth.Oidc=false`, one path `/` + `https://noop.invalid/`, 8 middlewares with correct slot enums, router config `auth_header_value="Bearer sk-test-key"`. |
|
||||
| `labelgen/labelgen_test.go` | `PickUnique_{DeterministicWithSeededRng,AvoidsTakenWordsWhenMostAreReserved,FallsBackWhenAllReserved}`; `UniqueWords_DropsDuplicates`. |
|
||||
| `types/consumption_test.go` | `WindowStart_{AlignedToUnixEpoch,WithinWindowConverges,AcrossWindowsDiverges,DifferentWindowsHaveDifferentBuckets,SubMinuteAndMinuteAlignment,ZeroWindowReturnsInputUTC}`. Bucket alignment so multi-node reads converge. |
|
||||
| `agentnetwork_realstack_test.go` | `ProviderCRUD_FansOutToProxyAndClientPeers` — no-mock end-to-end through real account manager + network-map + agentnetwork: provider create propagates the updated map to both proxy peer and client peer with the synth DNS surface. |
|
||||
| `agentnetwork_budgetrule_realstack_test.go` | `BudgetRuleCRUD_RealManager`; `UpdateSettings_PreservesImmutableAndTogglesCollection`. |
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- **`MergedGuardrails.TokenLimits/Budget/Retention` emit at zero** (`synthesizer.go:940-948`); real enforcement is `Policy.Limits` via `llm_limit_check`. Future cleanup implied.
|
||||
- **Session keys picked from first enabled provider by created_at** (`pickServiceSessionKeys`, `synthesizer.go:270`). Existing session cookies survive provider edits only while the first-by-CreatedAt provider stays in place. Document for operators.
|
||||
- **Reconcile failures silently swallowed** (`reconcile.go:42-44`). Persistent failures keep the proxy out of sync until the next reconcile.
|
||||
- **`scoreCandidates` exposes only the LAST exhaustion's deny code** when multiple policies are exhausted.
|
||||
- **`bootstrapSettingsIfNeeded` failure is non-fatal to provider create** (`manager.go:200`): provider lands, synth is no-op until the next provider create retries the bootstrap.
|
||||
- **Budget rules do not trigger a reconcile** (`manager.go:476-477`). Request-time evaluation only; new rules take effect on the next request without a proxy push.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- **Upstream:** [shared/api](10-shared-api.md), [management/store](20-management-store.md), reverseproxy `service`/`proxy`/`sessionkey` packages, `management/server/permissions` + `activity`.
|
||||
- **Downstream:** [management/handlers (HTTP wiring)](22-management-handlers-wiring.md), [proxy/middleware-builtin](31-proxy-middleware-builtin.md), network-map controller (`injectAllProxyPolicies` fan-out).
|
||||
- **End-to-end flow:** [../01-end-to-end-flows.md](../01-end-to-end-flows.md) — "Provider create → reconcile → proxy push → peer map refresh" and "request → policy select → record" diagrams.
|
||||
- **Top-level:** [../00-overview.md](../00-overview.md)
|
||||
203
docs/agent-networks/modules/22-management-handlers-wiring.md
Normal file
203
docs/agent-networks/modules/22-management-handlers-wiring.md
Normal file
@@ -0,0 +1,203 @@
|
||||
# management/handlers + wiring — HTTP API + gRPC delivery
|
||||
|
||||
> **Risk level:** Medium — the surface is mostly additive, but two changes are load-bearing: `injectAllProxyPolicies` runs on every per-peer compute, and `shallowCloneMapping` must round-trip `Private` (a missed field silently breaks every MODIFIED).
|
||||
> **Backward-compat impact:** Additive on the wire (new routes, new RPCs, new proto fields, new gorm column on `AccessLogEntry`). One management-internal break: `nbhttp.NewAPIHandler` gains a trailing `agentNetworkManager` parameter; `nil` is tolerated and silently skips route registration.
|
||||
|
||||
## Module boundary
|
||||
|
||||
This module is the seam between the public Agent Network HTTP API and the proxy fleet that serves agent traffic. North side: a `/api/agent-network/*` surface (providers, policies, guardrails, budget rules, settings, consumption) on the existing gorilla router, delegating to `agentnetwork.Manager`. Handlers are thin — they translate `api.*` ↔ `types.*`, validate shape, forward. RBAC and event emission stay inside the manager (`manager.go:680-682`).
|
||||
|
||||
South side: `ProxyServiceServer` (`proxy.go`) learns to (a) ship synth services to a proxy on initial snapshot, (b) resolve agent-network domains in `getServiceByDomain` for OIDC/session/tunnel-peer flows, (c) gate LLM requests via `CheckLLMPolicyLimits` + `RecordLLMUsage`, (d) preserve `Private` through `shallowCloneMapping` so per-proxy live updates don't silently flip services public. The network_map controller prepends synth services to `account.Services` on every per-peer compute; `accesslogentry.go` gains an indexed `AgentNetwork` column so the dashboard can filter cheaply.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `handlers/agentnetwork/providers_handler.go` | Catalog + provider CRUD + central `AddEndpoints` |
|
||||
| `handlers/agentnetwork/policies_handler.go` | Policy CRUD + shared `validatePolicy*` |
|
||||
| `handlers/agentnetwork/guardrails_handler.go` | Guardrail CRUD |
|
||||
| `handlers/agentnetwork/budget_handler.go` | Account-level budget rule CRUD |
|
||||
| `handlers/agentnetwork/settings_handler.go` | GET (200+`null` if unbootstrapped) + PUT toggles |
|
||||
| `handlers/agentnetwork/consumption_handler.go` | Read-only consumption rows |
|
||||
| `handlers/agentnetwork/handlers_test.go` | Real-store fixture; wire round-trip + validation |
|
||||
| `handlers/agentnetwork/budget_handler_test.go` | Budget-rule + settings toggles |
|
||||
| `server/http/handler.go` | New `agentNetworkManager` arg; conditional `AddEndpoints` |
|
||||
| `server/permissions/modules/module.go` | New `AgentNetwork` module key |
|
||||
| `internals/server/boot.go` | Wires synthesiser adapter + limits service into proxy server |
|
||||
| `internals/server/modules.go` | `AgentNetworkManager()` lazy-create node |
|
||||
| `internals/controllers/network_map/controller/controller.go` | `injectAllProxyPolicies` replaces 4 `InjectProxyPolicies` calls |
|
||||
| `internals/controllers/network_map/controller/repository.go` | `SynthesizeAgentNetworkServices` repo method |
|
||||
| `internals/modules/reverseproxy/service/service.go` | `MiddlewareConfig`, capture limits, `AgentNetwork`, `DisableAccessLog` + proto |
|
||||
| `internals/modules/reverseproxy/accesslogs/accesslogentry.go` | Indexed `AgentNetwork bool` from proto |
|
||||
| `internals/shared/grpc/proxy.go` | Synth wiring, 2 RPCs, domain fallback, `Private` in clone |
|
||||
| `internals/shared/grpc/proxy_clone_test.go` | Locks every `ProxyMapping` field minus `AuthToken` |
|
||||
| `server/activity/codes.go` | 13 new activity codes (125-137) |
|
||||
|
||||
## HTTP routes added
|
||||
|
||||
All routes inherit the platform's auth middleware. Perms enforced inside `agentnetwork.Manager.requirePermission` (`manager.go:680-682`) on `modules.AgentNetwork`. Permission column shows the `op` passed to `requirePermission` — read = `Read`, etc.
|
||||
|
||||
| Method | Path | Perm | Handler |
|
||||
| ------ | ---- | ---- | ------- |
|
||||
| GET | `/agent-network/catalog/providers` | authn only | `providers_handler.go:43` |
|
||||
| GET | `/agent-network/providers` | read | `providers_handler.go:57` |
|
||||
| POST | `/agent-network/providers` | create | `providers_handler.go:97` |
|
||||
| GET | `/agent-network/providers/{providerId}` | read | `providers_handler.go:77` |
|
||||
| PUT | `/agent-network/providers/{providerId}` | update | `providers_handler.go:132` |
|
||||
| DELETE | `/agent-network/providers/{providerId}` | delete | `providers_handler.go:172` |
|
||||
| GET | `/agent-network/policies` | read | `policies_handler.go:32` |
|
||||
| POST | `/agent-network/policies` | create | `policies_handler.go:72` |
|
||||
| GET | `/agent-network/policies/{policyId}` | read | `policies_handler.go:52` |
|
||||
| PUT | `/agent-network/policies/{policyId}` | update | `policies_handler.go:102` |
|
||||
| DELETE | `/agent-network/policies/{policyId}` | delete | `policies_handler.go:142` |
|
||||
| GET | `/agent-network/guardrails` | read | `guardrails_handler.go:25` |
|
||||
| POST | `/agent-network/guardrails` | create | `guardrails_handler.go:65` |
|
||||
| GET | `/agent-network/guardrails/{guardrailId}` | read | `guardrails_handler.go:45` |
|
||||
| PUT | `/agent-network/guardrails/{guardrailId}` | update | `guardrails_handler.go:95` |
|
||||
| DELETE | `/agent-network/guardrails/{guardrailId}` | delete | `guardrails_handler.go:135` |
|
||||
| GET | `/agent-network/budget-rules` | read | `budget_handler.go:24` |
|
||||
| POST | `/agent-network/budget-rules` | create | `budget_handler.go:64` |
|
||||
| GET | `/agent-network/budget-rules/{ruleId}` | read | `budget_handler.go:44` |
|
||||
| PUT | `/agent-network/budget-rules/{ruleId}` | update | `budget_handler.go:95` |
|
||||
| DELETE | `/agent-network/budget-rules/{ruleId}` | delete | `budget_handler.go:135` |
|
||||
| GET | `/agent-network/settings` | read | `settings_handler.go:53` (200+`null` if no row) |
|
||||
| PUT | `/agent-network/settings` | update | `settings_handler.go:27` |
|
||||
| GET | `/agent-network/consumption` | read | `consumption_handler.go:21` |
|
||||
|
||||
## gRPC RPCs added (or modified)
|
||||
|
||||
| RPC | Direction | Trigger |
|
||||
| --- | --------- | ------- |
|
||||
| `CheckLLMPolicyLimits` | proxy→mgmt unary | Pre-flight gate; returns allow/deny, selected policy, attribution group, window, deny code+reason (`proxy.go:259-301`). `Unimplemented` when limits service is nil. |
|
||||
| `RecordLLMUsage` | proxy→mgmt unary | Post-flight write of tokens+cost against policy-window dimensions + every applicable account budget rule (`proxy.go:303-349`). `window_seconds==0` ⇒ no policy cap, only account fan-out runs. |
|
||||
| `GetMappingUpdate`/`SendServiceUpdate` (stream) | mgmt→proxy | Snapshot (`proxy.go:752-780`) now appends `SynthesizeServicesForCluster`. Live updates use `SendServiceUpdateToCluster` + `shallowCloneMapping`. |
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### HTTP request lifecycle
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant DB as Dashboard
|
||||
participant R as gorilla.Router (/api)
|
||||
participant H as handler (agentnetwork)
|
||||
participant M as agentnetwork.Manager
|
||||
participant S as store.Store
|
||||
participant AM as accountManager (StoreEvent)
|
||||
|
||||
DB->>R: POST /api/agent-network/providers
|
||||
R->>H: createProvider (auth mw sets UserAuth)
|
||||
H->>H: GetUserAuthFromContext + validate(req)
|
||||
H->>M: CreateProvider(userID, provider, bootstrapCluster)
|
||||
M->>M: requirePermission(AgentNetwork, Create)
|
||||
M->>S: SaveAgentNetworkProvider
|
||||
M->>AM: StoreEvent(AgentNetworkProviderCreated)
|
||||
M-->>H: created provider
|
||||
H-->>DB: 200 + api.AgentNetworkProvider JSON
|
||||
```
|
||||
|
||||
### Synth-service delivery via gRPC
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant P as Proxy
|
||||
participant G as ProxyServiceServer
|
||||
participant SM as service.Manager (persisted)
|
||||
participant SA as synthesizerAdapter
|
||||
participant AN as SynthesizeServicesForCluster
|
||||
participant ST as store.Store
|
||||
|
||||
Note over P,G: Initial snapshot
|
||||
P->>G: GetMappingUpdate (stream open)
|
||||
G->>SM: GetServicesForCluster(conn.address)
|
||||
SM-->>G: persisted []*Service
|
||||
G->>SA: SynthesizeServicesForCluster(conn.address)
|
||||
SA->>AN: SynthesizeServicesForCluster(store, clusterAddr)
|
||||
AN->>ST: walk every account; read providers/policies/settings
|
||||
AN-->>SA: in-memory []*Service
|
||||
SA-->>G: []*Service
|
||||
G->>P: response (persisted + synth)
|
||||
|
||||
Note over G,P: Per-request live update
|
||||
G->>G: SendServiceUpdateToCluster(update, clusterAddr)
|
||||
G->>G: shallowCloneMapping(update) %% Private MUST survive
|
||||
G->>P: response with single mapping
|
||||
```
|
||||
|
||||
End-to-end: HTTP write persists rows and emits an activity event; the manager then triggers `proxyController.SendServiceUpdate` so proxies re-render. **The snapshot path is the only one that calls into the synthesiser** — on stream open it pulls persisted services then appends synth services for the cluster. Synth services are never persisted. For OIDC/session/tunnel-peer flows, `getServiceByDomain` falls back to `SynthesizeServicesForCluster(clusterFromDomain(domain))` when persisted lookup misses (`proxy.go:1763-1793`). The network_map contribution is orthogonal: per-peer compute prepends the same synth services to `account.Services` before `InjectProxyPolicies`.
|
||||
|
||||
## Permissions model added
|
||||
|
||||
- `permissions/modules/module.go:22` adds `AgentNetwork Module = "agent_network"`, registered in `All` (`module.go:42`). Standard `operations.{Read,Create,Update,Delete}` matrix.
|
||||
- Handlers don't call `permissionsManager` directly — they extract `UserAuth` and delegate to `agentnetwork.Manager`, which gates every mutation through `requirePermission` (`manager.go:168, 308, 549`, etc.). Confirm your role-set provider has `agent_network` rows for owner/admin/user/billing-admin before merging.
|
||||
- `getCatalogProviders` (`providers_handler.go:43`) intentionally skips RBAC — catalog is global static data.
|
||||
|
||||
## Activity codes added
|
||||
|
||||
`activity/codes.go:244-274` adds Activities 125-137 + string/code mappings (`codes.go:428-444`), following `<domain>.<resource>.<action>` (e.g., `agent_network.provider.create`). Audit-log exporters / SIEM forwarders need to know the new codes.
|
||||
|
||||
## Invariants
|
||||
|
||||
- **Synth services are never persisted.** Snapshot appends after `serviceManager.GetServicesForCluster` (`proxy.go:761-770`); network_map prepends before `InjectProxyPolicies` (`controller.go:117-126`).
|
||||
- **`shallowCloneMapping` must round-trip every `ProxyMapping` field except `AuthToken`** — `proxy_clone_test.go:50-58` enforces via `gproto.Equal`. The bug it guards: a missing `Private` made every MODIFIED arrive `private=false`, the proxy skipped `ValidateTunnelPeer`, `UserGroups` stayed empty, `llm_router` denied `no_authorised_provider`; a restart "fixed" it because the snapshot uses the original mapping.
|
||||
- **Limit-window floor is 60s** (`policies_handler.go:189-220`); enabled cap with both per-group and per-user at zero is rejected. Budget rules reuse the same validator (`budget_handler.go:170`).
|
||||
- **Manager is optional at boot.** `NewAPIHandler` registers routes only when non-nil (`handler.go:129`); `ProxyServiceServer` returns `Unimplemented` from both RPCs when limits service is unwired (`proxy.go:262-265, 306-309`).
|
||||
- **Settings GET on an unbootstrapped account returns 200 + `null`** (`settings_handler.go:65-72`) — not 404.
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
- **`injectAllProxyPolicies` runs on every per-peer compute**: `controller.go:163, 309, 415, 681`. `sendUpdateAccountPeers` is the target of the buffered fan-out — synth runs once per debounced account-update tick **and** once per direct `UpdateAccountPeer`. Cost is O(providers + policies × users-per-group) per account under `LockingStrengthNone`. No per-account synth cache — verify it fits the buffer interval for your largest tenant.
|
||||
- **`clusterFromDomain` strips at the first `.`** (`proxy.go:1784-1792`). A zero-dot domain returns `""` and the synth call walks every account. Confirm no path reaches this with a malformed/internal domain.
|
||||
- **Account-budget `RecordConsumption` fans out even when `window_seconds == 0`** (`proxy.go:341-348`) — intentional. Verify the proxy never sends `RecordLLMUsage` for a request that wasn't actually allowed.
|
||||
|
||||
### Security
|
||||
- Every handler extracts `UserAuth` via `nbcontext.GetUserAuthFromContext` before any work. Routes live behind the standard `/api` mux; bypass list is not extended.
|
||||
- `CheckLLMPolicyLimits` / `RecordLLMUsage` ride the existing **proxy → mgmt** gRPC connection auth. No additional token check inside the RPCs — they trust the connection. Confirm the proxy-side token-verification interceptor in this package gates both.
|
||||
- `RecordLLMUsage` only validates `account_id != ""` (`proxy.go:317-319`). A compromised proxy can attribute cost to any account in its cluster — was already true for prior RPCs but is louder now that data drives denials.
|
||||
|
||||
### Concurrency
|
||||
- `SetAgentNetworkSynthesizer` / `SetAgentNetworkLimitsService` write under `s.mu.Lock`; read paths copy the interface under read lock (`proxy.go:236-247, 260-263, 304-307`). Same pattern as existing `serviceManager`/`proxyController` setters.
|
||||
- Manager writes use `LockingStrengthUpdate`; synth reads use `LockingStrengthNone` — read-after-write via the proxy snapshot can observe a stale view by up to one fan-out tick.
|
||||
- Network_map controller is single-threaded per account; cross-account is parallel.
|
||||
|
||||
### Backward compatibility
|
||||
- `proxy_clone_test.go` is the regression net; any new `ProxyMapping` field must be cloned or explicitly nulled in the test.
|
||||
- `AccessLogEntry` adds indexed `AgentNetwork bool` — implicit AutoMigrate; deploy story must handle table-rewrite cost on high-volume access-log tables.
|
||||
- `TargetOptions` gains seven `omitempty` JSON fields (`service.go:69-94`); on-wire shape stays compatible. `targetOptionsToProto` tests all fields when deciding nil (`service.go:551-556`).
|
||||
- `NewAPIHandler` signature changes — every caller must pass `agentNetworkManager`; `nil` is supported.
|
||||
|
||||
### Observability
|
||||
- 13 new activity codes via `accountManager.StoreEvent` in the manager — confirm dashboard's audit-log UI maps them.
|
||||
- `AccessLogEntry.AgentNetwork` is indexed for the dashboard's agent-network log filter.
|
||||
- New RPCs log at error level on store/selector failures (`proxy.go:284, 327, 332, 348`). Snapshot synth failures degrade to warnings — stream is not aborted (`proxy.go:765`).
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test | Locks down |
|
||||
| ---- | ---------- |
|
||||
| `handlers_test.go::TestPolicyHandler_WindowSecondsRoundTrip` | GET carries `window_seconds`; legacy `window_hours`/`window_days` absent. |
|
||||
| `handlers_test.go::TestPolicyHandler_RejectsSubMinuteWindow` | POST `<60s` returns 4xx. |
|
||||
| `handlers_test.go::TestConsumptionHandler_EmptyAccountReturnsArray` | `/consumption` returns `[]` — never null. |
|
||||
| `handlers_test.go::TestConsumptionHandler_PopulatedAccountListsRows` | RecordConsumption×2 surfaces both with correct tokens/cost/window. |
|
||||
| `budget_handler_test.go::TestBudgetRuleHandler_RoundTrip` | Targets + PolicyLimits shape round-trip. |
|
||||
| `budget_handler_test.go::TestBudgetRuleHandler_ListReturnsArray` | Empty-list shape. |
|
||||
| `budget_handler_test.go::TestBudgetRuleHandler_{RejectsMissingName,RejectsSubMinuteWindow}` | Validation rejections are 4xx. |
|
||||
| `budget_handler_test.go::TestSettingsHandler_GetExposesCollectionToggles` | All four toggles + computed `Endpoint`. |
|
||||
| `proxy_clone_test.go::TestShallowCloneMapping_PreservesAllFieldsExceptAuthToken` | Future-proofs clone; every field round-trips, `AuthToken` dropped. |
|
||||
|
||||
Handler tests use a real sqlite store + real manager + always-allow permissions mock (`handlers_test.go:53-75`). Create/update/delete success paths flow through `accountManager.StoreEvent` which the fixture doesn't wire — covered by manager-level no-mock tests outside this module.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- No pagination on any list endpoint; no bulk endpoints.
|
||||
- Synth result is not cached — every snapshot and every per-peer compute repeats the store walk.
|
||||
- `getSettings` returning `200 + null` is a deliberate dashboard concession.
|
||||
- No rate-limiting beyond the global `/api` rate limiter.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Upstream: [shared/api](10-shared-api.md), [management/agentnetwork](21-management-agentnetwork.md), [management/store](20-management-store.md)
|
||||
- Downstream: [proxy/runtime](33-proxy-runtime.md)
|
||||
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level: [../00-overview.md](../00-overview.md)
|
||||
215
docs/agent-networks/modules/30-proxy-middleware-framework.md
Normal file
215
docs/agent-networks/modules/30-proxy-middleware-framework.md
Normal file
@@ -0,0 +1,215 @@
|
||||
# proxy/middleware-framework — generic plugin system
|
||||
|
||||
> **Risk level:** **High** — every proxied request transits this chain. Budget exhaustion, panic recovery, or chain-close bugs hit the hot path for all targets, not just agent-network ones.
|
||||
> **Backward-compat impact:** Additive at the proxy. The `middleware` and `bodytap` packages are new (`proxy/internal/middleware/middleware.go:1`, `proxy/internal/middleware/bodytap/request.go:13`); existing proxy targets keep working until a chain is bound to them via `Manager.Rebuild`.
|
||||
|
||||
This module is the **framework only** — no LLM/agent-network domain knowledge is required, since every example built into it is generic.
|
||||
|
||||
## Module boundary
|
||||
|
||||
This module is the **framework only**: slots, chains, registry, dispatcher, accumulator, body-tap, output filters. No middleware *implementation* lives here — those land in `proxy/internal/middleware/builtin/*` (covered in module 31). The package contract is:
|
||||
|
||||
1. The proxy hands a `Manager` to its config-apply path. The synth pushes per-path `PathTargetBinding` lists (`proxy/internal/middleware/manager.go:26`) into `Manager.Rebuild`, which resolves each spec via the `Registry`/`Resolver` (`proxy/internal/middleware/registry.go:81-121`) and produces an immutable `Chain` keyed by `serviceID|pathID` (`proxy/internal/middleware/manager.go:410-412`).
|
||||
2. The reverse-proxy handler captures the request body via `bodytap.CaptureRequest`, calls `Chain.RunRequest`, applies returned mutations (already filtered by `chain.applyMutations`), forwards to the upstream behind a `bodytap.CapturingResponseWriter`, then calls `Chain.RunResponse` and `Chain.RunTerminal`.
|
||||
3. Middlewares are inert plugins that receive a deep-cloned `Input` and return an `Output` whose decision/mutations are clamped by the dispatcher's `filterOutput` (`proxy/internal/middleware/dispatcher.go:149-172`).
|
||||
|
||||
Everything that crosses the framework boundary in either direction is value-typed and deep-copied — middlewares cannot mutate the live request directly, and the framework cannot inadvertently leak middleware-owned slices into the request hot path.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| `proxy/internal/middleware/middleware.go` | `Middleware` + `Factory` interfaces. |
|
||||
| `proxy/internal/middleware/types.go` | `Slot`, `FailMode`, `Decision`, all limit constants, `Input`/`Output`/`Mutations`/`UpstreamRewrite`/`AuthHeader` value types. |
|
||||
| `proxy/internal/middleware/spec.go` | Apply-time `Spec` (validated wire shape + runtime-injected fields) and `Clone`. |
|
||||
| `proxy/internal/middleware/registry.go` | `Registry` (factory map, RWMutex) and `Resolver` (Spec → bound `Middleware`). |
|
||||
| `proxy/internal/middleware/manager.go` | `Manager`, `chainTable` reverse index, `Rebuild`/`Invalidate*`, async chain close. |
|
||||
| `proxy/internal/middleware/chain.go` | `Chain.RunRequest`/`RunResponse`/`RunTerminal`, mutation gating, `cloneInputFor`. |
|
||||
| `proxy/internal/middleware/chain_test.go` | Metadata threading, LIFO response order, rewrite gating, UserGroups propagation, terminal accumulation. |
|
||||
| `proxy/internal/middleware/dispatcher.go` | Timeout/panic recovery, fail-mode, error classification, `filterOutput`. |
|
||||
| `proxy/internal/middleware/decision.go` | `RenderDenyResponse`, deny-code regex, status clamp. |
|
||||
| `proxy/internal/middleware/headerpolicy.go` | Compile-in header denylist + `FilterHeaderMutations`. |
|
||||
| `proxy/internal/middleware/bodypolicy.go` | `ValidateBodyReplace` / `ApplyBodyReplace` smuggling guards. |
|
||||
| `proxy/internal/middleware/keys.go` | Metadata key namespace constants. |
|
||||
| `proxy/internal/middleware/metadata.go` | `Accumulator` — allowlist, per-mw/per-request byte caps, redaction. |
|
||||
| `proxy/internal/middleware/metrics.go` | OTel instrument bundle (`proxy.middleware.*`). |
|
||||
| `proxy/internal/middleware/redaction.go` | `Scan` — PEM/JWT/AWS/bearer/Luhn-validated CC patterns. |
|
||||
| `proxy/internal/middleware/bodytap/request.go` | Capture + replay reader, `Budget` semaphore, bypass reason codes. |
|
||||
| `proxy/internal/middleware/bodytap/response.go` | `CapturingResponseWriter` (tee with `PassthroughWriter` for Flusher/Hijacker preservation). |
|
||||
|
||||
## Slot model
|
||||
|
||||
Three slots, declared per-middleware exactly once (`proxy/internal/middleware/types.go:27-41`):
|
||||
|
||||
- **`SlotOnRequest`** (`Slot=1`) — runs **before** the upstream call, in registration order. May `DecisionDeny`, may emit `Mutations` (header add/remove, body replace, `UpstreamRewrite`) when both `Spec.CanMutate` and `Middleware.MutationsSupported()` are true. May emit metadata. Each middleware in the slot sees metadata that earlier ones in the same slot just emitted (`proxy/internal/middleware/chain.go:144-178`) — this is how the framework gives middlewares an intra-slot side channel without a global bag.
|
||||
- **`SlotOnResponse`** (`Slot=2`) — runs **after** the upstream returns, in **reverse** registration order. Cannot deny (clamped in `dispatcher.filterOutput`, `proxy/internal/middleware/dispatcher.go:153-157`). May still mutate response headers in principle, but the current chain only forwards `RewriteUpstream` from on_request, so on_response mutations are observe-only in practice. Threads the same per-slot metadata view as on_request.
|
||||
- **`SlotTerminal`** (`Slot=3`) — runs **after** every on_response middleware has emitted, in registration order. Sees the full accumulated bag plus prior terminal emissions (`chain.go:221-245`). Cannot deny, cannot mutate (`dispatcher.go:168-170`). Designed for sinks (access log, metrics push, audit emitter).
|
||||
|
||||
Splitting a feature across slots (e.g. "parse on the way out, ship on terminal") is the explicit architectural choice — `types.go:7-15` and `types.go:22-25` make it clear no middleware participates in more than one slot.
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### Chain dispatch
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
participant H as proxy HTTP handler
|
||||
participant BT as bodytap.CaptureRequest
|
||||
participant CH as Chain
|
||||
participant DI as Dispatcher
|
||||
participant MW as Middleware (per slot)
|
||||
participant US as Upstream
|
||||
participant CW as CapturingResponseWriter
|
||||
|
||||
H->>BT: CaptureRequest(r, cfg, budget)
|
||||
BT-->>H: body[], truncated, release()
|
||||
H->>CH: RunRequest(ctx, r, Input, Accumulator)
|
||||
loop on_request, registration order
|
||||
CH->>CH: cloneInputFor(in, OnRequest)
|
||||
CH->>DI: Invoke(ctx, spec, mw, call)
|
||||
DI->>MW: mw.Invoke(callCtx, in)
|
||||
MW-->>DI: Output{decision, metadata, mutations?}
|
||||
DI->>DI: filterOutput (clamp deny, gate mutations)
|
||||
DI-->>CH: filtered Output
|
||||
CH->>CH: Accumulator.Emit (allowlist + caps + redact)
|
||||
alt DecisionDeny
|
||||
CH-->>H: denied, merged, rewrite
|
||||
else allow
|
||||
CH->>CH: applyMutations(r, m) and capture rewrite
|
||||
end
|
||||
end
|
||||
CH-->>H: nil, merged, rewrite
|
||||
H->>US: ProxyRequest (with rewrite/mutations applied)
|
||||
US-->>CW: bytes (streamed, tee'd into cap-bounded buf)
|
||||
CW-->>H: passthrough complete
|
||||
H->>CH: RunResponse(ctx, Input{RespBody:CW.Body(),...}, acc)
|
||||
loop on_response, REVERSE order (LIFO)
|
||||
CH->>DI: Invoke (same wrappers)
|
||||
end
|
||||
H->>CH: RunTerminal(ctx, Input{Metadata:full bag}, acc)
|
||||
H->>BT: release() + CW.Release()
|
||||
```
|
||||
|
||||
### Body-tap mechanics (request + response)
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph req[Request capture — bodytap.CaptureRequest]
|
||||
R0[r.Body] --> R1{cfg.MaxRequestBytes > 0?\nUpgrade absent?\nContent-Type allowed?\nCL <= cap?}
|
||||
R1 -- no --> R2[bypass = reason\nbody = nil\nr.Body untouched]
|
||||
R1 -- yes --> R3[Budget.Acquire(cap)]
|
||||
R3 -- denied --> R4[bypass=BypassBudget]
|
||||
R3 -- ok --> R5[io.LimitReader(r.Body, cap+1)\nio.ReadAll]
|
||||
R5 --> R6{len > cap?}
|
||||
R6 -- truncated --> R7[viewable = buf[:cap]\nr.Body = replayReadCloser{buf, tail}]
|
||||
R6 -- whole --> R8[r.Body = NopCloser(bytes.Reader(buf))\nclose original]
|
||||
R7 --> R9[(release captured\nbudget on req end)]
|
||||
R8 --> R9
|
||||
end
|
||||
|
||||
subgraph resp[Response capture — CapturingResponseWriter]
|
||||
W0[client] -.-> CW[Write(p)]
|
||||
CW --> P1[PassthroughWriter.Write(p)\n— bytes leave to client first]
|
||||
P1 --> P2{!stopped?}
|
||||
P2 -- yes --> P3{remaining = cap - buf.Len()}
|
||||
P3 --> P4[buf.Write(p[:take])\nset truncated if take<n]
|
||||
P2 -- no --> P5[silent drop into the tee\n(client write already done)]
|
||||
end
|
||||
```
|
||||
|
||||
The body-tap is the highest-leak-risk surface in this module; three details matter:
|
||||
|
||||
1. **Request capture is "read-and-replay", not "read-and-forward".** `CaptureRequest` always swaps `r.Body` for either a `bytes.Reader` (whole body fit) or a `replayReadCloser` that replays the captured prefix then drains the remaining stream from the original body (`bodytap/request.go:178-201`). This means the **upstream still sees the full body even when the tap truncates**. The original `r.Body` is **not** closed in the truncated branch — `replayReadCloser.Close()` only closes the tail (`bodytap/request.go:199-201`), which is the same reader, so close once on request end is correct, but reviewers should confirm the upstream proxy always reads to EOF (otherwise the tail is leaked).
|
||||
2. **Response capture is a write-through tee.** `CapturingResponseWriter.Write` forwards to the underlying writer **first** (`bodytap/response.go:116-117`), then tees into `buf` under its own mutex. Client never blocks on the tee. `Flusher`/`Hijacker` are preserved via the embedded `responsewriter.PassthroughWriter`. SSE/chunked streams flow through untouched; middlewares only see the bounded prefix.
|
||||
3. **Budget is a single shared semaphore.** `Manager` constructs one `bodytap.Budget` at startup (`manager.go:138-144`, default `256 MiB` from `bodytap/request.go:39`). Every capture pre-acquires its full `MaxRequestBytes` / `MaxResponseBytes` from the budget regardless of actual body size; that prevents a flood of small captures from collectively exceeding the cap, but it also means a misconfigured `MaxRequestBytes = 1 MiB` with 256 concurrent requests already exhausts the default budget. Reviewers should sanity-check the operator-facing defaults that ship with synth-service.
|
||||
|
||||
The framework explicitly aborts capture (and increments `proxy.middleware.capture_bypass_total`) before reading the first byte when `Upgrade`/`Connection: upgrade` is set (`bodytap/request.go:120-125`), when the content-type isn't in the allowlist (`bodytap/request.go:126-128`), or when the advertised `Content-Length` already exceeds the cap (`bodytap/request.go:131-133`). This is the right place to make sure WebSocket upgrades and large file uploads never reach the buffer.
|
||||
|
||||
## Public contracts
|
||||
|
||||
- **`Middleware` interface** (`middleware.go:14-36`): `ID()`, `Version()`, `Slot()`, `AcceptedContentTypes()`, `MetadataKeys()`, `MutationsSupported()`, `Invoke(ctx, *Input) (*Output, error)`, `Close()`. `MetadataKeys()` is the **closed set** the middleware is allowed to emit — the accumulator drops anything outside it (`metadata.go:71-75`). `Close` must be idempotent (called even when `Invoke` was never reached).
|
||||
- **`Factory` interface** (`middleware.go:44-47`): `ID()`, `New(rawConfig []byte) (Middleware, error)`. `RawConfig` is opaque JSON bytes on the wire (`spec.go:6-12`); each factory owns its own typed config.
|
||||
- **`Decision` type** (`types.go:59-69`): `Allow=0`, `Deny=1`, `Passthrough=2`. Default-zero is permissive — important because every middleware that omits `Decision` gets `Allow`. Dispatcher clamps `Deny` to `Passthrough` outside `SlotOnRequest` (`dispatcher.go:153-157`).
|
||||
- **`Mutations`** (`types.go:196-201`): `HeadersAdd`/`HeadersRemove` (filtered through `headerpolicy.go`), `BodyReplace` (gated through `bodypolicy.go`), and `RewriteUpstream`. `RewriteUpstream` is **last-write-wins** within the on_request slot (`chain.go:170-172`, locked down by `TestChain_RunRequest_LatestRewriteWins`).
|
||||
- **Metadata propagation keys** (`keys.go`): all keys live in a single file and follow `^[a-z][a-z0-9_-]*(\.[a-z0-9_-]*)+$` (`metadata.go:8`). Framework-injected error tagging uses `mw.<id>.error_kind` (`keys.go:81`) so operators can distinguish framework-emitted entries from middleware-emitted ones.
|
||||
|
||||
## Invariants
|
||||
|
||||
- **Per-request context isolation.** `cloneInputFor` deep-copies every mutable field (`Headers`, `RespHeaders`, `Metadata`, `Body`, `RespBody`, `UserGroups`, `UserGroupNames`) before each invocation (`chain.go:286-308`). A misbehaving middleware that mutates `in.Headers` only corrupts its own copy.
|
||||
- **Body-tap bounded by capture limit.** Request side uses `io.LimitReader(r.Body, limit+1)` (`bodytap/request.go:152`) — the `+1` is how the code detects truncation (`bodytap/request.go:160`); the surfaced buffer is sliced back down to `limit`. Response side stops teeing once `buf.Len() >= cap` (`bodytap/response.go:121-133`). Neither side can grow the buffer past the configured cap.
|
||||
- **Headers/body redaction order.** Accumulator runs `Scan(value)` **before** counting cost (`metadata.go:81-82`), so the byte budgets are computed against post-redaction sizes. `Scan` order is PEM → JWT → AWS key → bearer → Luhn-validated CC (`redaction.go:25-51`) — the comment block in `redaction.go:8-13` is explicit that this is best-effort, not DLP.
|
||||
- **No middleware can starve the chain.** Every invocation runs inside `context.WithTimeout(ctx, clampTimeout(spec.Timeout))` in a separate goroutine (`dispatcher.go:51-94`), with the deadline race-`select`ed against the result channel. A blocked middleware fires the timeout path, gets fail-mode'd, and `IncError(kind=timeout)`. Timeouts are clamped to `[10ms, 5s]` (`types.go:80-86`, `dispatcher.go:174-185`).
|
||||
- **Panic recovery.** `recover()` captures the panic, logs only the type + a 4 KiB stack prefix (no panic value — avoids leaking secrets the middleware was processing), and produces a `panicError` that flows through fail-mode (`dispatcher.go:64-76`).
|
||||
- **Chain immutability + atomic swap.** `chainTable` is cloned on every `Rebuild`/`Invalidate*` and swapped via `atomic.Pointer` (`manager.go:44-69`, `manager.go:221-300`). Readers (`ChainFor`) are lock-free; writers serialise on `writeMu`. The retired chain is `Close`-d in a background goroutine bounded by `chainCloseTimeout = 2 * MaxTimeout` (`manager.go:21-22`, `manager.go:326-346`), so in-flight invocations finish on the old chain after the swap.
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
|
||||
- **Chain ordering deterministic from synth output?** `Manager.buildChain` iterates `b.Specs` in slice order and appends to `bound` (`manager.go:366-391`); `NewChain` then partitions by slot but **preserves slice order within each slot** (`chain.go:50-60`). So order on the wire = order observed at runtime. Synth must therefore emit specs in the intended execution order — there is no per-spec `Priority` field. Worth flagging.
|
||||
- **Decision short-circuit semantics.** `RunRequest` returns immediately on `DecisionDeny` (`chain.go:164-167`) **with the metadata accumulated so far** plus the `denied.Metadata`. Callers that ignore `merged` on deny will lose framework-injected `mw.<id>.error_kind` entries. The proxy runtime is the only caller; confirm it always feeds `merged` into the access log on the deny path as well.
|
||||
- **`UpstreamRewrite` `AuthHeader` bypass** (`types.go:218-235`). The `AuthHeader`/`StripHeaders` fields *intentionally* bypass the header denylist on the basis that the proxy itself rewrites auth. The denylist still blocks middleware-emitted `HeadersAdd: Authorization=...`. This is a delicate carve-out — review the runtime consumer to confirm only the trusted upstream-build path unpacks `AuthHeader`, never the generic `applyMutations` loop.
|
||||
- **`replayReadCloser.Close` only closes the tail** (`bodytap/request.go:199-201`). The replay buffer doesn't own a resource, so this is correct, but it conflates "replay finished" with "underlying body closed". If a caller `Close()`s without reading to EOF, the original body is closed but the captured prefix is lost; harmless for the proxy path (upstream always reads to EOF) but worth a doc-comment.
|
||||
|
||||
### Security
|
||||
|
||||
- **Body-tap memory bounds.** Discussed above — bounded by `MaxBodyCapBytes = 1 MiB` per direction (`types.go:77`) and the shared `Budget` (default 256 MiB). The concerning case is the **deep-copy in `cloneInputFor`** (`chain.go:300-306`): every middleware invocation gets its **own copy** of `Body` and `RespBody`. A chain of N middlewares with a 1 MiB body allocates N MiB of transient bytes per request. With `MaxMiddlewaresPerChain = 16` (`types.go:103`) that's up to 16 MiB extra per in-flight request. Worth pricing into the budget model.
|
||||
- **Header redaction completeness.** `denyHeaders` (`headerpolicy.go:5-17`) covers the auth/forwarding family and framing (`Content-Length`, `Transfer-Encoding`, `Trailer`). `denyHeaderPrefixes` covers `X-Authenticated-*`, `X-Forwarded-*`, `X-Remote-*`, `X-NetBird-*`. Notably absent: `Range`, `If-Match`/`If-None-Match` (mutation could cause cache poisoning), `Origin`/`Referer`. Not necessarily wrong, but worth a deliberate decision.
|
||||
- **Metadata key collisions across middlewares.** The accumulator has no cross-middleware uniqueness check; two middlewares with the same key in their allowlist can both emit it, and both copies land in `merged` (`metadata.go:51-99`). Downstream consumers must tolerate duplicates. Worth documenting.
|
||||
- **Deny rendering.** `RenderDenyResponse` only allows codes matching `^[a-z][a-z0-9._-]{0,63}$` (`decision.go:9`), redacts/truncates message + detail values, caps `Details` at 8 entries (`decision.go:42-50`), clamps status to `[400,499]\{401}` (`decision.go:65-73`). The deny body type is fixed; middlewares cannot inject arbitrary JSON.
|
||||
|
||||
### Concurrency
|
||||
|
||||
- **Per-request state vs shared state in factories.** Each `Factory.New` is called once per chain build; the returned `Middleware` instance is **shared across all requests** for that chain. `Invoke` must be reentrant. The framework does not enforce this — a buggy middleware that holds per-call state on the struct will silently race. Suggest a `// Invoke must be safe for concurrent use` doc on the interface.
|
||||
- **`chainTable` clone-on-write** is correct, but `addChain`/`removeChain` mutate the *cloned* table before the swap (`manager.go:71-108`), and they're called under `writeMu`. Readers only ever see the post-swap pointer. Good.
|
||||
- **`Chain.inflight` WaitGroup**. `Run*` does `Add(1)`/`Done()` (`chain.go:142-143`, `chain.go:194-195`, `chain.go:225-226`); `Close` waits on it bounded by ctx (`chain.go:75-85`). One concern: a *new* `RunRequest` can `Add(1)` *after* `Close` started waiting if the caller still holds a stale chain pointer. `WaitGroup` does not panic on this if the count was already > 0 at `Wait` time, but it does panic if `Add` happens after `Wait` returns and another `Wait` runs. `Close` is documented one-shot, so single-`Wait` is fine, but callers must drop the chain reference before calling `Close`. Worth a code comment near `Close`.
|
||||
- **Goroutine leaks.** `Dispatcher.Invoke` spawns one goroutine per call and *always* writes to a buffered (cap=1) channel (`dispatcher.go:62-76`), so even if the timeout fires the goroutine completes its send and exits. No leak.
|
||||
- **`closeChainsAsync`** detaches retired chains into a goroutine (`manager.go:326-346`). If `Manager` is never GC'd this is fine, but there's no shutdown hook to wait on outstanding closes. Reviewers should confirm the proxy shutdown path explicitly drains in-flight requests before tearing down `Manager`, or accept that the last chain-close round may be cut short on exit.
|
||||
|
||||
### Performance
|
||||
|
||||
- **Allocations per request.** `cloneInputFor` allocates new slices for `Headers`, `RespHeaders`, `Metadata`, `Body`, `RespBody`, `UserGroups`, `UserGroupNames` — once per middleware per request. For a typical 5-middleware chain on a 1 KiB body that's ~10 small slice allocs plus one `Body` copy each. Not a hot-path crisis, but `sync.Pool` for the per-call `Input` would be a natural follow-up.
|
||||
- **Accumulator allocates a fresh `allowSet` per `Emit` call** (`metadata.go:55-58`). One per middleware per slot pass = up to 48 per request. Cheap, but worth noting.
|
||||
- **Regex cost.** `Scan` runs five regex passes on every accepted metadata value (`redaction.go:25-51`). Bounded by `MaxMetadataValueBytes = 4 KiB` so worst case is small.
|
||||
|
||||
### Observability
|
||||
|
||||
- **Per-middleware metrics.** `proxy.middleware.requests_total{middleware,target_id,outcome}` (`metrics.go:34-41`), `duration_ms`, `invocations_total`, `errors_total{kind}`, `metadata_rejected_total{reason}`, `header_mutation_blocked_total{header}`, `capture_bypass_total{reason}`. Comprehensive surface; operators can alert on `errors_total{kind=panic}` and `errors_total{kind=timeout}` separately. **Latency histogram is in milliseconds with default OTel buckets** — for a 10ms–5s timeout range default buckets cover OK, but a custom bucket set centred on 1–500ms would resolve the agent-network response-parser tail better.
|
||||
- **Decision logs.** Panic logs (`dispatcher.go:69`) include `request_id`, type, and stack but not the panic value (safe). `Chain.Close` logs middleware-close errors at debug (`chain.go:91`). `applyMutations` logs body-replace rejections at warn (`chain.go:278`). No log on the deny path itself — by design, since the access-log terminal middleware is expected to record outcomes.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| `proxy/internal/middleware/chain_test.go:77` | `RunRequest` threads metadata across on_request middlewares (regression for the "later mw can't see earlier mw's emissions" bug). |
|
||||
| `chain_test.go:110` | `RunResponse` reverse-order threading. |
|
||||
| `chain_test.go:142` | `cost_meter`-shaped scenario: response_parser registered after cost_meter still emits *before* cost_meter sees the bag (guards the `cost.skipped=missing_tokens` regression). |
|
||||
| `chain_test.go:178` | `UpstreamRewrite` last-write-wins. |
|
||||
| `chain_test.go:206` | No middleware emits → nil rewrite. |
|
||||
| `chain_test.go:224` | Rewrite filtered when `CanMutate=false`. |
|
||||
| `chain_test.go:245` | `Input.UserGroups` propagates verbatim through `cloneInputFor`. |
|
||||
| `chain_test.go:304` | Terminal middlewares see the full accumulated bag + prior terminal emissions. |
|
||||
|
||||
**Gaps** worth raising with the author:
|
||||
- No direct test for `Dispatcher.Invoke` timeout / panic / fail-mode behaviour at the framework level (covered indirectly by built-in tests, but a unit test pinning `errors_total{kind=...}` labels would be cheap insurance).
|
||||
- No test for `bodytap.CaptureRequest` truncated replay (the upstream-sees-full-body invariant is exactly the kind of thing a regression would silently break).
|
||||
- No test for `Budget` exhaustion behaviour under concurrency.
|
||||
- No test for `Manager.InvalidateMiddleware` + `LiveServiceCheck` race (the auth-revocation race the comment at `manager.go:33-38` calls out is the load-bearing reason for `LiveServiceCheck`).
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- **No middleware-to-middleware RPC.** Side-channel is metadata only.
|
||||
- **No streaming body inspection.** Middlewares see a bounded prefix; SSE / chunked parsing happens against that prefix in the response middleware.
|
||||
- **No per-spec priority.** Order is registration order in the spec slice.
|
||||
- **No retry / circuit-breaker** on middleware errors. Fail-mode is binary (open/closed) and per-spec.
|
||||
- **Mutations cannot rewrite the request URL path or query** — only `RewriteUpstream` can change scheme/host (+ optional path replacement, see `types.go:218-235`).
|
||||
- **Redaction is best-effort.** Explicitly documented in `redaction.go:8-13`. Not a DLP solution.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Upstream wire shape: [../modules/10-shared-api.md](10-shared-api.md) (Spec/RawConfig encoding from management).
|
||||
- Built-in middlewares using this framework: [../modules/31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md).
|
||||
- Runtime wiring (where `Manager`, `Chain`, and `bodytap` are consumed by the HTTP handler): [../modules/33-proxy-runtime.md](33-proxy-runtime.md).
|
||||
- End-to-end request flow including capture + chain dispatch: [../01-end-to-end-flows.md](../01-end-to-end-flows.md).
|
||||
- Top-level architecture: [../00-overview.md](../00-overview.md).
|
||||
365
docs/agent-networks/modules/31-proxy-middleware-builtin.md
Normal file
365
docs/agent-networks/modules/31-proxy-middleware-builtin.md
Normal file
@@ -0,0 +1,365 @@
|
||||
# proxy/middleware-builtin — the LLM chain
|
||||
|
||||
The registry-mounted middleware set the proxy executes on every agent-network
|
||||
LLM request. The two highest-blast-radius areas are the **capture-pointer
|
||||
semantics** and the **limit_check ⇒ limit_record** record-once invariant.
|
||||
|
||||
Sibling module: [32-proxy-llm-parsers.md](./32-proxy-llm-parsers.md) — the SDK
|
||||
adapters + pricing catalog this chain delegates to.
|
||||
|
||||
---
|
||||
|
||||
## Module boundary
|
||||
|
||||
This module is the registry-mounted middleware set the proxy executes on
|
||||
every agent-network LLM request. Each sub-package registers itself via
|
||||
`init()`
|
||||
([builtin.go:32–34](../../../proxy/internal/middleware/builtin/builtin.go));
|
||||
the proxy server anonymous-imports the set
|
||||
([all_test.go:11–19](../../../proxy/internal/middleware/builtin/all_test.go))
|
||||
so the registry is populated at boot. The chain is wired by the management
|
||||
synthesiser and executed by the framework
|
||||
(`proxy/internal/middleware/{chain,dispatcher,accumulator}.go` — both out
|
||||
of scope). Everything here reads from / writes to one envelope: the
|
||||
`middleware.KV` metadata bag plus `middleware.Mutations` for header/body
|
||||
rewrites.
|
||||
|
||||
## The 8 middlewares
|
||||
|
||||
| Name | Slot | Inputs (metadata read) | Outputs (metadata written) | Side effects |
|
||||
|---|---|---|---|---|
|
||||
| `llm_request_parser` | OnRequest | `Input.{URL,Body,BodyTruncated}` | `llm.{provider,model,stream,request_prompt_raw,capture_truncated}` | none |
|
||||
| `llm_router` | OnRequest | `llm.model`, `Input.{URL,UserGroups}` | `llm.{resolved_provider_id,authorising_groups}`, `llm_policy.{decision,reason}` | upstream rewrite + auth strip/inject |
|
||||
| `llm_limit_check` | OnRequest | `llm.{resolved_provider_id,model}`, `Input.{AccountID,UserID,UserGroups}` | `llm.{selected_policy_id,attribution_group_id,attribution_window_seconds}`, `llm_policy.{decision,reason}` | gRPC `CheckLLMPolicyLimits` |
|
||||
| `llm_identity_inject` | OnRequest | `llm.{resolved_provider_id,authorising_groups}`, `Input.{UserEmail,UserID,UserGroups,UserGroupNames}` | none | header strip/inject + optional body rewrite |
|
||||
| `llm_guardrail` | OnRequest | `llm.{model,request_prompt_raw}` | `llm_policy.{decision,reason}`, `llm.request_prompt` | none (model allowlist deny) |
|
||||
| `llm_response_parser` | OnResponse | `llm.provider`, `Input.{RespHeaders,RespBody,Status}` | `llm.{input,output,total,cached_input,cache_creation}_tokens`, `llm.response_completion` | none |
|
||||
| `cost_meter` | OnResponse | `llm.{provider,model}`, token buckets | `cost.usd_total` or `cost.skipped` | pricing lookup |
|
||||
| `llm_limit_record` | OnResponse | `llm.{attribution_group_id,attribution_window_seconds,input_tokens,output_tokens}`, `cost.usd_total` | none | gRPC `RecordLLMUsage` |
|
||||
|
||||
[all_test.go:26–40](../../../proxy/internal/middleware/builtin/all_test.go)
|
||||
locks the ID set; adding or removing one is a conscious extension.
|
||||
|
||||
## Files
|
||||
|
||||
| File | LOC | Notes |
|
||||
|---|---:|---|
|
||||
| `builtin.go` | 86 | Registry + `FactoryContext` (ctx, data dir, meter, logger, mgmt client) |
|
||||
| `all_test.go` | 41 | Locks the 8-ID registry surface |
|
||||
| `agentnetwork_chain_integration_test.go` | 319 | Live sqlite + real gRPC bufconn; gate→recorder wire path |
|
||||
| `llm_request_parser/*` | 162 / 66 / 356 | Provider detection, body parse, prompt extraction with capture-pointer gating |
|
||||
| `llm_router/*` | 385 / 84 / 586 | Three-pass route selection (model → groups → path-prefix) |
|
||||
| `llm_limit_check/*` | 196 / 38 / 182 | Pre-flight `CheckLLMPolicyLimits` (2s, fail-open) |
|
||||
| `llm_identity_inject/*` | 440 / 108 / 666 | HeaderPair (LiteLLM) + JSONMetadata (Portkey) + ExtraHeaders |
|
||||
| `llm_guardrail/*` | 176 / 82 / 75 / 219 / 217 | Model allowlist + optional prompt capture with PII redaction |
|
||||
| `llm_response_parser/*` | 258 / 222 / 43 / 433 / 169 / 111 | Buffered + SSE accumulation; AWS event-stream accumulator (`streaming_bedrock.go`) for Bedrock; capture-pointer gates completion emit |
|
||||
| `cost_meter/*` | 181 / 84 / 439 | Token → USD via `proxy/internal/llm/pricing` |
|
||||
| `llm_limit_record/*` | 144 / 35 / 191 | Post-flight `RecordLLMUsage` (5s, debug-on-error) |
|
||||
|
||||
## Per-middleware
|
||||
|
||||
### llm_request_parser
|
||||
|
||||
Detects the LLM provider via `llm.DetectParser` (URL sniff) or by name via
|
||||
`llm.ParserByName` when synthesiser stamps `provider_id`
|
||||
([middleware.go:96–99](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)).
|
||||
**Path-routed providers short-circuit first:** `parseVertexPath` and
|
||||
`parseBedrockPath` ([middleware.go:85–94](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go))
|
||||
pull the model + vendor out of the URL before parser selection runs — Vertex
|
||||
from `/v1/projects/.../publishers/{pub}/models/{model}:{action}` (publisher →
|
||||
vendor via `vertexPublisherVendor`), Bedrock from `/model/{id}/{action}` with
|
||||
`normalizeBedrockModel` stripping the region prefix + version suffix. See
|
||||
[50-path-routed-providers.md](./50-path-routed-providers.md) for the full path
|
||||
grammar. For body-routed providers it decodes the body into `RequestFacts`
|
||||
(model + stream) and extracts the prompt. On
|
||||
`capture_prompt=true` (or absent — see capture-pointer semantics below) the
|
||||
prompt is run through `llm_guardrail.RedactPII` when `redact_pii=true` and
|
||||
truncated rune-safely to 3500 bytes
|
||||
([middleware.go:109–122](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)).
|
||||
**Key invariant:** redaction is parser-side, not guardrail-side — access-log
|
||||
reads `llm.request_prompt_raw` directly.
|
||||
|
||||
### llm_router
|
||||
|
||||
Three-pass route selection in `matchRoute`
|
||||
([middleware.go:241–300](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)):
|
||||
filter by `Models` claim → vendor-pin (a vendor-tagged request never crosses to
|
||||
another vendor's route) → filter by `AllowedGroupIDs` intersection → model
|
||||
precedence over path → tie-break by longest `UpstreamPath` prefix match.
|
||||
Model-miss returns `llm_policy.model_not_routable`; known-but-unauthorised
|
||||
returns `llm_policy.no_authorised_provider`. **Key invariant:** auth-header
|
||||
strip+inject rides on `UpstreamRewrite.{StripHeaders,AuthHeader}`
|
||||
([middleware.go:606–646](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
— NOT `HeadersAdd/HeadersRemove` — because the framework's mutation gate
|
||||
blocks `Authorization` on the generic header path.
|
||||
|
||||
**Path-routed providers route before the model table.** `Invoke` checks
|
||||
`isVertexPath` / `isBedrockPath`
|
||||
([middleware.go:138–216](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
ahead of the model lookup, so a path-carried model can't be claimed by a
|
||||
same-vendor body-routed provider. `matchPathRoute` enforces the route's `Models`
|
||||
allowlist (empty = catch-all) even though the model came from the URL.
|
||||
Two path-only behaviours:
|
||||
- **Vertex unmeterable publisher** — when `llm_request_parser` emits no
|
||||
`llm.provider` (e.g. Gemini/`google`), the router denies with
|
||||
`llm_policy.unmeterable_publisher` (403) rather than forward it uncounted.
|
||||
- **GCP token minting** — when the route carries `GCPServiceAccountKeyB64`
|
||||
(set from a `keyfile::` api_key), `gcpBearer` mints + caches a short-lived
|
||||
OAuth2 token per request instead of injecting a static value; a bad key or
|
||||
unreachable token endpoint denies with `llm_policy.upstream_auth_failed`
|
||||
(502). Bedrock uses its static bearer token directly (no minting).
|
||||
- **`/bedrock` prefix** — an optional `/bedrock` gateway-namespace prefix is
|
||||
accepted and stripped via `RewriteUpstream.StripPathPrefix` so the native
|
||||
`/model/...` path reaches the upstream.
|
||||
|
||||
Full treatment in [50-path-routed-providers.md](./50-path-routed-providers.md).
|
||||
|
||||
### llm_limit_check
|
||||
|
||||
Pre-flight gate. Reads `llm.resolved_provider_id`, calls
|
||||
`CheckLLMPolicyLimits` with a 2s context timeout
|
||||
([middleware.go:24, 97–106](../../../proxy/internal/middleware/builtin/llm_limit_check/middleware.go)),
|
||||
on allow stamps `llm.selected_policy_id`, `llm.attribution_group_id`,
|
||||
`llm.attribution_window_seconds`. **Key invariant:** fail-open. Nil
|
||||
`MgmtClient`, empty provider id, or RPC error returns `allowNoAttribution()`
|
||||
— management outage doesn't take down every LLM request. Operators audit via
|
||||
the access-log; a future flag may switch this to fail-closed.
|
||||
|
||||
### llm_identity_inject
|
||||
|
||||
Dispatches per-rule between LiteLLM-shaped `HeaderPair`
|
||||
([middleware.go:169](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go))
|
||||
and Portkey-shaped `JSONMetadata`
|
||||
([middleware.go:292](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go)).
|
||||
Identity is the peer's email (or `UserID` fallback); tags are the
|
||||
**authorising-groups intersection** emitted by `llm_router`, not the full
|
||||
`UserGroups` — a peer in 5 groups authorised under 1 only tags as that 1.
|
||||
**Anti-spoof:** every `HeadersAdd` is preceded by a `HeadersRemove` of the
|
||||
same name; the framework runs `Remove` before `Add` so client-supplied
|
||||
identity never reaches the upstream. Body-level inject (`tags_in_body`,
|
||||
`end_user_id_in_body`) is skipped on empty / truncated / non-JSON bodies so
|
||||
header attribution stays intact.
|
||||
|
||||
### llm_guardrail
|
||||
|
||||
Model allowlist deny + optional prompt-capture-with-redaction. Allowlist
|
||||
match is case-insensitive via `normaliseModel`; empty allowlist disables the
|
||||
check. Prompt capture reads `llm.request_prompt_raw` and emits
|
||||
`llm.request_prompt` only when `prompt_capture.enabled`
|
||||
([middleware.go:149–165](../../../proxy/internal/middleware/builtin/llm_guardrail/middleware.go)).
|
||||
**Key invariant:** `RedactPII` is the exported function the parsers call —
|
||||
single PII contract across all three keys.
|
||||
|
||||
### llm_response_parser
|
||||
|
||||
Buffered and SSE paths share one `Invoke`
|
||||
([middleware.go:102–127](../../../proxy/internal/middleware/builtin/llm_response_parser/middleware.go)):
|
||||
content-type sniffing dispatches to `invokeBuffered` (JSON, status<400) or
|
||||
`invokeStreaming` (text/event-stream, partial bodies tolerated). Streaming
|
||||
delegates to `accumulateStream`
|
||||
([streaming.go:21–30](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go))
|
||||
using `llm.NewScanner`. A third path, `accumulateBedrockStream`
|
||||
([streaming_bedrock.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock.go)),
|
||||
decodes the AWS binary event-stream (`application/vnd.amazon.eventstream`)
|
||||
returned by Bedrock's `-stream` actions — InvokeModel `chunk` frames wrap a
|
||||
base64 Anthropic event, Converse frames carry text + a trailing usage block.
|
||||
Cached / cache-creation buckets emit only when non-zero, preserving the existing
|
||||
token schema.
|
||||
|
||||
### cost_meter
|
||||
|
||||
Reads `llm.provider` + `llm.model` + token buckets, looks up per-1k rate via
|
||||
`pricing.Loader`, emits `cost.usd_total` or a closed-set `cost.skipped`
|
||||
reason (`missing_provider/model/tokens`, `unparseable_tokens`, `zero_tokens`,
|
||||
`unknown_model`). Loader's hot-reload goroutine is bound to proxy-lifetime
|
||||
context via `startReloader`. **Key invariant:** provider-shape switch lives
|
||||
in `pricing.Table.Cost` (sibling doc) — `cost_meter` stays provider-agnostic.
|
||||
|
||||
### llm_limit_record
|
||||
|
||||
Post-flight write. Always returns `DecisionAllow`; response has already been
|
||||
served so RPC errors mustn't surface (logged at `Debugf`). Skip-on-no-signal
|
||||
at line 81 (zero tokens + zero cost). **Key invariant:** the
|
||||
skip-on-missing-attribution guard at line 98 is a safety net independent of
|
||||
the framework's deny short-circuit — if the gate denied and the framework
|
||||
still runs the recorder, the recorder skips on absent
|
||||
`UserID`+`groupID`+`UserGroups` and no phantom counter materialises.
|
||||
|
||||
## Full-chain diagram (canonical order)
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[HTTP request] --> B[llm_request_parser<br/>OnRequest]
|
||||
B -->|llm.provider, llm.model,<br/>llm.stream, llm.request_prompt_raw| C[llm_router<br/>OnRequest]
|
||||
C -->|llm.resolved_provider_id,<br/>llm.authorising_groups,<br/>upstream rewrite + auth| D[llm_limit_check<br/>OnRequest]
|
||||
D -->|deny path| Z1[403 llm_policy.*]
|
||||
D -->|allow + llm.selected_policy_id,<br/>llm.attribution_group_id,<br/>llm.attribution_window_seconds| E[llm_identity_inject<br/>OnRequest]
|
||||
E -->|header strip+inject<br/>+ optional body rewrite| F[llm_guardrail<br/>OnRequest]
|
||||
F -->|deny: model_blocked| Z2[403 llm_policy.model_blocked]
|
||||
F -->|allow + llm.request_prompt| G[upstream LLM call]
|
||||
G --> H[llm_response_parser<br/>OnResponse]
|
||||
H -->|llm.{input,output,total,cached_input,cache_creation}_tokens,<br/>llm.response_completion| I[cost_meter<br/>OnResponse]
|
||||
I -->|cost.usd_total or cost.skipped| J[llm_limit_record<br/>OnResponse]
|
||||
J --> K[response to client]
|
||||
```
|
||||
|
||||
## limit_check ⇒ limit_record record-once invariant
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant LC as llm_limit_check
|
||||
participant M as management gRPC
|
||||
participant U as upstream LLM
|
||||
participant LR as llm_limit_record
|
||||
participant DB as sqlite consumption table
|
||||
|
||||
LC->>M: CheckLLMPolicyLimits (2s)
|
||||
alt allow
|
||||
M-->>LC: selected_policy_id, attribution_group_id, window_s
|
||||
LC->>U: stamps attribution metadata
|
||||
U-->>LR: response + tokens (via llm_response_parser + cost_meter)
|
||||
LR->>M: RecordLLMUsage (5s, debug-on-error)
|
||||
M->>DB: increment (user, group, window) row
|
||||
else deny
|
||||
M-->>LC: llm_policy.token_cap_exceeded
|
||||
Note over LR: framework short-circuits; even if invoked,<br/>recorder skips on absent UserID+groupID+UserGroups
|
||||
else mgmt nil / rpc error
|
||||
LC-->>LC: allowNoAttribution() — fail open
|
||||
Note over LR: no window_s ⇒ recorder books only account-level<br/>budget rules (which run independently)
|
||||
end
|
||||
```
|
||||
|
||||
The integration test
|
||||
[agentnetwork_chain_integration_test.go](../../../proxy/internal/middleware/builtin/agentnetwork_chain_integration_test.go)
|
||||
exercises all three branches against a real sqlite store + bufconn gRPC —
|
||||
no mocks. Tests: `TestChain_AllowPath_StampsAttributionAndRecordsCounter`
|
||||
(line 130), `TestChain_DenyPath_GateRejectsAndNoConsumptionWritten` (line
|
||||
207), `TestChain_CapExhaustTransition` (line 265).
|
||||
|
||||
## Public contracts (per-middleware JSON config)
|
||||
|
||||
| Middleware | Config shape |
|
||||
|---|---|
|
||||
| `llm_request_parser` | `{provider_id?, redact_pii?, capture_prompt?: *bool}` ([factory.go:19–37](../../../proxy/internal/middleware/builtin/llm_request_parser/factory.go)) |
|
||||
| `llm_router` | `{providers: [{id, models, upstream_scheme, upstream_host, upstream_path?, auth_header_name, auth_header_value, allowed_group_ids}]}` |
|
||||
| `llm_limit_check` | `{}` — pulls `MgmtClient` from `FactoryContext` |
|
||||
| `llm_identity_inject` | `{providers: [{provider_id, header_pair?|json_metadata?, extra_headers?}]}` |
|
||||
| `llm_guardrail` | `{model_allowlist: []string, prompt_capture: {enabled, redact_pii}}` |
|
||||
| `llm_response_parser` | `{redact_pii?, capture_completion?: *bool}` |
|
||||
| `cost_meter` | `{pricing_path?}` (basename inside data-dir; defaults `pricing.yaml`) |
|
||||
| `llm_limit_record` | `{}` — same pattern as `llm_limit_check` |
|
||||
|
||||
All factories accept empty / null / `{}` / whitespace as zero-value config;
|
||||
only structurally invalid JSON is rejected so misconfig surfaces at chain
|
||||
build time.
|
||||
|
||||
## Invariants
|
||||
|
||||
1. **limit_check ↔ limit_record paired.** They MUST appear together. Gate
|
||||
stamps attribution metadata on the request leg; recorder reads it on the
|
||||
response leg. If a chain contains only the recorder, the
|
||||
skip-on-missing-attribution guard at
|
||||
[llm_limit_record/middleware.go:81–87, 98–103](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go)
|
||||
keeps counters consistent but no enforcement runs. Only-gate means
|
||||
counters never tick and headroom appears infinite.
|
||||
|
||||
2. **`capture_prompt` / `capture_completion` pointer semantics.** Both are
|
||||
`*bool`. `nil` = "preserve legacy emit" (back-compat default for
|
||||
non-agent-network callers and pre-toggle tests). `false` = suppress the
|
||||
key entirely (access-log row carries zero prompt / completion content).
|
||||
`true` = emit. The synthesiser sets the pointer explicitly to the
|
||||
account's `EnablePromptCollection` toggle. The handling lives
|
||||
in [llm_request_parser/factory.go:55–61](../../../proxy/internal/middleware/builtin/llm_request_parser/factory.go)
|
||||
and the symmetric [llm_response_parser/middleware.go:62–68](../../../proxy/internal/middleware/builtin/llm_response_parser/middleware.go);
|
||||
a missing pointer must not be treated as `false` (that would suppress
|
||||
capture for legacy non-agent-network callers).
|
||||
`redact_pii` is an orthogonal `bool` controlling **form** of emitted
|
||||
content, not whether it's emitted.
|
||||
|
||||
3. **`redact_pii` is parser-side.** Both parsers import
|
||||
`llm_guardrail.RedactPII` and run it BEFORE stamping the metadata bag.
|
||||
Load-bearing because the access-log sink reads `llm.request_prompt_raw`
|
||||
and `llm.response_completion` directly — by the time `llm_guardrail`
|
||||
runs its own pass on `llm.request_prompt`, the raw key has already been
|
||||
stamped. Tests: `TestInvoke_RedactPii_RedactsBeforeEmittingRawPrompt`,
|
||||
`TestInvoke_RedactPii_RedactsCompletionBeforeEmit`.
|
||||
|
||||
4. **Metadata allowlist enforcement.** Every middleware declares
|
||||
`MetadataKeys()`. The framework accumulator drops any KV outside that
|
||||
allowlist. When adding a new key, also extend the docstring in
|
||||
`middleware/keys.go`.
|
||||
|
||||
5. **Closed deny-code set.** All deny paths emit one of:
|
||||
`llm_policy.model_not_routable`, `llm_policy.no_authorised_provider`,
|
||||
`llm_policy.model_blocked`, `llm_policy.token_cap_exceeded`,
|
||||
`llm_policy.unmeterable_publisher` (path-routed Vertex publisher with no
|
||||
parser → 403), `llm_policy.upstream_auth_failed` (GCP token mint failure →
|
||||
502), or the management-supplied code on `llm_limit_check`. These surface
|
||||
verbatim; arbitrary middleware text never reaches the wire.
|
||||
|
||||
## Things to scrutinise
|
||||
|
||||
**Correctness.** `llm_router` model match treats an empty `Models` slice as
|
||||
"claim every model"
|
||||
([middleware.go:238–248](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
for gateway-style providers — confirm no real provider record ships with an
|
||||
empty `Models` by accident. Path-prefix tie-break falls back to declaration
|
||||
order when no candidate prefix-matches, so the synthesiser must emit a
|
||||
deterministic order. `llm_limit_record` discards `strconv.ParseInt` errors
|
||||
([middleware.go:78–80](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go))
|
||||
— relies on `llm_response_parser` always emitting parseable values; spot-check
|
||||
the streaming partial path on truncated bodies.
|
||||
|
||||
**Security.** Auth headers must NEVER appear on `Mutations.HeadersAdd/Remove`
|
||||
for the router — a direct headers path would bypass the framework gate. The
|
||||
capture-pointer handling is the kind of place a bug ships PII to logs
|
||||
silently; every synthesiser config path must set the pointer explicitly.
|
||||
`llm_identity_inject` body inject silently skips on a
|
||||
non-object `metadata` field
|
||||
([middleware.go:262–270](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go))
|
||||
— header path still attributes, but body-level tag-budget enforcement
|
||||
doesn't run for that request.
|
||||
|
||||
**Concurrency.** `cost_meter` shares a `pricing.Loader` via
|
||||
`atomic.Pointer[Table]`; readers always see a consistent table. Every
|
||||
middleware is a stateless value receiver. Integration test uses real bufconn
|
||||
gRPC — race detector is the meaningful bar.
|
||||
|
||||
**Perf.** Hot path is `lookupKV` linear scan over <10 KVs; `cost_meter.Cost`
|
||||
is O(1); SSE accumulation is single-pass. No map allocation per call.
|
||||
|
||||
**Observability.** Every deny stamps `llm_policy.decision=deny` and a
|
||||
matching `llm_policy.reason` — access-log can pivot on either.
|
||||
`llm_limit_record` only logs at `Debugf` on RPC failure
|
||||
([middleware.go:125–130](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go));
|
||||
operators need an alternate signal (metric on `RecordLLMUsage` failures) for
|
||||
counter accuracy.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| File | Tests | Notes |
|
||||
|---|---:|---|
|
||||
| `all_test.go` | 1 | Registry surface lock |
|
||||
| `agentnetwork_chain_integration_test.go` | 3 | Allow/deny/cap-exhaust vs live sqlite + bufconn gRPC |
|
||||
| `llm_request_parser/middleware_test.go` | 18 | `provider_id` bypass, redaction, capture-pointer, rune-safe truncation |
|
||||
| `llm_router/middleware_test.go` | 19 | Three-pass match, deny codes, path-prefix tie-break, header strip+inject |
|
||||
| `llm_limit_check/middleware_test.go` | 6 | Allow/deny, fail-open on nil mgmt / RPC error, attribution stamping |
|
||||
| `llm_identity_inject/middleware_test.go` | 28 | HeaderPair, JSONMetadata, ExtraHeaders, body inject, anti-spoof |
|
||||
| `llm_guardrail/middleware_test.go` | 15 | Allowlist case-insensitivity, prompt capture toggle, deny shape |
|
||||
| `llm_guardrail/redact_test.go` | 15 | Email, SSN, phone (E.164 + NA), bearer, IPv4; fixture-driven |
|
||||
| `llm_response_parser/middleware_test.go` | 18 | Buffered OAI+Anthro, capture-pointer, redact, truncation |
|
||||
| `llm_response_parser/streaming_test.go` | 7 | OAI usage frame, Anthro message_delta, truncated body best-effort |
|
||||
| `cost_meter/middleware_test.go` | 17 | Each skip reason, provider-shape, pricing loader integration |
|
||||
| `llm_limit_record/middleware_test.go` | 7 | Skip-on-no-signal, skip-on-missing-attribution, RPC failure swallowed |
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Sibling: [32-proxy-llm-parsers.md](./32-proxy-llm-parsers.md) — SDK adapters
|
||||
+ SSE framer + pricing loader.
|
||||
- Path-routed providers (Vertex AI + Bedrock), `keyfile::` credential, GCP
|
||||
token minting, `/bedrock` prefix:
|
||||
[50-path-routed-providers.md](./50-path-routed-providers.md).
|
||||
- Upstream config: `management/server/agentnetwork/synthesizer` (out of scope).
|
||||
- Framework: `proxy/internal/middleware/{chain,dispatcher,accumulator,registry}.go`.
|
||||
- Metadata key registry: `proxy/internal/middleware/keys.go`.
|
||||
- gRPC surface: `proto.ProxyServiceClient.{CheckLLMPolicyLimits,RecordLLMUsage}`.
|
||||
392
docs/agent-networks/modules/32-proxy-llm-parsers.md
Normal file
392
docs/agent-networks/modules/32-proxy-llm-parsers.md
Normal file
@@ -0,0 +1,392 @@
|
||||
# proxy/llm-parsers — SDK adapters + pricing + SSE
|
||||
|
||||
The runtime-agnostic LLM library: the OpenAI Responses API (`/v1/responses`)
|
||||
and the older Chat Completions API (`/v1/chat/completions`), the Anthropic
|
||||
Messages API (`/v1/messages`), the SSE wire format (`event:` / `data:` lines,
|
||||
`\n\n` framing, CRLF tolerance), and per-provider token accounting (OpenAI's
|
||||
cached-prompt **subset** vs Anthropic's cache_read **additive** model). The
|
||||
pricing table's per-provider cost formula is the highest-leverage place a
|
||||
small bug would silently mis-bill operators.
|
||||
|
||||
Sibling module: [31-proxy-middleware-builtin.md](./31-proxy-middleware-builtin.md)
|
||||
— the 8 middlewares that consume this package's parsers + pricing loader.
|
||||
|
||||
---
|
||||
|
||||
## Module boundary
|
||||
|
||||
`proxy/internal/llm` is the runtime-agnostic LLM library shared by every
|
||||
middleware that needs to understand provider-specific shapes. Zero
|
||||
proxy-framework dependencies:
|
||||
|
||||
- `parser.go` — `Parser` interface, `Provider` enum, public factories
|
||||
(`Parsers`, `DetectParser`, `ParserByName`).
|
||||
- `openai.go` / `anthropic.go` / `bedrock.go` — per-provider `Parser` impls.
|
||||
- `sse.go` — SSE scanner (`Scanner`, `Event`, `NewScanner`).
|
||||
- `errors.go` — sentinels callers branch on with `errors.Is`.
|
||||
- `pricing/` — embedded-default + hot-reload override table with
|
||||
symlink-safe Unix loader (build-tagged stub elsewhere).
|
||||
- `fixtures/` — captured request/response/stream bodies the tests replay.
|
||||
|
||||
The package carries zero proxy-framework dependencies so the same parsers can
|
||||
be reused later by a WASM adapter
|
||||
([parser.go:1–6](../../../proxy/internal/llm/parser.go)).
|
||||
|
||||
## Files
|
||||
|
||||
| File | LOC | Notes |
|
||||
|---|---:|---|
|
||||
| `parser.go` | 104 | Interface + factories + `Provider{Unknown,OpenAI,Anthropic}` enum |
|
||||
| `openai.go` | 347 | Chat Completions + Completions + Responses API; cached_tokens subset |
|
||||
| `openai_test.go` | 222 | 11 tests; fixture replay + cached/Responses-API matrix |
|
||||
| `anthropic.go` | 172 | Messages + legacy `/v1/complete`; cache_read + cache_creation additive |
|
||||
| `anthropic_test.go` | 154 | 7 tests including streaming-extraction-skipped contract |
|
||||
| `bedrock.go` | 190 | AWS Bedrock InvokeModel (snake_case) + Converse (camelCase) response shapes; model lives in URL path |
|
||||
| `bedrock_test.go` | — | InvokeModel + Converse usage shapes; AWS event-stream content-type → `ErrStreamingUnsupported` on buffered `ParseResponse` |
|
||||
| `sse.go` | 117 | `bufio`-backed scanner; CRLF normalised; trailing-event handling |
|
||||
| `sse_test.go` | 175 | 12 tests; fixture replay + multiline + size limits |
|
||||
| `parser_test.go` | 53 | `Parsers()`, `DetectParser`, provider enum values |
|
||||
| `errors.go` | 31 | 6 sentinels: `Err{Unknown,Unsupported}Provider/Model`, `Err{NotLLM,Malformed}Response`, `ErrStreamingUnsupported`, `ErrMalformedRequest` |
|
||||
| `pricing/pricing.go` | 421 | `Loader`, `Table`, `Entry`; embedded defaults + atomic swap + mtime reload |
|
||||
| `pricing/pricing_unix.go` | 69 | `O_NOFOLLOW` + fstat-from-FD + 1 MiB cap |
|
||||
| `pricing/pricing_other.go` | 21 | Stub returning "not supported on this platform" |
|
||||
| `pricing/pricing_test.go` | 432 | 21 tests — symlink rejection, reload race, path traversal, oversize |
|
||||
| `pricing/defaults_pricing.yaml` | 85 | go:embed source of truth |
|
||||
| `fixtures/*` | 21–59 | OAI chat/responses/stream + Anthro messages/stream + pricing starter |
|
||||
|
||||
## Request body → parser dispatch
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[HTTP request<br/>URL + JSON body] --> B{ParserByName?<br/>provider_id config set}
|
||||
B -- yes --> P[matched Parser]
|
||||
B -- no --> C[DetectParser]
|
||||
C --> D{loop Parsers<br/>OpenAIParser, AnthropicParser}
|
||||
D -- DetectFromURL match --> P
|
||||
D -- no match --> X[ok=false<br/>middleware skips]
|
||||
P --> E[ParseRequest body]
|
||||
E -->|err: ErrMalformedRequest| Y[middleware emits provider only]
|
||||
E --> F[RequestFacts<br/>model + stream]
|
||||
P --> G[ExtractPrompt body]
|
||||
G --> H[joinMessages<br/>extractContentParts<br/>decodeStringOrJoin]
|
||||
H --> I[prompt text<br/>or empty]
|
||||
F --> J[stamps llm.model + llm.stream]
|
||||
I --> K[stamps llm.request_prompt_raw<br/>subject to capture_prompt gate]
|
||||
```
|
||||
|
||||
OpenAI's URL hints
|
||||
([openai.go:27–33](../../../proxy/internal/llm/openai.go)) include
|
||||
both `/v1/chat/completions` and the bare `/chat/completions` — the latter
|
||||
covers Cloudflare AI Gateway, which rewrites the canonical version segment.
|
||||
Anthropic's hints are `/v1/messages` and `/v1/complete`
|
||||
([anthropic.go:14–17](../../../proxy/internal/llm/anthropic.go)).
|
||||
Both implementations use case-insensitive substring matching so a proxy prefix
|
||||
strip / rewrite doesn't defeat detection.
|
||||
|
||||
`ParserByName` ([parser.go:93–103](../../../proxy/internal/llm/parser.go))
|
||||
is the **agent-network bypass**: the synthesiser knows which parser to use
|
||||
because it built the synth service from the catalog, so it stamps
|
||||
`provider_id` on the parser config and the middleware skips URL sniffing
|
||||
entirely. This is what makes the same parser set work whether the request
|
||||
flows to OpenAI direct, to LiteLLM, to Portkey, or to any gateway with a
|
||||
non-canonical URL shape.
|
||||
|
||||
**Path-routed providers (Vertex AI, Bedrock) bypass both `ParserByName` and
|
||||
`DetectParser`.** The model and the parser surface live in the URL path, so the
|
||||
request middleware extracts them directly (`parseVertexPath` /
|
||||
`parseBedrockPath`) before the parser-selection step. For Vertex the publisher
|
||||
segment picks the parser (`anthropic` → Anthropic parser; `google`/Gemini →
|
||||
none, request denied as unmeterable). For Bedrock the dedicated `BedrockParser`
|
||||
handles the response. Full treatment in
|
||||
[50-path-routed-providers.md](./50-path-routed-providers.md).
|
||||
|
||||
## Streaming response → SSE chunker → response parser → completion + token count
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant U as upstream LLM
|
||||
participant LR as llm_response_parser<br/>(OnResponse)
|
||||
participant S as llm.NewScanner<br/>(SSE framer)
|
||||
participant P as Parser-specific accumulator<br/>(accumulateOpenAIStream<br/>or accumulateAnthropicStream)
|
||||
|
||||
U-->>LR: text/event-stream<br/>(buffered prefix in RespBody)
|
||||
LR->>S: NewScanner(bytes.NewReader(body))
|
||||
loop until EOF or [DONE]
|
||||
S-->>LR: Event{Type, Data}
|
||||
LR->>P: dispatch per event.Type<br/>(OpenAI: data-only<br/>Anthropic: named events)
|
||||
P-->>P: accumulate completion text<br/>track usage from final frame
|
||||
end
|
||||
P-->>LR: llm.Usage + completion string
|
||||
LR->>LR: appendUsage stamps<br/>llm.{input,output,total,cached_input,cache_creation}_tokens
|
||||
LR->>LR: truncateCompletion(3500 bytes, rune-safe)
|
||||
LR->>LR: redactPII if redact_pii && captureCompletion
|
||||
```
|
||||
|
||||
`Scanner.Next`
|
||||
([sse.go:44–87](../../../proxy/internal/llm/sse.go)) returns one
|
||||
event per `\n\n` boundary; multiple `data:` lines join with `\n`; comment lines
|
||||
(starting with `:`) are skipped per the SSE spec; a trailing event without a
|
||||
closing blank line is still returned before `io.EOF` so a server that closes
|
||||
the connection cleanly doesn't lose the last frame
|
||||
([sse.go:55–58](../../../proxy/internal/llm/sse.go)). CRLF is
|
||||
normalised in `trimEOL` so fixtures captured from live servers replay
|
||||
unchanged.
|
||||
|
||||
## Per-provider
|
||||
|
||||
### OpenAI
|
||||
|
||||
[openai.go:54–67](../../../proxy/internal/llm/openai.go) defines
|
||||
`openAIRequest` with three prompt fields: `messages` (Chat Completions),
|
||||
`prompt` (legacy), `input` (Responses API). The decoder uses
|
||||
`json.RawMessage` so each shape is parsed lazily.
|
||||
|
||||
`ParseResponse`
|
||||
([openai.go:117–146](../../../proxy/internal/llm/openai.go))
|
||||
accepts both naming conventions: Chat Completions returns
|
||||
`prompt_tokens`/`completion_tokens`, Responses API returns
|
||||
`input_tokens`/`output_tokens`. `pickInt64` prefers Responses-API names and
|
||||
falls back — same parser handles both endpoints without per-route config.
|
||||
`openAICachedTokens` mirrors the fallback for
|
||||
`input_tokens_details.cached_tokens` vs `prompt_tokens_details.cached_tokens`.
|
||||
|
||||
**Key invariant:** `CachedInputTokens` for OpenAI is a SUBSET of
|
||||
`InputTokens`. The cost meter clamps to guard against malformed upstream
|
||||
responses where `cached > total`.
|
||||
|
||||
### Anthropic
|
||||
|
||||
[anthropic.go:37–49](../../../proxy/internal/llm/anthropic.go)
|
||||
defines `anthropicRequest` covering Messages API (`system` + `messages[]`)
|
||||
and legacy `/v1/complete` (`prompt` string). `ExtractPrompt` emits
|
||||
`system: <text>` first when present, then per-message `role: content`.
|
||||
|
||||
`ParseResponse`
|
||||
([anthropic.go:82–104](../../../proxy/internal/llm/anthropic.go))
|
||||
fills three independent token buckets: `InputTokens`, `CacheReadInputTokens`,
|
||||
`CacheCreationInputTokens`. Latter two are **additive** (not subset).
|
||||
`TotalTokens` sums all four so downstream dashboards render one "tokens"
|
||||
number without double-counting.
|
||||
|
||||
`ExtractCompletion` walks `content[]` `{type, text}` parts and concatenates
|
||||
non-empty text with newlines, falling back to legacy `completion`.
|
||||
|
||||
### Bedrock
|
||||
|
||||
[bedrock.go](../../../proxy/internal/llm/bedrock.go) implements the
|
||||
`Parser` interface for the AWS Bedrock runtime. Bedrock is **path-routed**: the
|
||||
model lives in the URL (`/model/{id}/{action}`), so the request middleware
|
||||
extracts it (see [50-path-routed-providers.md](./50-path-routed-providers.md))
|
||||
and `ParseRequest` is a deliberate no-op. The parser's real work is on the
|
||||
response leg, covering both Bedrock body shapes:
|
||||
|
||||
- **InvokeModel** — vendor-native. Anthropic-on-Bedrock returns snake_case usage
|
||||
(`input_tokens`, `output_tokens`, `cache_read_input_tokens`,
|
||||
`cache_creation_input_tokens`) with the same additive cache buckets as
|
||||
first-party Anthropic.
|
||||
- **Converse** — unified camelCase (`inputTokens`, `outputTokens`,
|
||||
`totalTokens`). `firstNonZero` folds the two naming conventions into one
|
||||
`Usage`; when Converse omits `totalTokens` the parser sums the buckets.
|
||||
|
||||
`ProviderName()` returns `"bedrock"` — its own `defaults_pricing.yaml` block,
|
||||
keyed by the **normalised** model id (region prefix + version suffix stripped by
|
||||
the request parser). `ParseResponse` returns `ErrStreamingUnsupported` for an
|
||||
AWS binary event-stream content-type (`application/vnd.amazon.eventstream`,
|
||||
`isAWSEventStream`) so the caller routes to the streaming accumulator instead.
|
||||
|
||||
### SSE framing
|
||||
|
||||
`Scanner` is `bufio`-backed, 64 KiB read buffer, 1 MiB max line so a
|
||||
malicious upstream can't blow process memory
|
||||
([sse.go:33–38, 97–100](../../../proxy/internal/llm/sse.go)).
|
||||
`splitField` strips one space after the `:` per the SSE spec. Documented
|
||||
`not safe for concurrent use`; every consumer creates a fresh scanner per
|
||||
response body. Streaming accumulators live in the middleware package
|
||||
([llm_response_parser/streaming.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go))
|
||||
but use `llm.NewScanner` so the framing contract stays here.
|
||||
|
||||
### Pricing catalog
|
||||
|
||||
`Table.Cost`
|
||||
([pricing.go:129–174](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
is the cost formula — most security-relevant math in this module:
|
||||
|
||||
| Provider | Formula |
|
||||
|---|---|
|
||||
| `openai` | `(inTokens − clamped) × InputPer1K + clamped × CachedInputPer1K + outTokens × OutputPer1K` where `clamped = min(cachedInput, inTokens)` |
|
||||
| `anthropic`, `bedrock` | `inTokens × InputPer1K + cachedInput × CacheReadPer1K + cacheCreation × CacheCreationPer1K + outTokens × OutputPer1K` |
|
||||
| default | `inTokens × InputPer1K + outTokens × OutputPer1K` |
|
||||
|
||||
`bedrock` shares the Anthropic additive-cache formula
|
||||
([pricing.go:172-174](../../../proxy/internal/llm/pricing/pricing.go)):
|
||||
Anthropic-on-Bedrock reports the same additive cache buckets, while non-Anthropic
|
||||
Bedrock models (Nova, Llama) simply report zero in those buckets so cost reduces
|
||||
to `input + output`.
|
||||
|
||||
Each per-bucket rate falls back to `InputPer1K` when zero — operators opt in
|
||||
to discounts by setting the field.
|
||||
|
||||
`Loader`
|
||||
([pricing.go:212–268](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
overlays an optional `pricing.yaml` from data-dir on top of the go:embed
|
||||
defaults. Atomic pointer swap means readers never observe a partial update.
|
||||
The mtime-poll reloader (30s default cadence) keeps the previous table on
|
||||
parse failure so cost annotation never goes blank during a botched edit.
|
||||
|
||||
`defaults_pricing.yaml` is the source of truth for built-in pricing.
|
||||
Operator overrides only carry the entries they want to change.
|
||||
|
||||
## Public contracts
|
||||
|
||||
**`Parser` interface**
|
||||
([parser.go:50–66](../../../proxy/internal/llm/parser.go)):
|
||||
|
||||
```go
|
||||
type Parser interface {
|
||||
Provider() Provider
|
||||
ProviderName() string
|
||||
DetectFromURL(path string) bool
|
||||
ParseRequest(body []byte) (RequestFacts, error)
|
||||
ParseResponse(status int, contentType string, body []byte) (Usage, error)
|
||||
ExtractPrompt(body []byte) string
|
||||
ExtractCompletion(status int, contentType string, body []byte) string
|
||||
}
|
||||
```
|
||||
|
||||
Adding a provider means implementing this interface and appending to the
|
||||
slice returned by `Parsers()` ([parser.go:78–84](../../../proxy/internal/llm/parser.go)).
|
||||
Order matters: `DetectFromURL` ties resolve by registration order.
|
||||
`Parsers()` today returns `{OpenAIParser, AnthropicParser, BedrockParser}`.
|
||||
|
||||
**`Provider` enum**
|
||||
([parser.go:8–18](../../../proxy/internal/llm/parser.go)):
|
||||
`ProviderUnknown = 0`, `ProviderOpenAI = 1`, `ProviderAnthropic = 2`,
|
||||
`ProviderBedrock = 3`. Numeric values are persisted in nothing today but treat
|
||||
them as wire-stable — new providers must take fresh numbers.
|
||||
|
||||
**`Pricing` lookup**
|
||||
([pricing.go:129](../../../proxy/internal/llm/pricing/pricing.go)):
|
||||
|
||||
```go
|
||||
func (t *Table) Cost(provider, model string, inTokens, outTokens, cachedInput, cacheCreation int64) (float64, bool)
|
||||
```
|
||||
|
||||
Nil-safe: `t.Cost` on a nil receiver returns `(0, false)`
|
||||
([pricing.go:130–132](../../../proxy/internal/llm/pricing/pricing.go)).
|
||||
`ok=false` means provider or model is absent from the loaded table; the caller
|
||||
emits `cost.skipped=unknown_model`.
|
||||
|
||||
## Invariants
|
||||
|
||||
1. **Cross-platform pricing build.** `pricing_unix.go` carries the only
|
||||
functional `loadPricing` (uses `syscall.O_NOFOLLOW` and `f.Stat()` on an
|
||||
open descriptor — both Unix-only). `pricing_other.go` is a build-tag
|
||||
fallback that returns `"not supported on this platform"`
|
||||
([pricing_other.go:14–16](../../../proxy/internal/llm/pricing/pricing_other.go)).
|
||||
The proxy is Linux-only in production today; a Windows port needs an
|
||||
equivalent path-as-handle implementation. Reviewers building on Windows
|
||||
should expect this surface to return an error at startup if an override
|
||||
file is configured.
|
||||
|
||||
2. **SSE scanner handles partial chunks.** A buffered prefix that doesn't end
|
||||
in `\n\n` still yields its accumulated event before `io.EOF`
|
||||
([sse.go:55–58](../../../proxy/internal/llm/sse.go)). Tests:
|
||||
`TestSSEScanner_OpenAIFixture`, `TestSSEScanner_AnthropicFixture`,
|
||||
`TestSSEScanner_MultilineData`, `TestSSEScanner_CRLF`. The streaming
|
||||
accumulators ride on this: `accumulateAnthropicStream` and
|
||||
`accumulateOpenAIStream` `break` on any scanner error to return partial
|
||||
usage rather than aborting
|
||||
([streaming.go:68–73, 144–150](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go)).
|
||||
|
||||
3. **`defaults_pricing.yaml` is the source of truth.** Compiled into the
|
||||
binary via `//go:embed`
|
||||
([pricing.go:29–30](../../../proxy/internal/llm/pricing/pricing.go)).
|
||||
`DefaultTable()` parses once and panics on parse failure
|
||||
([pricing.go:42–49](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
— by design: a broken embedded YAML must not ship to production.
|
||||
|
||||
4. **Loader path validation.** `resolveMiddlewareDataPath`
|
||||
([pricing.go:370–394](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
rejects absolute paths, traversal segments, and basenames that fail
|
||||
`basenameRegex = ^[a-zA-Z0-9._-]+$`. The resolved path must remain
|
||||
inside `baseDir` even after `filepath.Clean`. Tests:
|
||||
`TestNewLoader_PathValidation`, `TestNewLoader_PathValidation_Extended`,
|
||||
`TestNewLoader_SymlinkOutsideBaseDirRejected`, `TestNewLoader_SymlinkRejected`.
|
||||
|
||||
5. **Unix loader symlink safety.** `O_NOFOLLOW` on open, `f.Stat()` on the
|
||||
open descriptor (never re-stat by path), `info.Mode().IsRegular()` check,
|
||||
`io.LimitReader(f, maxPricingBytes+1)` with a final size assertion
|
||||
([pricing_unix.go:25–57](../../../proxy/internal/llm/pricing/pricing_unix.go)).
|
||||
A mid-read symlink swap is detected because the fstat is on the original
|
||||
fd. Test: `TestNewLoader_RejectsOversizedFile_FixesM4`.
|
||||
|
||||
6. **`yaml.NewDecoder(...).KnownFields(true)`**
|
||||
([pricing.go:397–398](../../../proxy/internal/llm/pricing/pricing.go))
|
||||
rejects YAML files that carry fields not in the schema. A typo in an
|
||||
operator override file fails loud instead of silently zeroing rates.
|
||||
|
||||
## Things to scrutinise
|
||||
|
||||
**Correctness.** Verify OpenAI cached-prompt clamp at
|
||||
[pricing.go:147–149](../../../proxy/internal/llm/pricing/pricing.go)
|
||||
short-circuits before subtraction. `Anthropic.TotalTokens` sums all four
|
||||
buckets (in + out + cache_read + cache_creation) — downstream dashboards
|
||||
need to know this differs from `input + output`.
|
||||
`OpenAIParser.ExtractPrompt` falls through `messages → input → prompt`; a
|
||||
request sending all three reports only `messages` (uncommon but worth
|
||||
noting).
|
||||
|
||||
**Security.** `Scanner.maxLine = 1 MiB`; a 2 MiB single-line `data:` event
|
||||
errors from `Scanner.Next` and both accumulators stop with partial usage.
|
||||
Pricing file 1 MiB cap is orders of magnitude larger than realistic. Confirm
|
||||
new schema additions are mirrored in both `pricingFile` and `Entry`;
|
||||
`KnownFields(true)` will reject silently-typo'd operator overrides
|
||||
otherwise.
|
||||
|
||||
**Concurrency.** `Loader.table` is `atomic.Pointer[Table]`; readers never
|
||||
block or see a torn table. `Loader.Reload` is one goroutine, cancelled via
|
||||
context (`TestLoader_ReloadBackgroundLoopCancellation`). `DefaultTable()`
|
||||
uses `sync.Once`. Per-call `Scanner` instances mean no shared state across
|
||||
concurrent response-parser calls.
|
||||
|
||||
**Perf.** `Table.Cost` is two map lookups + multiplications, O(1).
|
||||
`Scanner.Next` is one `ReadString('\n')` per line. Pricing reload poll 30s.
|
||||
|
||||
**Observability.** Reload failures count via `metric.Int64Counter` keyed
|
||||
`plugin`; warning log rate-limited at 5 min so a broken file doesn't flood.
|
||||
Parser errors return sentinels — middleware uses `errors.Is` to map to the
|
||||
right `cost.skipped` reason.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| File | Tests | Coverage highlights |
|
||||
|---|---:|---|
|
||||
| `parser_test.go` | 3 | `Parsers()` shape lock, `DetectParser` URL matrix, provider enum stability |
|
||||
| `openai_test.go` | 11 | Chat Completions + Responses API + legacy `prompt`; cached-tokens subset for both naming conventions; fixture replays |
|
||||
| `anthropic_test.go` | 7 | Messages + legacy `/v1/complete`; streaming REJECTED on `ParseResponse` (must use scanner); fixture replays |
|
||||
| `sse_test.go` | 12 | Fixture replay both providers; multiline `data:`; CRLF; comment skip; trailing-event-without-blank-line; oversize rejection |
|
||||
| `pricing/pricing_test.go` | 21 | Provider-shape switch; cached-rate fallback; cached-clamp; symlink rejection (target outside basedir + symlink to file); path validation matrix; oversize rejection; reload-keeps-previous-on-parse-error; mtime change detection; goroutine cancellation |
|
||||
|
||||
**Fixtures** ([proxy/internal/llm/fixtures/](../../../proxy/internal/llm/fixtures/)):
|
||||
`openai_chat_completion.json` (chat.completions with usage),
|
||||
`openai_responses.json` (Responses API shape),
|
||||
`openai_stream.txt` (3 deltas + usage + `[DONE]`),
|
||||
`anthropic_messages.json` (Messages API non-streaming),
|
||||
`anthropic_stream.txt` (full 7-event sequence: message_start →
|
||||
content_block_{start,delta×2,stop} → message_delta (usage) → message_stop),
|
||||
`pricing.yaml` (realistic-pricing starter for operator overrides).
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Sibling: [31-proxy-middleware-builtin.md](./31-proxy-middleware-builtin.md)
|
||||
— the chain that calls `llm.Parsers()`, `llm.ParserByName`,
|
||||
`llm.NewScanner`, `pricing.NewLoader`.
|
||||
- Path-routed providers (Vertex AI + Bedrock), credential syntax, and the
|
||||
Bedrock AWS event-stream accumulator:
|
||||
[50-path-routed-providers.md](./50-path-routed-providers.md).
|
||||
- Direct callers: `llm_request_parser/middleware.go:82–94`,
|
||||
`llm_response_parser/middleware.go:113–123`,
|
||||
`llm_response_parser/streaming.go:65, 142`, `cost_meter/factory.go:49–57`.
|
||||
- Related elsewhere: the agent-network synthesiser stamping `provider_id`
|
||||
is covered in the management-side module guide; proxy server boot +
|
||||
`FactoryContext` construction is covered in the proxy-framework guide.
|
||||
194
docs/agent-networks/modules/33-proxy-runtime.md
Normal file
194
docs/agent-networks/modules/33-proxy-runtime.md
Normal file
@@ -0,0 +1,194 @@
|
||||
# proxy/runtime — translate + serve + log
|
||||
|
||||
> **Risk level:** High — every config push from management is translated here, and the chain runs on every HTTP request to a synth target.
|
||||
> **Backward-compat impact:** Additive at the wire (`PathTargetOptions.middlewares`, `agent_network`, `disable_access_log`, capture caps) and on the proxy `Server` struct (`MiddlewareDataDir`, `MiddlewareCaptureBudgetBytes`). Non-agent-network targets stay on the no-middleware fast path.
|
||||
|
||||
## Module boundary
|
||||
|
||||
Turns the synth-service wire format from `ProxyService.SyncMappings`/`GetMappingUpdate` into in-process middleware chains and runs them on top of the existing `httputil.ReverseProxy`. Four concerns: (a) **translate** — `proto.MiddlewareConfig` → validated `middleware.Spec` (proxy/middleware_translate.go) + self-register the eight built-ins (proxy/middleware_register.go); (b) **boot + rebuild** — construct the `middleware.Manager`, share the OTel meter, install the live-service check, rebuild per-path chains on every `addMapping`/`modifyMapping` (proxy/server.go); (c) **serve** — resolve chain at request time, capture bodies under a global budget, invoke `RunRequest`/`RunResponse`/`RunTerminal`, render deny responses, apply `UpstreamRewrite` (proxy/internal/proxy/reverseproxy.go); (d) **log + tag** — emit access-log entries with the new `agent_network` flag, gate emission on `EnableLogCollection` via `DisableAccessLog` (proxy/internal/accesslog).
|
||||
|
||||
**Inert for non-agent-network targets**: nil or empty chain → existing fast path (reverseproxy.go:127-139); `SuppressAccessLog` defaults false so the access-log middleware emits unchanged.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| ---- | ---- |
|
||||
| proxy/middleware_translate.go | proto→Spec translation; slot/failmode/timeout mapping; caps |
|
||||
| proxy/middleware_translate_test.go | translator unit tests |
|
||||
| proxy/middleware_register.go | blank-imports the eight builtins for `init()` registration |
|
||||
| proxy/server.go | `initMiddlewareManager`, `rebuildMiddlewareChains`, `isLiveService`, `buildMiddlewareBindings`, new Server fields, `protoToMapping` stamps AgentNetwork/DisableAccessLog/CaptureConfig/Middlewares |
|
||||
| proxy/internal/proxy/reverseproxy.go | `WithMiddlewareManager`, chain dispatch, body capture, `applyUpstreamRewrite`/`Headers`, `buildRequestInput`, response-leg respInput identity fields |
|
||||
| proxy/internal/proxy/reverseproxy_test.go | `TestBuildRequestInput_PropagatesIdentityAndGroups` |
|
||||
| proxy/internal/proxy/context.go | `agentNetwork`, `suppressAccessLog`, `userGroupNames` on `CapturedData` |
|
||||
| proxy/internal/proxy/servicemapping.go | new `PathTarget` fields |
|
||||
| proxy/internal/proxy/agent_network_chain_realstack_test.go | end-to-end self-contained chain test |
|
||||
| proxy/internal/accesslog/logger.go | `logEntry.AgentNetwork` → `proto.AccessLog` |
|
||||
| proxy/internal/accesslog/middleware.go | reads `GetAgentNetwork()`; gates `l.log` on `!GetSuppressAccessLog()` |
|
||||
| proxy/internal/accesslog/middleware_test.go | suppress/default/preserves-usage assertions |
|
||||
| proxy/internal/auth/middleware_test.go | tunnel-peer group propagation contract |
|
||||
| proxy/internal/metrics/metrics.go | `Meter()` getter for the middleware manager |
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### Synth-service ingestion → translate → register → serve
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Management SyncMappings/GetMappingUpdate] --> B["processMappings\nserver.go:1492"]
|
||||
B --> C{Mapping type}
|
||||
C -->|CREATED| D["addMapping → setupHTTPMapping → updateMapping"]
|
||||
C -->|MODIFIED| E["modifyMapping → cleanupMappingRoutes → setupHTTPMapping → updateMapping"]
|
||||
C -->|REMOVED| F["removeMapping → cleanupMappingRoutes → invalidateMiddlewareChains"]
|
||||
D --> G["protoToMapping\nserver.go:2181"]
|
||||
E --> G
|
||||
G --> H["translateMiddlewareConfigs\nmiddleware_translate.go:55"]
|
||||
G --> I["translateMiddlewareCaptureConfig\nmiddleware_translate.go:18"]
|
||||
H --> J["[]middleware.Spec on PathTarget"]
|
||||
I --> K["*bodytap.Config on PathTarget"]
|
||||
J --> L["proxy.AddMapping\nservicemapping.go:118"]
|
||||
K --> L
|
||||
L --> M["rebuildMiddlewareChains\nserver.go:2017 → Manager.Rebuild"]
|
||||
F --> N["Manager.Invalidate(serviceID)"]
|
||||
```
|
||||
|
||||
### Per-request lifecycle through the chain + accesslog
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
participant C as Client
|
||||
participant M as accesslog.Middleware
|
||||
participant A as auth.Middleware (Protect)
|
||||
participant RP as ReverseProxy.ServeHTTP
|
||||
participant CH as middleware.Chain
|
||||
participant U as Upstream
|
||||
C->>M: HTTP request
|
||||
M->>M: NewCapturedData(requestID), WithCapturedData(ctx)
|
||||
M->>A: next.ServeHTTP
|
||||
A->>A: Private → ValidateTunnelPeer → stamp UserID/Email/Groups/GroupNames/AuthMethod
|
||||
A->>RP: next.ServeHTTP
|
||||
RP->>RP: findTargetForRequest → targetResult
|
||||
RP->>RP: stamp ServiceID/AccountID/AgentNetwork/SuppressAccessLog on CapturedData
|
||||
RP->>RP: resolveChain via Manager.ChainFor
|
||||
alt chain == nil or Empty
|
||||
RP->>U: httputil.ReverseProxy.ServeHTTP (fast path)
|
||||
else chain non-empty
|
||||
RP->>RP: bodytap.CaptureRequest (global budget)
|
||||
RP->>CH: RunRequest
|
||||
CH-->>RP: denyOutput? requestMeta + upstreamRewrite
|
||||
alt deny
|
||||
RP->>C: RenderDenyResponse
|
||||
else allow
|
||||
RP->>RP: capturingWriter + applyUpstreamRewrite/Headers
|
||||
RP->>U: httputil.ReverseProxy.ServeHTTP(respWriter)
|
||||
U-->>RP: response
|
||||
RP->>CH: RunResponse (respInput carries UserGroups)
|
||||
RP->>CH: RunTerminal (merged request+response metadata)
|
||||
end
|
||||
end
|
||||
RP-->>M: handler returns
|
||||
M->>M: build logEntry incl. AgentNetwork
|
||||
alt SuppressAccessLog == true
|
||||
M->>M: skip l.log; still trackUsage
|
||||
else default
|
||||
M->>M: l.log → goroutine SendAccessLog
|
||||
end
|
||||
```
|
||||
|
||||
### EnableLogCollection suppression path
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
S["agentnetwork.Settings.EnableLogCollection"] --> B["synthesizer: target.DisableAccessLog = !EnableLogCollection"]
|
||||
B --> P["proto PathTargetOptions.disable_access_log (field 13)"]
|
||||
P --> T["protoToMapping reads GetDisableAccessLog()\nserver.go:2211"]
|
||||
T --> M["PathTarget.DisableAccessLog\nservicemapping.go:47"]
|
||||
M --> R["ServeHTTP: cd.SetSuppressAccessLog\nreverseproxy.go:106"]
|
||||
R --> G["accesslog middleware: if !GetSuppressAccessLog l.log\nmiddleware.go:95"]
|
||||
R --> U["trackUsage unconditional — bandwidth telemetry preserved"]
|
||||
```
|
||||
|
||||
**Ingestion** lands as a `ProxyMapping` batch on `handleSyncMappingsStream`/`handleMappingStream`. `processMappings` dispatches to `addMapping`/`modifyMapping`/`removeMapping`; HTTP goes `setupHTTPMapping → updateMapping → protoToMapping`. `protoToMapping` (server.go:2181) is the single translation surface that materialises `[]middleware.Spec`, `*bodytap.Config`, `AgentNetwork`, `DisableAccessLog` onto each `PathTarget`; `updateMapping` finishes with `s.proxy.AddMapping(m)` (atomic swap under `mappingsMux`) and `s.rebuildMiddlewareChains(svcID, m)`.
|
||||
|
||||
At **request time** the access-log middleware stamps `CapturedData`; the auth chain runs (Private services lift `peer_group_ids` from `ValidateTunnelPeer` — auth/middleware_test.go:322). `ReverseProxy.ServeHTTP` resolves the chain; nil or empty → original `httputil.ReverseProxy`, no body capture. When a chain matches, body is captured under the global budget, `RunRequest` produces an `UpstreamRewrite` (`llm_router` selects a provider, rewrites scheme/host/path, injects `Authorization`), and `RunResponse`+`RunTerminal` run after the upstream returns. The terminal slot sees the merged metadata bag — that's how `llm_limit_record` ships the consumption sample. The **access-log** addition: `logEntry.AgentNetwork` from `GetAgentNetwork()` onto `proto.AccessLog.AgentNetwork`; the gate at middleware.go:95 honors `EnableLogCollection`, skipping `l.log` but keeping `trackUsage` so bandwidth telemetry survives.
|
||||
|
||||
## Public contracts touched
|
||||
|
||||
- `proxy.Server.MiddlewareDataDir` (string) — base dir for file-backed middleware config (server.go:238-241).
|
||||
- `proxy.Server.MiddlewareCaptureBudgetBytes` (int64) — process-wide capture cap; defaults to 256 MiB (server.go:248-250).
|
||||
- `proxy/internal/proxy.WithMiddlewareManager(*middleware.Manager) Option` — new option on `NewReverseProxy`; nil keeps the fast path (reverseproxy.go:48-56).
|
||||
- `proxy/internal/proxy.PathTarget` adds `Middlewares`, `CaptureConfig`, `AgentNetwork`, `DisableAccessLog` (servicemapping.go:27-51), all zero-default.
|
||||
- `proxy/internal/proxy.CapturedData` adds `agentNetwork`, `suppressAccessLog`, `userGroupNames` behind `sync.RWMutex`; slices deep-copied (context.go:47-66, 183-258).
|
||||
- `accesslog.logEntry.AgentNetwork` + `proto.AccessLog.AgentNetwork` (logger.go:131, 268).
|
||||
- `metrics.Metrics.Meter()` exposes the OTel meter for the middleware manager (metrics.go:53-58).
|
||||
|
||||
## Invariants
|
||||
|
||||
- **Synth-service updates are live (no proxy restart).** Every `MODIFIED` flows through `modifyMapping → cleanupMappingRoutes` (invalidates chains) `→ setupHTTPMapping → updateMapping → rebuildMiddlewareChains`. **ProxyMapping.Private preservation:** the relevant logic lives in `management/internals/shared/grpc/proxy.go:shallowCloneMapping`, not this module, but it surfaces here — if a `MODIFIED` synth service arrives `private=false`, auth skips `ValidateTunnelPeer`, `CapturedData.UserGroups` stays empty, and `llm_router` denies with `llm_policy.no_authorised_provider` until a management restart re-pushes the snapshot. This module assumes `mapping.GetPrivate()` is correct on every batch.
|
||||
- **`EnableLogCollection=false` suppresses access-log writes but middleware still runs.** Gate is one `if !cd.GetSuppressAccessLog()` immediately around `l.log(entry)` (middleware.go:95); `trackUsage` runs below the gate. Locked by `TestMiddleware_SuppressAccessLog_PreservesUsageTracking` (middleware_test.go:139).
|
||||
- **`agent_network` flag on access-log entries is set when the chain processed the request.** Source `target.AgentNetwork`, stamped at reverseproxy.go:105, read at accesslog/middleware.go:86.
|
||||
- **auth → builtin group propagation.** `Protect` writes `UserGroups`/`UserGroupNames`; `buildRequestInput` (reverseproxy.go:333) copies them into `middleware.Input`. The response-leg `respInput` (reverseproxy.go:196-223) also carries `UserEmail`/`UserGroups`/`UserGroupNames` — `llm_limit_record` needs `UserGroups` to ship `group_ids` so management's group-targeted budget rules match (comment at reverseproxy.go:211-215).
|
||||
- **Empty chains stay on the fast path.** `ServeHTTP` skips body capture and the run sequence when `chain == nil || chain.Empty()` (reverseproxy.go:127).
|
||||
- **Self-registration is the only way a builtin reaches the registry.** `middleware_register.go` blank-imports each builtin; `init()` adds the factory to `mwbuiltin.DefaultRegistry()`. Missing it → translator drops the entry with a warn (translate.go:97).
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
- **Translate edge cases** — drops on nil cfg, empty ID, unknown ID, UNSPECIFIED slot; each logs one warn; volume bounded by `MaxMiddlewaresPerChain`.
|
||||
- **Re-translate without dropping in-flight requests** — `Manager.Rebuild` is the only call from `rebuildMiddlewareChains`. Reverse proxy reads `ChainFor` once per request (reverseproxy.go:327) and runs the captured `*Chain` for the whole request. Verify in module 30 that `Rebuild` swaps atomically.
|
||||
- **ProxyMapping.Private preservation** — enforced management-side in `shallowCloneMapping`. Proxy-side regression catches: `TestProtect_PrivateService_TunnelPeerGroupsPropagate` + the integration test.
|
||||
- **Body-capture cleanup** — `defer releaseBudget()` (reverseproxy.go:145) and `defer capturingWriter.Release()` (reverseproxy.go:180) must run on every return; confirm no future `return` lands between acquisition and defer.
|
||||
- **`applyUpstreamRewrite` clones the URL** — `cloned := *orig` value-copies `*url.URL`; safe because overwritten fields are strings, not slices/maps (reverseproxy.go:285-292).
|
||||
|
||||
### Security
|
||||
- **Translate validates every config** — registry membership rejects unknown IDs; UNSPECIFIED slot drops; ID-less drops; raw config copied (not aliased) at translate.go:109.
|
||||
- **`AuthHeader`/`StripHeaders` only reachable via `UpstreamRewrite`** — regular mutation surface goes through the framework denylist (`Authorization`/`Cookie` blocked); only the router middleware can replace `Authorization` (reverseproxy.go:296-304). Confirm in module 30 nothing outside the proxy-trusted path populates `UpstreamRewrite.AuthHeader`.
|
||||
- **`stampNetBirdIdentity` strips client-sent values first** (reverseproxy.go:742-743) — anti-spoof for `X-NetBird-User`/`X-NetBird-Groups`; control chars filtered; comma-bearing labels dropped (reverseproxy_test.go:1217/:1243/:1193).
|
||||
- **Auth → group propagation** — `auth/middleware_test.go:322` and `:366` cover the contract. If auth ever stops calling `ValidateTunnelPeer` for Private services, every agent-network request silently denies.
|
||||
|
||||
### Concurrency
|
||||
- **Chain replacement under in-flight requests** — `findTargetForRequest` takes `mappingsMux.RLock`; `AddMapping` writes. `resolveChain` calls `ChainFor` once; even if `Rebuild` swaps mid-request, in-flight requests keep running on the captured pointer.
|
||||
- **`CapturedData` mutation across slots** — accessors take `sync.RWMutex`; slices deep-copied on both Set and Get. Verify no caller mutates the returned slice expecting it to land back.
|
||||
- **`Manager.Invalidate` race** — `removeMapping` invalidates after `cleanupMappingRoutes`; mapping read happens before chain resolution, so requests before invalidate run captured chains; later ones fail `findTargetForRequest`.
|
||||
- **`Logger.log` goroutine** — `logSem` caps at `maxLogWorkers = 4096`; overflow → `dropped.Add(1)` + debug log. Middleware test uses a buffered channel and 150ms negative-assertion window — review whether 150ms holds on slow CI.
|
||||
|
||||
### Backward compatibility
|
||||
- **Non-agent-network services unaffected** — `protoToMapping` reads new fields only when `opts != nil`; defaults leave `Middlewares`/`CaptureConfig` nil → chain resolves nil → fast path. Existing `reverseproxy_test.go` (non-chain) still passes.
|
||||
- **`disable_access_log` is proto field 13, default false** — every existing target unset; gate is no-op. Locked by `TestMiddleware_SuppressAccessLog_DefaultEmitsLog` (middleware_test.go:104).
|
||||
- **`Server` additions optional** — 256 MiB default when `MiddlewareCaptureBudgetBytes ≤ 0` (server.go:1997-2000).
|
||||
|
||||
### Performance
|
||||
- **Translate cost per push** — O(n) with per-entry registry lookup and `config_json` copy; negligible vs. the upstream gRPC unmarshal.
|
||||
- **Empty-chain hot path** — one `ChainFor` map lookup + one `chain.Empty()` check; no allocation delta vs. pre-PR.
|
||||
- **Body capture buffer churn** — `bodytap.CaptureRequest` allocates `MaxRequestBytes` per chain-hitting request; `releaseBudget` ties allocation to the 256 MiB proxy-wide budget. Confirm in module 30 the budget is a hard cap.
|
||||
|
||||
### Observability
|
||||
- **Metrics** — `Metrics.Meter()` shared with `middleware.NewMetrics` (server.go:1990-1993) so middleware instruments land in the same prometheus exporter. No new metrics defined here.
|
||||
- **Access-log accuracy** — every entry carries `AgentNetwork`; terminal-slot metadata merged into `CapturedData.Metadata` (reverseproxy.go:238-241).
|
||||
- **Deny logs at `Infof`** (reverseproxy.go:170) — review whether `Info` is too noisy at high deny rates; consider Debug or rate-limit.
|
||||
|
||||
## Test coverage
|
||||
|
||||
| Test file | Locks down |
|
||||
| --------- | ---------- |
|
||||
| proxy/middleware_translate_test.go | Empty/nil → nil; field preservation; unknown ID skip; nil registry permissive; timeout clamping; fail-mode + slot incl. UNSPECIFIED-drop; empty-ID drop; truncation above + at `MaxMiddlewaresPerChain` |
|
||||
| proxy/internal/proxy/reverseproxy_test.go | Rewrite host/headers/cookies/query; trusted proxy; path forwarding; classifyProxyError; X-NetBird-User/Groups anti-spoof + CSV-join + control-char/comma rejection + fallback-to-ID; `TestBuildRequestInput_PropagatesIdentityAndGroups` (UserGroups/Email/GroupNames/AgentNetwork reach `middleware.Input`) |
|
||||
| proxy/internal/proxy/agent_network_chain_realstack_test.go | **The end-to-end integration test.** Drives a real agent-network request through `ReverseProxy.ServeHTTP` with the chain the synthesizer produces, against an in-process management gRPC (bufconn) backed by a real sqlite store + real `agentnetwork.Manager`, plus an `httptest` upstream — no external infrastructure or real LLM. Guarantees: (1) response-leg `respInput` carries `UserGroups` so `llm_limit_record` ships non-empty `group_ids` and the admin-group consumption row increments; (2) `RedactPii=true` redacts both prompt and completion on captured metadata; (3) the full chain runs against a real management stack. **Line 189-211 inlines the proto→Spec mapping** instead of calling the proxy's private `translateMiddlewareConfig` — keep that inline mirror in sync with `proxy/middleware_translate.go` or the test silently diverges from production. |
|
||||
| proxy/internal/accesslog/middleware_test.go | `SuppressAccessLog=true` skips `SendAccessLog` (150ms negative wait); default emits one send (2s positive); usage tracking runs under suppression |
|
||||
| proxy/internal/auth/middleware_test.go | `TestProtect_PrivateService_TunnelPeerGroupsPropagate` proves `peer_group_ids` reach `CapturedData.UserGroups`; `TestProtect_PrivateService_TunnelPeerDenied` proves rejected peers 403 without reaching the handler |
|
||||
|
||||
The integration test runs in a few seconds with no external infrastructure — exercising the real synthesizer, `Manager.Rebuild`, `ServeHTTP` dispatch, and `llm_limit_record` writing a real consumption row through the real `agentnetwork.Manager` over real gRPC.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- **Translator does not validate `RawConfig` JSON** — factory's job at `New([]byte)`. Confirm in module 30 that a per-binding factory failure doesn't poison the rest of the chain.
|
||||
- **No throttle on management push rate** — every `MODIFIED` triggers `Manager.Rebuild`. Mitigation upstream.
|
||||
- **Streaming responses (SSE)** — body capture is streaming-aware, but response-leg middleware runs only after the response completes; long SSE streams delay `llm_limit_record` until close.
|
||||
- **OIDC-only path doesn't carry tunnel-peer groups** — agent-network synth services rely on the Private tunnel-peer path; JWT groups claim is the only carrier for non-Private OIDC.
|
||||
- **`agent_network` flag on L4 entries** not added; HTTP-only.
|
||||
- **`mw.capture.bypass_reason` metadata key** documented at reverseproxy.go:151,184; namespace this in module 30/31 to avoid collisions.
|
||||
|
||||
## Cross-references
|
||||
- Upstream: [shared/api](10-shared-api.md), [proxy/middleware-framework](30-proxy-middleware-framework.md), [proxy/middleware-builtin](31-proxy-middleware-builtin.md), [proxy/llm-parsers](32-proxy-llm-parsers.md)
|
||||
- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level: [../00-overview.md](../00-overview.md)
|
||||
228
docs/agent-networks/modules/40-dashboard.md
Normal file
228
docs/agent-networks/modules/40-dashboard.md
Normal file
@@ -0,0 +1,228 @@
|
||||
# dashboard — UI for agent-networks
|
||||
|
||||
This module documents code that lives in the **dashboard repo** (under
|
||||
`src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`), not
|
||||
in this repo. It is co-located here so backend readers see the full picture.
|
||||
|
||||
> **Risk level:** Medium. The new surface is isolated under `src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`, but it also reshapes the sidebar, splits `/peers`, renames `reverse-proxy/clusters` → `self-hosted-proxies`, and overlays the Control Center graph. Regressions here would be cross-cutting.
|
||||
> **Backward-compat impact:** Additive on the API side. Breaking on URL/navigation: `/peers` redirects to `/peers/devices` (src/app/(dashboard)/peers/page.tsx:7-15), `/reverse-proxy/clusters` was renamed to `/reverse-proxy/self-hosted-proxies`, the sidebar lost Access Control / Networks / Reverse Proxy / DNS / standalone Guardrails / Consumption / Activity (Navigation.tsx:165-171 — routes still resolve via URL), and the standalone `/agent-network/{access-log,consumption,global-controls}` routes are gone in favor of `/agent-network/observability`.
|
||||
|
||||
## Module boundary
|
||||
|
||||
The dashboard is the only place an operator interacts with agent-networks: provider catalog, configured providers, policies, guardrails, account-level budget rules, account settings (collection / redaction toggles), per-request access log, and consumption rollups all render, paginate, and edit here. Data flows in via SWR (`useFetchApi`) keyed by REST URL. One big context provider (`src/modules/agent-network/AIProvidersProvider.tsx`) aggregates five resources (providers, policies, guardrails, budget rules, settings) plus the proxy access-log stream filtered to `agent_network=true`, and exposes `add* / update* / toggle* / delete*` mutators that call through `useApiCall` and re-`mutate()` SWR. Pages mount the provider once at the top and compose presentational tables and modals beneath. The control-center page additionally fetches `/agent-network/{providers,policies}` directly (control-center/page.tsx:123-130) to overlay graph nodes.
|
||||
|
||||
## What the UI delivers
|
||||
|
||||
- **AI Observability** page with four tabs: Access Logs, Budget Dashboard,
|
||||
Budget Settings, Log Settings (replaces the standalone access-log,
|
||||
consumption, and global-controls routes).
|
||||
- **Providers** page: provider catalog + connect/edit wizard with per-vendor
|
||||
copy (LiteLLM, Portkey, Bifrost, Cloudflare, Vercel, OpenRouter, custom).
|
||||
- **Policies** page: group → provider authorization with per-policy Limits
|
||||
(minute-granular windows) + guardrail attach.
|
||||
- **Guardrails** page: reusable model-allowlist + prompt-capture sets.
|
||||
- **Account controls**: Log Collection / Prompt Collection / Redact PII toggles.
|
||||
- **Budget rules**: account-level rules reusing the policy Limits UI.
|
||||
- **Control Center overlay**: provider + agent-policy nodes on the graph.
|
||||
- **Navigation + peers reshaping**: peers split into Devices / Agents,
|
||||
`reverse-proxy/clusters` renamed to `self-hosted-proxies`, sidebar
|
||||
repackaged for agent-network focus.
|
||||
|
||||
## Surface added
|
||||
|
||||
### New pages
|
||||
|
||||
| Route | Purpose | Backing module(s) |
|
||||
| ----- | ------- | ----------------- |
|
||||
| `/agent-network` | Redirect to `/agent-network/providers` | page.tsx:7-15 |
|
||||
| `/agent-network/providers` | List + connect providers; header surfaces per-account base URL | providers/page.tsx + AgentProvidersTable + AIProviderModal |
|
||||
| `/agent-network/policies` | Group → Provider authorization with per-policy Limits + Guardrail attach | policies/page.tsx + AgentPoliciesTable + AgentPolicyModal |
|
||||
| `/agent-network/guardrails` | Reusable guardrail sets (model allowlist + prompt capture) | guardrails/page.tsx + AgentGuardrailsTable + AgentGuardrailModal |
|
||||
| `/agent-network/observability` | Tabs: Access Logs / Budget Dashboard / Budget Settings / Log Settings | observability/page.tsx |
|
||||
| `/peers/devices`, `/peers/agents` | Split of `/peers`, shared via `PeersListView` keyed by `kind` | peers/{devices,agents}/page.tsx |
|
||||
| `/reverse-proxy/self-hosted-proxies` | Renamed from `clusters` | self-hosted-proxies/page.tsx |
|
||||
|
||||
Removed in favor of `/agent-network/observability`: `/agent-network/access-log`, `/agent-network/consumption`, `/agent-network/global-controls`.
|
||||
|
||||
### New modules under src/modules/agent-network
|
||||
|
||||
| File | Role |
|
||||
| ---- | ---- |
|
||||
| AIProvidersProvider.tsx (~1158 LOC) | Aggregates every agent-network resource via SWR; normalises snake↔camel; exposes mutators; holds wizard-open state |
|
||||
| AIProviderModal.tsx (~1268 LOC) | Connect / edit provider wizard with per-vendor copy (Bifrost, Portkey, LiteLLM, Cloudflare, Vercel, OpenRouter, custom) |
|
||||
| AIProviderLogo + useProviderCatalog | Catalog-driven brand swatch + SWR hook over `/agent-network/catalog/providers` |
|
||||
| AgentPoliciesTable + AgentPolicyModal + AgentPolicyGuardrailsTab + AgentPolicyLimitsTab | Policies; modal has 3 tabs (Rule, Limits, Guardrails) |
|
||||
| AgentGuardrailsTable + AgentGuardrailModal + AgentGuardrailBrowseModal + AgentGuardrailChecksCell | Guardrails CRUD + attach-from-policy |
|
||||
| AgentBudgetRulesTable + AgentBudgetRuleModal | Account-level budget rules; modal reuses AgentPolicyLimitsTab verbatim |
|
||||
| AgentAccountControlsCard | Three account-wide toggles (Log Collection / Prompt Collection / Redact PII) |
|
||||
| AgentAccessLogTable + AgentAccessLogExpandedRow | Access log on `/events/proxy?agent_network=true` |
|
||||
| AgentConsumptionPanel + AgentConsumptionTable | Token + cost panel: charts + counter table |
|
||||
| table/AgentProvidersTable + AgentProviderActionCell | Providers table + per-row actions |
|
||||
| data/mockData.ts | Domain types and a few residual `MOCK_*` constants (see scrutinize) |
|
||||
|
||||
### Touched non-agent-network areas
|
||||
|
||||
- **control-center**: agent-network overlay (provider + agent-policy nodes); removed the All Networks dropdown; hid the Networks tab in FlowSelector (FlowSelector.tsx:9-14 — enum value kept so `?tab=networks` still type-checks); wrapped `ControlCenterView` in `AIProvidersProvider` (page.tsx:73-83); `agentPolicyNode` clicks routed to a separate state slot (page.tsx:1871-1874). New node renderers: nodes/ProviderNode.tsx, nodes/AgentPolicyNode.tsx (registered at utils/nodes.ts:21-22).
|
||||
- **peers**: Split into Devices and Agents sub-routes; shared via `PeersListView` keyed by `kind` (PeersListView.tsx:24-95). New compact-toolbar `UserFilterSelector` (users/UserFilterSelector.tsx).
|
||||
- **reverse-proxy**: Folder rename `clusters/` → `self-hosted-proxies/`; deleted `ClustersFeaturesCell.tsx`, `ClusterTypeIndicator.tsx`; new ReverseProxyClusterTargetSelector for cluster target type; Private toggle on target modal; body-capture knobs removed; new ReverseProxyEventExpandedRow.
|
||||
- **events**: `ReverseProxyEventsUserCell` rewritten with user + peer fallback (ReverseProxyEventsUserCell.tsx:14-21), shared with the access-log table.
|
||||
- **navigation**: Full repackaging in Navigation.tsx — Agent Network items flattened (no collapsible parent), distinct icons per item; Access Control, Networks, Reverse Proxy, DNS, standalone Guardrails, Consumption, Activity removed (still URL-reachable, per lines 165-171).
|
||||
|
||||
## Architecture & flow
|
||||
|
||||
### Page → Provider → Table/Modal hierarchy
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
Nav[Navigation.tsx]
|
||||
Nav --> ProvidersPage[/agent-network/providers/]
|
||||
Nav --> PoliciesPage[/agent-network/policies/]
|
||||
Nav --> GuardrailsPage[/agent-network/guardrails/]
|
||||
Nav --> ObsPage[/agent-network/observability/]
|
||||
|
||||
ProvidersPage --> AIPP1[AIProvidersProvider]
|
||||
PoliciesPage --> AIPP2[AIProvidersProvider]
|
||||
GuardrailsPage --> AIPP3[AIProvidersProvider]
|
||||
ObsPage --> AIPP4[AIProvidersProvider]
|
||||
ObsPage -.wraps.-> GroupsProvider
|
||||
ObsPage -.wraps.-> PeersProvider
|
||||
|
||||
AIPP1 --> ProvTable[AgentProvidersTable]
|
||||
ProvTable --> ProvModal[AIProviderModal]
|
||||
AIPP2 --> PolTable[AgentPoliciesTable]
|
||||
PolTable --> PolModal[AgentPolicyModal]
|
||||
PolModal --> PolGuardTab[AgentPolicyGuardrailsTab]
|
||||
PolModal --> PolLimitsTab[AgentPolicyLimitsTab]
|
||||
PolGuardTab --> GuardBrowse[AgentGuardrailBrowseModal]
|
||||
PolGuardTab --> GuardModal[AgentGuardrailModal]
|
||||
AIPP3 --> GuardTable[AgentGuardrailsTable]
|
||||
GuardTable --> GuardModal
|
||||
AIPP4 --> Tabs[Tabs]
|
||||
Tabs --> AccessLog[AgentAccessLogTable]
|
||||
Tabs --> Consumption[AgentConsumptionPanel]
|
||||
Tabs --> BudgetRules[AgentBudgetRulesTable]
|
||||
Tabs --> AccountCtl[AgentAccountControlsCard]
|
||||
BudgetRules --> BudgetModal[AgentBudgetRuleModal]
|
||||
BudgetModal -.reuses.-> PolLimitsTab
|
||||
```
|
||||
|
||||
### AI Observability tab page
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
Page[AIObservabilityPage] --> RA[RestrictedAccess<br/>permission.services.read]
|
||||
RA --> GP[GroupsProvider]
|
||||
GP --> PP[PeersProvider]
|
||||
PP --> AIP[AIProvidersProvider]
|
||||
AIP --> Tabs[Tabs / TabsList]
|
||||
Tabs --> T1[Access Logs<br/>AgentAccessLogTable]
|
||||
Tabs --> T2[Budget Dashboard<br/>AgentConsumptionPanel]
|
||||
Tabs --> T3[Budget Settings<br/>AgentBudgetRulesTable]
|
||||
Tabs --> T4[Log Settings<br/>AgentAccountControlsCard]
|
||||
T1 -.GET.-> EP[/events/proxy?agent_network=true/]
|
||||
T2 -.GET poll 5s.-> CONS[/agent-network/consumption/]
|
||||
T3 -.GET/PUT.-> BR[/agent-network/budget-rules/]
|
||||
T4 -.GET/PUT.-> ST[/agent-network/settings/]
|
||||
```
|
||||
|
||||
### Data fetch path
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
Page[Page component] --> Prov[AIProvidersProvider]
|
||||
Prov -->|useFetchApi| SWR[(SWR cache<br/>key = URL)]
|
||||
SWR -.GET.-> P[/agent-network/providers/]
|
||||
SWR -.GET.-> POL[/agent-network/policies/]
|
||||
SWR -.GET.-> G[/agent-network/guardrails/]
|
||||
SWR -.GET.-> BR[/agent-network/budget-rules/]
|
||||
SWR -.GET ignoreError.-> ST[/agent-network/settings/]
|
||||
SWR -.GET.-> CAT[/agent-network/catalog/providers/]
|
||||
SWR -.GET pageSize=100.-> EVT[/events/proxy agent_network=true/]
|
||||
Prov --> Mut[useApiCall.post/put/del]
|
||||
Mut -.on success.-> MutateSWR[SWR mutate keys]
|
||||
Prov --> Children[Tables / Modals via useAIProviders]
|
||||
```
|
||||
|
||||
Every list view reaches management through SWR over `/api/agent-network/*`. The provider context maps snake-case payloads to camelCase domain types (`fromAPI`, `policyFromAPI`, `guardrailFromAPI`, `budgetRuleFromAPI`, `settingsFromAPI`, `accessLogFromAPI` — AIProvidersProvider.tsx:138-562) and back via matching `*ToRequest` adaptors. The access log piggy-backs on `/events/proxy` with `agent_network=true&page_size=100` (line 707-709) and decodes LLM-specific fields from per-event `metadata`. Group IDs on events are resolved to current names through the surrounding GroupsProvider catalog (lines 515-521, 717-731) — no extra round trip. Mutators run `*ToRequest`, await `useApiCall.post/put/del`, call SWR `mutate()`, then `notify`. Errors caught and surfaced via `notify` — no exceptions escape into render. The Connect Provider modal's open state lives in the provider itself (`isWizardOpen` at lines 732-735) so the providers-page empty-state CTA and the table's + button share one modal. Control-center re-fetches `/agent-network/{providers,policies}` directly on top of `AIProvidersProvider` — SWR de-dupes but the code path is harder to reason about.
|
||||
|
||||
## Public contracts consumed
|
||||
|
||||
- `GET/POST /api/agent-network/providers`, `PUT/DELETE /:id`
|
||||
- `GET/POST /api/agent-network/policies`, `PUT/DELETE /:id`
|
||||
- `GET/POST /api/agent-network/guardrails`, `PUT/DELETE /:id`
|
||||
- `GET/POST /api/agent-network/budget-rules`, `PUT/DELETE /:id`
|
||||
- `GET/PUT /api/agent-network/settings` (ignoreError-tolerant; 404 = not yet bootstrapped — auto-bootstrap on first provider create via `bootstrap_cluster` field — AIProvidersProvider.tsx:737-760)
|
||||
- `GET /api/agent-network/catalog/providers` (read-only declarative; backend owns vendor list, IDs, brand colors, models, extra_headers, identity_injection — useProviderCatalog.ts:6-95)
|
||||
- `GET /api/agent-network/consumption` (polled every 5s on Budget Dashboard — ConsumptionPanel.tsx:53,65-71)
|
||||
- `GET /api/events/proxy?agent_network=true&page_size=100` (shared with Proxy Events)
|
||||
- `permission?.services?.read` gates every agent-network route via RestrictedAccess.
|
||||
|
||||
`AIProviderId` is a closed union in dashboard types (data/mockData.ts:8-21) but the converter tolerates anything the backend ships — unknown ids fall through to `"custom"` (AIProvidersProvider.tsx:497-506). Catalog values are pure read-through: anything declared in `extra_headers` renders in the modal automatically, copy keyed by header name (`EXTRA_HEADER_UI` in AIProviderModal.tsx:61-89), labeled-fallback for unknown ones.
|
||||
|
||||
## Invariants
|
||||
|
||||
- Provider context wrap order on user-attribution pages: `GroupsProvider > PeersProvider > AIProvidersProvider` (observability/page.tsx:87-89). Reverse it and access-log group resolution silently drops names.
|
||||
- Every agent-network route checks `permission?.services?.read` via `RestrictedAccess` (observability/page.tsx:85, providers/page.tsx:184, policies/page.tsx:53, guardrails/page.tsx:55).
|
||||
- Modal `key={open ? 1 : 0}` pattern is used to force unmount/remount on close so internal `useState` resets between edits (AgentBudgetRuleModal.tsx:60, AgentPolicyModal.tsx:66). Removing this would leak prior-row state into a new-row session.
|
||||
- `mockData.ts` is the canonical home for ALL agent-network domain types; `MOCK_*` constants must never reach a production code path. One leak remains (below).
|
||||
|
||||
## Things to scrutinize
|
||||
|
||||
### Correctness
|
||||
|
||||
- **Tab-state URL hand-off is one-way.** observability/page.tsx:53-58 reads `?tab=` on mount (despite the file comment at line 28 saying URL hand-off is future) but `setTab` does NOT push back, so reload preserves the chosen tab only if it came in via the link. Inconsistent with control-center (page.tsx:1817-1831).
|
||||
- **Provider overlay runs only in `applySingleGroupView` / `applyPeerView`** (control-center/page.tsx:557, 1159-1166). User view does NOT show providers — if agent-network is a primary lens, that's a gap.
|
||||
- **Two useEffects race to invalidate the control-center layout.** page.tsx:1655-1657 drops `layoutInitialized` when `agentPolicies` / `agentProviders` arrive; the main effect (1786-1799) also lists them as deps. Functional but fragile — watch for flash-of-empty-graph.
|
||||
- **`updateProvider` / `updatePolicy` / `updateBudgetRule` use `??` on `enabled`** (AIProvidersProvider.tsx:784, 859, 1018). Toggle paths are safe; any caller sending `enabled: false` thinking "leave it off" gets `existing.enabled` instead. Audit modal callers.
|
||||
- **Form validation in modals is minimal.** Window-seconds picker — mockData.ts:209-215 documents "minimum 60 — one minute" but there is no matching UI guard in PolicyLimitsTab; the backend validator is the enforcement point.
|
||||
|
||||
### Security
|
||||
|
||||
- **No client-side enforcement claims** — every cap, allowlist, and toggle is display + edit; proxy is the source of truth for deny decisions (AccessLogTable.tsx:177-191 renders backend-emitted `denyReason` as-is).
|
||||
- **Prompt display is gated by what the backend stamps.** When `enable_prompt_collection` is OFF the proxy must not put prompt/completion into event metadata; the dashboard renders whatever it gets verbatim (AccessLogTable lines 532-534, AccessLogExpandedRow.tsx:42-57). No UI filter on top of backend collection switches.
|
||||
- Account Controls disables `Redact PII` when `Prompt Collection` is off (AgentAccountControlsCard.tsx:122) and clears it on off-transition (line 100), but relies on backend to enforce the same gate at write — confirm PUT handler rejects `redact_pii=true && enable_prompt_collection=false`.
|
||||
- **Bifrost identity-header overrides**: empty-string vs nil semantics documented in AIProvidersProvider.tsx:772-781 ("omitted = preserve, empty = explicit clear"). Mishandling could leak group attribution to a header the operator thought disabled. Focused read of Bifrost code path in AIProviderModal.tsx recommended.
|
||||
|
||||
### Accessibility
|
||||
|
||||
- Observability TabsList (observability/page.tsx:96-113) uses the shared Tabs component — should inherit Radix roving-tabindex. All four TabsTriggers carry only icon + text, no `aria-label`; fine because text is visible.
|
||||
- Modal focus traps are inherited from the shared Modal; agent-network modals don't override them. Quick keyboard pass recommended.
|
||||
- `EndpointBadge` Copy button (providers/page.tsx:66-76) has an `aria-label`, good.
|
||||
|
||||
### Performance
|
||||
|
||||
- `AgentConsumptionPanel` polls `/agent-network/consumption` every 5s (ConsumptionPanel.tsx:53,70). Tab switches unmount the panel, so the poll stops — verify in network panel.
|
||||
- `AgentAccessLogTable` is hard-capped at 100 rows via `page_size=100` (AIProvidersProvider.tsx:707-709). Server-side pagination is future work; high-traffic tenants miss everything past row 100 — known limitation.
|
||||
- Observability page mounts providers ONCE at page level (observability/page.tsx:87-89); tab switches keep SWR cache hot. Moving the provider mount inside `TabsContent` would re-fetch the access log on every switch.
|
||||
|
||||
### Visual consistency
|
||||
|
||||
- The observability tab style mirrors peers/page.tsx. Outer Tabs `pt-4 pb-0 mb-0`, TabsList `px-8` (observability/page.tsx:94-96) — confirm chrome height matches so the page doesn't visually jump.
|
||||
- Sidebar: `Boxes` for Providers, `AccessControlIcon` for Policies, `TelescopeIcon` for AI Observability (Navigation.tsx:113,120,133). Reusing `AccessControlIcon` makes Policies look identical to the (now hidden) Access Control item — if Access Control ever comes back, they collide.
|
||||
- `AgentNetworkIcon` is used in breadcrumbs on every agent-network page but NOT in the sidebar (per-page icons instead). Deliberate departure — record so it doesn't get reverted.
|
||||
|
||||
## Test coverage
|
||||
|
||||
- **Cypress**: One file (`cypress/e2e/test.cy.ts`) covering only the install-page copy-to-clipboard flow. NOTHING covers agent-network UI.
|
||||
- **Component / unit tests**: `src/utils/version.test.ts` is the only `.test.*` file in the repo. The agent-network modules ship without component tests.
|
||||
- Data-cy hooks exist on key controls: `save-account-controls` (AgentAccountControlsCard.tsx:71), `enable-log-collection`, `enable-prompt-collection`, `redact-pii`, plus existing `data-cy={policy.name}` / `data-cy={provider.name}` on ActiveInactiveRow. Sufficient hooks for Cypress flows; none written yet.
|
||||
- **Tooling gap (pre-existing):** `npm run lint` (`next lint`) is broken in Next 16 — the `lint` subcommand was removed from the Next CLI in 16.x, so the dashboard effectively has no working lint gate. The fix is to add either a flat-config `eslint .` script or wire ESLint via an explicit `eslint-config-next` invocation.
|
||||
|
||||
## Known limitations / explicit non-goals
|
||||
|
||||
- **`data/mockData.ts` still contains `MOCK_GROUPS`, `MOCK_PROVIDERS`, `MOCK_PEERS`.** Only `MOCK_GROUPS` is referenced from production — AgentPoliciesTable.tsx:45,76 uses it as a name-lookup fallback when a policy references a group ID the real GroupsProvider doesn't know about. `MOCK_PROVIDERS` / `MOCK_PEERS` are unreferenced; safe to delete. The file is `/* eslint-disable */` so dead-code warnings don't flag them.
|
||||
- **Tab-state URL hand-off on observability page is one-way** (read-only).
|
||||
- **Access log hard-capped at 100 rows**; no server-side pagination.
|
||||
- **No optimistic updates.** All mutations are round-trip; failures rollback via SWR revalidation.
|
||||
- **`FlowView.NETWORKS` retained but hidden** from FlowSelector (FlowSelector.tsx:9-14). Old `?tab=networks` links still route to the hidden view because `applyNetworksView` still runs.
|
||||
- **Redirects are not query-preserving** — `router.replace("/peers/devices")` (peers/page.tsx:13) strips any incoming filter params.
|
||||
- **Control-center cross-fetches** `/agent-network/{providers,policies}` directly on top of `AIProvidersProvider`. Could be collapsed.
|
||||
- **Sidebar permanently hides Access Control, Networks, Reverse Proxy, standalone Guardrails, DNS, Activity, Consumption.** Routes still resolve via URL (Navigation.tsx:165-171); intentional.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Upstream API contracts: [shared/api](10-shared-api.md)
|
||||
- Backend persistence: [management/store](20-management-store.md)
|
||||
- Backend handler wiring: [management/handlers + wiring](22-management-handlers-wiring.md)
|
||||
- End-to-end flow narrative: [../01-end-to-end-flows.md](../01-end-to-end-flows.md)
|
||||
- Top-level overview: [../00-overview.md](../00-overview.md)
|
||||
251
docs/agent-networks/modules/50-path-routed-providers.md
Normal file
251
docs/agent-networks/modules/50-path-routed-providers.md
Normal file
@@ -0,0 +1,251 @@
|
||||
# path-routed providers — Vertex AI + Bedrock
|
||||
|
||||
This guide pulls the **path-routed** provider story together in one place
|
||||
because it crosses the catalog, the synthesiser, the request parser, and the
|
||||
router. The relevant building blocks are the `llm_router` /
|
||||
`llm_request_parser` middlewares
|
||||
([31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md)), the
|
||||
per-provider parser surface ([32-proxy-llm-parsers.md](32-proxy-llm-parsers.md)),
|
||||
and the synthesiser's catalog → `ProviderRoute` mapping
|
||||
([21-management-agentnetwork.md](21-management-agentnetwork.md)).
|
||||
|
||||
Sibling modules: [31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md)
|
||||
(router + request parser) and [32-proxy-llm-parsers.md](32-proxy-llm-parsers.md)
|
||||
(Bedrock parser + pricing).
|
||||
|
||||
---
|
||||
|
||||
## What "path-routed" means
|
||||
|
||||
Most catalog providers carry the model in the request **body** (`{"model": …}`),
|
||||
so `llm_router` selects an upstream by matching the model name against each
|
||||
provider's `Models` claim. Two providers instead carry the model in the **URL
|
||||
path**, so they are routed by path before the model/vendor table is consulted:
|
||||
|
||||
| Catalog id | Style flag | Request path shape |
|
||||
|---|---|---|
|
||||
| `vertex_ai_api` | `IsVertexPathStyle` → `ProviderRoute.Vertex` | `/v1/projects/{project}/locations/{region}/publishers/{publisher}/models/{model}:{action}` |
|
||||
| `bedrock_api` | `IsBedrockPathStyle` → `ProviderRoute.Bedrock` | `/model/{modelId}/{action}` (optionally behind `/bedrock`) |
|
||||
|
||||
The catalog declares the style with
|
||||
[`catalog.IsVertexPathStyle` / `catalog.IsBedrockPathStyle`](../../../management/server/agentnetwork/catalog/catalog.go)
|
||||
and the synthesiser copies the result onto the router route as the `Vertex` /
|
||||
`Bedrock` booleans
|
||||
([synthesizer.go:450-451](../../../management/server/agentnetwork/synthesizer.go)).
|
||||
On the request leg `llm_router.Invoke` dispatches `isVertexPath` / `isBedrockPath`
|
||||
**before** the model lookup
|
||||
([llm_router/middleware.go:138-216](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
so a model the parser extracted from the path can't be claimed by a same-vendor
|
||||
*body-routed* provider (e.g. `claude-*` on `api.anthropic.com`).
|
||||
|
||||
## Google Vertex AI (`vertex_ai_api`)
|
||||
|
||||
### Catalog entry
|
||||
|
||||
`KindProvider`, parser surface left unset on the catalog entry — the request
|
||||
parser picks the parser from the URL **publisher** segment, not from
|
||||
`ParserID`. Upstream host is `<region>-aiplatform.googleapis.com`
|
||||
(`https://aiplatform.googleapis.com` for the `global` location). The catalog
|
||||
lists the Claude-on-Vertex lineup (`claude-opus-4-*`, `claude-sonnet-4-*`,
|
||||
`claude-haiku-4-5`, `claude-fable-5`) at the same per-token rates as the
|
||||
first-party Anthropic entry
|
||||
([catalog.go:333-363](../../../management/server/agentnetwork/catalog/catalog.go)).
|
||||
|
||||
### Credential — service-account OAuth (`keyfile::`)
|
||||
|
||||
Vertex does **not** accept a static API key. The operator sets the provider
|
||||
`api_key` to:
|
||||
|
||||
```
|
||||
keyfile::<base64 of the GCP service-account JSON key>
|
||||
```
|
||||
|
||||
The synthesiser recognises the `keyfile::` prefix in `providerAuthHeader`
|
||||
([synthesizer.go:897-903](../../../management/server/agentnetwork/synthesizer.go)),
|
||||
emits **no** static auth value, and carries the base64 key material on the
|
||||
route as `GCPServiceAccountKeyB64`
|
||||
([factory.go:56-61](../../../proxy/internal/middleware/builtin/llm_router/factory.go)).
|
||||
At request time the router mints a short-lived OAuth2 access token from the key
|
||||
(cloud-platform scope) and injects `Authorization: Bearer <access-token>` —
|
||||
never the key itself
|
||||
([llm_router/middleware.go:621-692](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)):
|
||||
|
||||
- One auto-refreshing `oauth2.TokenSource` is cached per key (keyed by a
|
||||
SHA-256 of the base64 material), so token minting happens once and refreshes
|
||||
amortise across requests.
|
||||
- Mint / refresh is bounded by a 10s timeout HTTP client (`gcpTokenTimeout`) so
|
||||
a slow Google token endpoint can't hang the request.
|
||||
- A malformed key or an unreachable token endpoint fails the request with
|
||||
`llm_policy.upstream_auth_failed` at HTTP **502** (an upstream problem, not a
|
||||
policy denial) — see `denyUpstreamAuth`.
|
||||
|
||||
### Metering — Anthropic-on-Vertex only
|
||||
|
||||
The request parser extracts `{publisher, model, action}` from the path
|
||||
(`parseVertexPath`, [llm_request_parser/middleware.go:237-263](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)),
|
||||
strips the `@version` suffix from the model, and maps the publisher to a parser
|
||||
surface via `vertexPublisherVendor`:
|
||||
|
||||
- `anthropic` → `llm.provider="anthropic"` → metered through the Anthropic
|
||||
parser, priced under the **`anthropic`** block in `defaults_pricing.yaml`
|
||||
(the parser emits the standard Anthropic provider label, so Vertex Claude
|
||||
reuses first-party Anthropic prices).
|
||||
- `openai` → `llm.provider="openai"` (reserved; not in the catalog lineup
|
||||
today).
|
||||
- anything else (notably `google` / Gemini) → empty vendor → **no parser**.
|
||||
|
||||
**Gemini is intentionally denied as unmeterable.** When the parser emits no
|
||||
`llm.provider` for a Vertex publisher, `llm_router` returns
|
||||
`llm_policy.unmeterable_publisher` (403) rather than forwarding the request
|
||||
uncounted — serving it would bypass token / budget metering
|
||||
([llm_router/middleware.go:144-162, 712-728](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)).
|
||||
A Gemini parser would lift this restriction; until then the `google` publisher
|
||||
is omitted from the catalog.
|
||||
|
||||
> Caveat: cross-region inference profiles in `eu` / `apac` carry a ~10% price
|
||||
> premium that the base per-token rates do **not** model — cost annotations for
|
||||
> those regions read low. Operators who need exact regional billing override
|
||||
> the affected entries in `pricing.yaml`.
|
||||
|
||||
## AWS Bedrock (`bedrock_api`)
|
||||
|
||||
### Catalog entry
|
||||
|
||||
`KindProvider`, upstream host `bedrock-runtime.<region>.amazonaws.com`. Metered
|
||||
models are the Anthropic-on-Bedrock lineup (`anthropic.claude-*`) plus Amazon
|
||||
Nova and Llama 3.3 entries
|
||||
([catalog.go:300-332](../../../management/server/agentnetwork/catalog/catalog.go)).
|
||||
Anthropic-on-Bedrock reuses the first-party Claude prices (with additive cache
|
||||
buckets); Nova / Llama report no cache, so cost is `input + output`.
|
||||
|
||||
### Credential — static bearer token
|
||||
|
||||
Bedrock uses the **AWS Bedrock API key** as a static bearer. The operator sets
|
||||
the provider `api_key` directly (no `keyfile::` prefix); the catalog template
|
||||
is `Authorization: Bearer ${API_KEY}`
|
||||
([catalog.go:306-307](../../../management/server/agentnetwork/catalog/catalog.go)).
|
||||
No token minting — the synthesiser substitutes the key into the template and
|
||||
the router injects the resulting `Authorization` header after stripping inbound
|
||||
vendor auth (including client-supplied AWS SigV4 material: `X-Amz-Date`,
|
||||
`X-Amz-Security-Token`, `X-Amz-Content-Sha256`, see `strippedAuthHeaders`).
|
||||
|
||||
### Model id form — cross-region inference profiles
|
||||
|
||||
Bedrock model ids in the request path must be the cross-region
|
||||
**inference-profile** form, e.g.
|
||||
`eu.anthropic.claude-sonnet-4-5-20250929-v1:0`. The bare
|
||||
`anthropic.claude-…` id is rejected by AWS. `normalizeBedrockModel`
|
||||
([llm_request_parser/middleware.go:398-414](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go))
|
||||
strips the region prefix (`us.` / `eu.` / `apac.` / `global.`), an optional ARN
|
||||
wrapper, and the `-YYYYMMDD-vN[:N]` version/throughput suffix so the normalised
|
||||
id (`anthropic.claude-sonnet-4-5`) matches the catalog/pricing key.
|
||||
|
||||
### Supported endpoints + actions
|
||||
|
||||
`/model/{modelId}/{action}` where action ∈ `invoke`,
|
||||
`invoke-with-response-stream`, `converse`, `converse-stream`
|
||||
([llm_request_parser/middleware.go:363-390](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)).
|
||||
`invoke` / `converse` are non-streaming; the `-stream` actions set the streaming
|
||||
flag.
|
||||
|
||||
- **InvokeModel** body uses the vendor-native shape — for Anthropic that means
|
||||
`"anthropic_version":"bedrock-2023-05-31"` and snake_case usage with additive
|
||||
cache buckets.
|
||||
- **Converse** uses the unified camelCase shape with a precomputed `totalTokens`.
|
||||
- The `BedrockParser` reads both shapes on the response leg
|
||||
([bedrock.go](../../../proxy/internal/llm/bedrock.go)); the request parser
|
||||
doesn't need to distinguish them (`ParseRequest` is a no-op — model + stream
|
||||
come from the path).
|
||||
|
||||
### Streaming — AWS binary event-stream
|
||||
|
||||
The `-stream` actions return `application/vnd.amazon.eventstream` (the AWS
|
||||
binary event-stream framing), and streaming **is metered**.
|
||||
`accumulateBedrockStream`
|
||||
([llm_response_parser/streaming_bedrock.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock.go))
|
||||
decodes the frames with `aws-sdk-go-v2/aws/protocol/eventstream`:
|
||||
|
||||
- InvokeModel `chunk` frames wrap a base64 `{"bytes":…}` payload carrying a
|
||||
vendor-native (Anthropic) stream event — folded through the shared Anthropic
|
||||
stream accumulator.
|
||||
- Converse `contentBlockDelta` frames carry text; the trailing `metadata` frame
|
||||
carries the final usage block.
|
||||
- A truncated stream (cut at the body-tap capture cap) decodes best-effort:
|
||||
frames up to the cut are applied and partial usage is returned.
|
||||
|
||||
### Optional `/bedrock` gateway-namespace prefix
|
||||
|
||||
Clients may place an optional `/bedrock` prefix before the native path
|
||||
(`/bedrock/model/{modelId}/{action}`) to disambiguate Bedrock from other
|
||||
providers that also use `/model/...`. Both the request parser
|
||||
(`trimBedrockNamespace`) and the router (`splitBedrockNamespace`) accept it.
|
||||
When the prefix is present, the router sets
|
||||
`RewriteUpstream.StripPathPrefix = "/bedrock"` so the **native** path
|
||||
(`/model/...`) is what reaches `bedrock-runtime.<region>.amazonaws.com`
|
||||
([llm_router/middleware.go:168-184, 320-348](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)).
|
||||
|
||||
## Model allowlist on path-routed providers
|
||||
|
||||
Because the model lives in the URL rather than the body, a path-routed provider
|
||||
credential could otherwise be used for any model the upstream supports. The
|
||||
router still enforces the route's `Models` allowlist via `matchPathRoute`
|
||||
([llm_router/middleware.go:370-416](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)):
|
||||
|
||||
1. Filter to routes of the matching style (`Vertex` / `Bedrock`).
|
||||
2. Filter to routes whose `AllowedGroupIDs` authorise the caller's groups
|
||||
(else `no_authorised_provider`).
|
||||
3. Filter to routes that **claim the requested model**. As with body-routed
|
||||
providers, an **empty `Models` list = catch-all** (serve any model);
|
||||
a non-empty list serves only the listed models (else `model_not_routable`).
|
||||
4. Multiple survivors disambiguate by longest `UpstreamPath` prefix match.
|
||||
|
||||
So an operator who lists explicit models on a Vertex/Bedrock provider gets a
|
||||
hard allowlist; an operator who leaves `Models` empty accepts every model the
|
||||
upstream serves (still subject to the unmeterable-publisher gate on Vertex).
|
||||
|
||||
Model-less OpenAI endpoints (`GET /v1/models`) are **never** routed to a
|
||||
Vertex/Bedrock provider — `matchModelless` skips path-routed routes
|
||||
([llm_router/middleware.go:427-462](../../../proxy/internal/middleware/builtin/llm_router/middleware.go))
|
||||
so a model-listing call can't be rewritten onto an upstream that would 404 it.
|
||||
|
||||
## Catalog ↔ pricing cross-check
|
||||
|
||||
Catalog prices and context windows are cross-checked against LiteLLM's
|
||||
`model_prices_and_context_window.json`. The proxy's embedded
|
||||
`defaults_pricing.yaml` covers **every metered first-party model** the catalog
|
||||
enumerates — guarded by
|
||||
`TestDefaultTable_FirstPartyModelCoverage`
|
||||
([pricing/defaults_coverage_test.go](../../../proxy/internal/llm/pricing/defaults_coverage_test.go)),
|
||||
which fails if a catalog model has no embedded price. Bedrock entries are keyed
|
||||
by the **normalised** id the request parser emits (region prefix + version
|
||||
suffix stripped). Vertex Claude carries no Bedrock-style prefix, so it prices
|
||||
straight off the `anthropic` block.
|
||||
|
||||
## Things to scrutinise
|
||||
|
||||
**Security.** The Vertex service-account key is never forwarded — only a minted
|
||||
short-lived bearer. Confirm the key material stays out of access logs (it lives
|
||||
on `ProviderRoute.GCPServiceAccountKeyB64`, not in any emitted metadata key).
|
||||
The unmeterable-publisher deny is the only thing standing between an
|
||||
operator-misconfigured Vertex provider and unmetered Gemini traffic; verify
|
||||
`vertexPublisherVendor` stays conservative (deny by default for unknown
|
||||
publishers).
|
||||
|
||||
**Correctness.** `normalizeBedrockModel` is the join between the wire id and the
|
||||
pricing key — a model that normalises to something not in `defaults_pricing.yaml`
|
||||
meters at `cost.skipped=unknown_model` rather than failing the request. The
|
||||
`/bedrock` prefix strip must run on both the parser side (so the model is
|
||||
extracted) and the router side (so the upstream path is native); a regression in
|
||||
either silently breaks the other.
|
||||
|
||||
**Metering caveats.** eu/apac cross-region Bedrock + Vertex profiles carry a
|
||||
~10% premium not modelled by base pricing — flagged in both the catalog comment
|
||||
and `defaults_pricing.yaml`. Operators needing exact regional billing override
|
||||
the relevant entries.
|
||||
|
||||
## Cross-references
|
||||
|
||||
- Router + request-parser detail: [31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md)
|
||||
- Bedrock parser + pricing + SSE / event-stream: [32-proxy-llm-parsers.md](32-proxy-llm-parsers.md)
|
||||
- Catalog → route synthesis + `keyfile::` handling: [21-management-agentnetwork.md](21-management-agentnetwork.md)
|
||||
- Overview: [../00-overview.md](../00-overview.md)
|
||||
2
go.mod
2
go.mod
@@ -35,6 +35,7 @@ require (
|
||||
github.com/DeRuina/timberjack v1.4.2
|
||||
github.com/awnumar/memguard v0.23.0
|
||||
github.com/aws/aws-sdk-go-v2 v1.38.3
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1
|
||||
github.com/aws/aws-sdk-go-v2/config v1.31.6
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.18.10
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
|
||||
@@ -156,7 +157,6 @@ require (
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect
|
||||
github.com/awnumar/memcall v0.4.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6 // indirect
|
||||
|
||||
@@ -398,7 +398,42 @@ configure_domain() {
|
||||
return 0
|
||||
}
|
||||
|
||||
apply_agent_network_preset() {
|
||||
# Agent-network turnkey install: built-in Traefik + NetBird Proxy with
|
||||
# NB_PROXY_PRIVATE=true, dashboard locked to agent-network-only mode.
|
||||
# Bypasses every reverse-proxy / proxy / CrowdSec prompt. The only
|
||||
# inputs we still need from the operator are the domain (handled by
|
||||
# configure_domain via NETBIRD_DOMAIN env var or interactive prompt)
|
||||
# and the ACME email — both honor env vars first and fall back to a
|
||||
# prompt only when unset. CrowdSec is intentionally off.
|
||||
REVERSE_PROXY_TYPE="0"
|
||||
ENABLE_PROXY="true"
|
||||
ENABLE_CROWDSEC="false"
|
||||
|
||||
if [[ -n "${NETBIRD_LETSENCRYPT_EMAIL}" ]]; then
|
||||
TRAEFIK_ACME_EMAIL="${NETBIRD_LETSENCRYPT_EMAIL}"
|
||||
else
|
||||
TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email)
|
||||
fi
|
||||
|
||||
echo "" > /dev/stderr
|
||||
echo "Agent-network preset enabled (NETBIRD_AGENT_NETWORK=true):" > /dev/stderr
|
||||
echo " - reverse proxy: built-in Traefik" > /dev/stderr
|
||||
echo " - NetBird Proxy: enabled with NB_PROXY_PRIVATE=true" > /dev/stderr
|
||||
echo " - dashboard: NETBIRD_AGENT_NETWORK_ONLY=true" > /dev/stderr
|
||||
echo " - CrowdSec: disabled" > /dev/stderr
|
||||
echo " - Let's Encrypt email: ${TRAEFIK_ACME_EMAIL}" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
}
|
||||
|
||||
configure_reverse_proxy() {
|
||||
# Short-circuit: agent-network preset locks every reverse-proxy /
|
||||
# proxy / CrowdSec choice and bypasses the interactive prompts.
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
apply_agent_network_preset
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Prompt for reverse proxy type
|
||||
REVERSE_PROXY_TYPE=$(read_reverse_proxy_type)
|
||||
|
||||
@@ -910,6 +945,15 @@ NGINX_SSL_PORT=443
|
||||
# Letsencrypt
|
||||
LETSENCRYPT_DOMAIN=none
|
||||
EOF
|
||||
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
cat <<EOF
|
||||
# Agent-network preset: dashboard hides the standard NetBird surfaces
|
||||
# and exposes only the AI Observability + agent-network configuration
|
||||
# pages. Paired with NB_PROXY_PRIVATE=true on the proxy side.
|
||||
NETBIRD_AGENT_NETWORK_ONLY=true
|
||||
EOF
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -946,6 +990,17 @@ NB_PROXY_PROXY_PROTOCOL=true
|
||||
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
|
||||
EOF
|
||||
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
cat <<EOF
|
||||
# Agent-network preset: turn the proxy into the private reverse-proxy
|
||||
# ingress for agent-network synth services. Disables the public-facing
|
||||
# surface so the proxy serves only synth-generated routes (the
|
||||
# llm_router-driven LLM endpoints) and the per-account inbound
|
||||
# listeners on the embedded netstack.
|
||||
NB_PROXY_PRIVATE=true
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [[ "$ENABLE_CROWDSEC" == "true" && -n "$CROWDSEC_BOUNCER_KEY" ]]; then
|
||||
cat <<EOF
|
||||
NB_PROXY_CROWDSEC_API_URL=http://crowdsec:8080
|
||||
|
||||
@@ -116,6 +116,24 @@ func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, p
|
||||
c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
// injectAllProxyPolicies prepares an account for the per-peer network-map
|
||||
// computation. It prepends the in-memory agent-network services synthesised
|
||||
// from the account's current provider/policy state to account.Services so
|
||||
// the existing InjectProxyPolicies + injectPrivateServicePolicies walks pick
|
||||
// them up alongside persisted reverse-proxy services. Synthesised services
|
||||
// are never persisted; the account is loaded fresh per cycle so re-prepending
|
||||
// is safe and idempotent. Accounts without agent-network providers get an
|
||||
// empty synth slice — no behaviour change.
|
||||
func (c *Controller) injectAllProxyPolicies(ctx context.Context, account *types.Account) {
|
||||
synth, err := c.repo.SynthesizeAgentNetworkServices(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("synthesise agent-network services for account %s: %v", account.Id, err)
|
||||
} else if len(synth) > 0 {
|
||||
account.Services = append(synth, account.Services...)
|
||||
}
|
||||
account.InjectProxyPolicies(ctx)
|
||||
}
|
||||
|
||||
func (c *Controller) CountStreams() int {
|
||||
return c.peersUpdateManager.CountStreams()
|
||||
}
|
||||
@@ -150,7 +168,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
@@ -281,7 +299,15 @@ func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID s
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
// The affected-peer path MUST mirror sendUpdateAccountPeers (line 171)
|
||||
// here: injectAllProxyPolicies prepends the synthesised agent-network
|
||||
// services BEFORE InjectProxyPolicies + private-service policies run.
|
||||
// Previously this path called only account.InjectProxyPolicies, which
|
||||
// skipped the synth-services prepend — so peer-level changes
|
||||
// (proxy restart, embedded peer connect/disconnect) propagated a
|
||||
// network map that omitted the synth DNS zone, and the agent kept
|
||||
// resolving against the stale or absent record.
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
@@ -399,7 +425,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return fmt.Errorf("failed to get validated peers: %v", err)
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
@@ -603,7 +629,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
@@ -874,7 +900,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
c.injectAllProxyPolicies(ctx, account)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
@@ -3,7 +3,9 @@ package controller
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -16,6 +18,10 @@ type Repository interface {
|
||||
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
|
||||
GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error)
|
||||
// SynthesizeAgentNetworkServices returns the in-memory reverse-proxy
|
||||
// services synthesised from the account's agent-network provider/policy
|
||||
// state. Empty for accounts without agent-network providers.
|
||||
SynthesizeAgentNetworkServices(ctx context.Context, accountID string) ([]*service.Service, error)
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
@@ -50,6 +56,10 @@ func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID s
|
||||
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
func (r *repository) SynthesizeAgentNetworkServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||
return agentnetwork.SynthesizeServices(ctx, r.store, accountID)
|
||||
}
|
||||
|
||||
func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) {
|
||||
return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package agentnetwork
|
||||
|
||||
import "github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
|
||||
// init registers the agent-network service synthesiser with the affectedpeers
|
||||
// resolver. Agent-network reverse-proxy services are synthesised on demand and
|
||||
// never persisted, so the resolver can't load them from the store; without them
|
||||
// it can't fold the embedded proxy peer into the affected set on a client
|
||||
// group/peer change, and the proxy never learns a newly authorised client until
|
||||
// it reconnects. Registered here (rather than via a direct
|
||||
// affectedpeers→agentnetwork import) to avoid an import cycle
|
||||
// (agentnetwork → account → affectedpeers).
|
||||
func init() {
|
||||
affectedpeers.SetAgentNetworkSynthesizer(SynthesizeServices)
|
||||
}
|
||||
749
management/internals/modules/agentnetwork/catalog/catalog.go
Normal file
749
management/internals/modules/agentnetwork/catalog/catalog.go
Normal file
@@ -0,0 +1,749 @@
|
||||
// Package catalog defines the static set of Agent Network providers
|
||||
// recognized by the management server. The catalog is consulted both to
|
||||
// validate provider_id on create/update and to surface the available
|
||||
// providers (and their models) to the dashboard.
|
||||
package catalog
|
||||
|
||||
import "github.com/netbirdio/netbird/shared/management/http/api"
|
||||
|
||||
// Model is the in-memory representation of a catalog model.
|
||||
type Model struct {
|
||||
ID string
|
||||
Label string
|
||||
InputPer1k float64
|
||||
OutputPer1k float64
|
||||
ContextWindow int
|
||||
}
|
||||
|
||||
// ProviderKind groups catalog entries for UI presentation. The split
|
||||
// is semantic, not technical:
|
||||
// - KindProvider: the upstream is a vendor's first-party API (OpenAI,
|
||||
// Anthropic, Mistral, Bedrock, etc.) — NetBird talks straight to
|
||||
// the model provider.
|
||||
// - KindGateway: the upstream is itself a routing / aggregation layer
|
||||
// in front of multiple providers (LiteLLM, Portkey, Helicone, …).
|
||||
// These typically need NetBird identity stamped onto upstream
|
||||
// requests so the gateway's analytics and budgets attribute to the
|
||||
// real caller; that's what IdentityInjection is for.
|
||||
// - KindCustom: the catch-all "OpenAI-compatible self-hosted endpoint"
|
||||
// entry (vLLM, Ollama, custom inference servers).
|
||||
//
|
||||
// Frontend uses Kind to group the provider Select in the modal so an
|
||||
// operator can spot at a glance which catalog entries proxy other
|
||||
// providers vs. talk straight to one. Backend doesn't dispatch on Kind
|
||||
// today; it's purely a presentation hint.
|
||||
type ProviderKind string
|
||||
|
||||
const (
|
||||
KindProvider ProviderKind = "provider"
|
||||
KindGateway ProviderKind = "gateway"
|
||||
KindCustom ProviderKind = "custom"
|
||||
)
|
||||
|
||||
// Provider is the in-memory representation of a catalog provider.
|
||||
type Provider struct {
|
||||
ID string
|
||||
Name string
|
||||
Description string
|
||||
DefaultHost string
|
||||
// Kind groups this entry for UI presentation; see ProviderKind.
|
||||
Kind ProviderKind
|
||||
// AuthHeaderName is the HTTP header the provider's API expects
|
||||
// the credential under (e.g. "Authorization" for OpenAI,
|
||||
// "x-api-key" for Anthropic). Combined with AuthHeaderTemplate
|
||||
// at synthesis time to inject the auth header on every upstream
|
||||
// request.
|
||||
AuthHeaderName string
|
||||
AuthHeaderTemplate string
|
||||
DefaultContentType string
|
||||
BrandColor string
|
||||
// ParserID names the proxy LLM parser surface this provider
|
||||
// speaks (matches llm.Parser.ProviderName: "openai",
|
||||
// "anthropic"). Multiple catalog ids may share a parser surface
|
||||
// (e.g. azure_openai_api and mistral_api both speak the OpenAI
|
||||
// shape). Empty when no parser is yet implemented for the
|
||||
// surface — the proxy middleware then falls back to URL sniffing
|
||||
// or skips request-side enrichment.
|
||||
ParserID string
|
||||
// IdentityInjection, when non-nil, instructs the proxy to stamp
|
||||
// the caller's NetBird identity onto upstream requests under the
|
||||
// configured header names. Used for gateways like LiteLLM that
|
||||
// key budgets and attribution off request headers (the gateway
|
||||
// otherwise has no way to learn which user / group made the call).
|
||||
// The proxy strips the same header names from the inbound request
|
||||
// before stamping ours, so an app can't spoof identity by setting
|
||||
// these headers itself.
|
||||
IdentityInjection *IdentityInjection
|
||||
// ExtraHeaders is a catalog-declared list of additional per-
|
||||
// provider routing/config headers the proxy stamps on every
|
||||
// upstream request. Distinct from AuthHeaderName/Template (which
|
||||
// always carries the API_KEY) and from IdentityInjection (caller
|
||||
// identity). Each entry surfaces an optional input on the
|
||||
// dashboard's provider modal whose value lives on the provider
|
||||
// record's ExtraValues map (keyed by ExtraHeader.Name). Empty
|
||||
// list = no extra inputs rendered. Used today by Portkey for
|
||||
// "x-portkey-config: pc-..." (a saved-config id that resolves
|
||||
// upstream provider + credentials on Portkey's hosted side).
|
||||
ExtraHeaders []ExtraHeader
|
||||
Models []Model
|
||||
}
|
||||
|
||||
// ExtraHeader names a single optional per-provider routing/config
|
||||
// header. Catalog declares N of these per provider type; the operator
|
||||
// fills any subset on the provider record (see Provider.ExtraValues).
|
||||
// At synth time, only entries with a non-empty operator value are
|
||||
// stamped; the proxy's identity-inject middleware applies anti-spoof
|
||||
// (Remove + Add) so a client can't supply these headers themselves.
|
||||
//
|
||||
// UI copy (label / help text / tooltip) for each known Name lives on
|
||||
// the dashboard, not here — the backend's job is just to declare
|
||||
// which wire headers are accepted. New provider needs an extra
|
||||
// header? Add the Name here AND the matching UI copy on the dashboard.
|
||||
type ExtraHeader struct {
|
||||
// Name is the wire header name, e.g. "x-portkey-config".
|
||||
Name string
|
||||
}
|
||||
|
||||
// IdentityInjection describes how the proxy stamps NetBird identity onto
|
||||
// upstream gateway requests. Exactly one shape must be set — they're
|
||||
// mutually exclusive and dispatched by the inject middleware.
|
||||
//
|
||||
// Shape choice tracks the wire convention the upstream gateway uses,
|
||||
// not the vendor name. New gateways with a known shape become a catalog
|
||||
// entry, not a new code path.
|
||||
type IdentityInjection struct {
|
||||
// HeaderPair emits separate headers per identity dimension
|
||||
// (end-user id, tags as CSV). LiteLLM and OpenAI-compatible
|
||||
// self-hosted gateways that read identity from dedicated headers.
|
||||
HeaderPair *HeaderPairInjection
|
||||
// JSONMetadata emits a single header carrying a JSON object with
|
||||
// reserved keys for user / groups / etc. Portkey, Helicone-style
|
||||
// metadata headers, anything that wants a structured envelope.
|
||||
JSONMetadata *JSONMetadataInjection
|
||||
}
|
||||
|
||||
// HeaderPairInjection is the LiteLLM-style wire convention.
|
||||
type HeaderPairInjection struct {
|
||||
// Customizable, when true, marks the wire header names as
|
||||
// operator-overridable: the dashboard surfaces EndUserIDHeader
|
||||
// and TagsHeader as editable inputs (defaults shown as
|
||||
// placeholders) and the synthesizer pulls the actual values from
|
||||
// the provider record's IdentityHeader* fields rather than from
|
||||
// these defaults. An empty operator value disables stamping for
|
||||
// that dimension. Used today for Bifrost, whose log-metadata /
|
||||
// telemetry header prefix (x-bf-lh-* vs x-bf-dim-*) is a
|
||||
// per-operator choice; LiteLLM and similar gateways with a fixed
|
||||
// wire protocol leave this false so the catalog defaults are
|
||||
// authoritative.
|
||||
Customizable bool
|
||||
// EndUserIDHeader receives the caller's display identity (user
|
||||
// email when the peer is attached to a user, else peer.Name),
|
||||
// e.g. "x-litellm-end-user-id".
|
||||
EndUserIDHeader string
|
||||
// TagsHeader receives the caller's NetBird group display names
|
||||
// as a CSV, e.g. "x-litellm-tags".
|
||||
TagsHeader string
|
||||
// TagsInBody, when true, additionally writes the tag list into
|
||||
// the request body's metadata.tags array (a JSON path the
|
||||
// gateway parses for budget enforcement). LiteLLM only honours
|
||||
// metadata.tags for tag-budget gating — its x-litellm-tags
|
||||
// header path feeds spend tracking but bypasses
|
||||
// _tag_max_budget_check entirely. Body inject is skipped when
|
||||
// the request body is empty, truncated, non-JSON, or when an
|
||||
// existing metadata field is a non-object value (defensive: we
|
||||
// never clobber a client-supplied non-object). The header path
|
||||
// remains a robust fallback for spend tracking in those cases.
|
||||
TagsInBody bool
|
||||
// EndUserIDInBody, when true, additionally writes the display
|
||||
// identity into the request body's top-level "user" field (the
|
||||
// OpenAI-standard end-user identifier). LiteLLM resolves the end
|
||||
// user id from headers first then body, so for LiteLLM this is
|
||||
// belt-and-suspenders. It matters when an OpenAI-compatible
|
||||
// gateway downstream of LiteLLM (or OpenAI direct, bypassing
|
||||
// LiteLLM) only reads the body, and as anti-spoof: client-
|
||||
// supplied "user" values are overwritten with our trusted
|
||||
// identity. Same skip rules as TagsInBody.
|
||||
EndUserIDInBody bool
|
||||
}
|
||||
|
||||
// JSONMetadataInjection is the Portkey-style wire convention: a single
|
||||
// header carrying a JSON object. NetBird identity fields land under the
|
||||
// configured reserved keys; missing keys (empty string) are skipped at
|
||||
// emit time.
|
||||
type JSONMetadataInjection struct {
|
||||
// Customizable, when true, marks the JSON keys as operator-
|
||||
// overridable. The dashboard surfaces UserKey and GroupsKey as
|
||||
// editable inputs (the catalog values shown as placeholders) and
|
||||
// the synthesizer pulls the actual JSON-key names from the
|
||||
// provider record's IdentityHeader* fields. Same field reuse as
|
||||
// HeaderPair's customizable path — the dimensions (user identity,
|
||||
// groups) are the same, only the wire encoding differs (JSON key
|
||||
// vs HTTP header name). An empty operator value disables emission
|
||||
// for that dimension. Used today for Cloudflare AI Gateway, whose
|
||||
// cf-aig-metadata header accepts arbitrary JSON keys; Portkey
|
||||
// leaves this false because its keys are reserved by the Portkey
|
||||
// schema.
|
||||
Customizable bool
|
||||
// Header is the wire header name carrying the JSON payload, e.g.
|
||||
// "x-portkey-metadata".
|
||||
Header string
|
||||
// UserKey is the JSON key for the caller's display identity.
|
||||
// Portkey reserves "_user" for this dimension.
|
||||
UserKey string
|
||||
// GroupsKey is the JSON key for the caller's NetBird groups,
|
||||
// emitted as a CSV string value (Portkey requires string values).
|
||||
GroupsKey string
|
||||
// MaxValueLength caps each emitted JSON value, in bytes. Portkey
|
||||
// enforces a 128-char limit per value; oversized values are
|
||||
// truncated rather than failing the request. 0 disables the cap.
|
||||
MaxValueLength int
|
||||
}
|
||||
|
||||
// providers is the canonical list of supported Agent Network providers.
|
||||
// Update this list together with the dashboard's PROVIDER_CATALOG.
|
||||
var providers = []Provider{
|
||||
{
|
||||
ID: "openai_api",
|
||||
Kind: KindProvider,
|
||||
Name: "OpenAI API",
|
||||
Description: "GPT, Responses API, and Embeddings",
|
||||
DefaultHost: "api.openai.com",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#10A37F",
|
||||
ParserID: "openai",
|
||||
// Pricing + context windows cross-checked against LiteLLM's
|
||||
// model_prices_and_context_window.json. Notable corrections from
|
||||
// earlier values: o4-mini repriced from $4/$16 to $1.10/$4.40
|
||||
// per MTok, gpt-4o from $5/$15 to $2.50/$10, and the GPT-5
|
||||
// family context windows split between 1.05M for full-size
|
||||
// models and 272K for mini/nano/codex variants.
|
||||
Models: []Model{
|
||||
{ID: "gpt-5.5", Label: "GPT-5.5", InputPer1k: 0.005, OutputPer1k: 0.030, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.5-pro", Label: "GPT-5.5 Pro", InputPer1k: 0.030, OutputPer1k: 0.180, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4", Label: "GPT-5.4", InputPer1k: 0.0025, OutputPer1k: 0.015, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4-pro", Label: "GPT-5.4 Pro", InputPer1k: 0.030, OutputPer1k: 0.180, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4-mini", Label: "GPT-5.4 Mini", InputPer1k: 0.00075, OutputPer1k: 0.0045, ContextWindow: 272000},
|
||||
{ID: "gpt-5.4-nano", Label: "GPT-5.4 Nano", InputPer1k: 0.0002, OutputPer1k: 0.00125, ContextWindow: 272000},
|
||||
{ID: "gpt-5.3-codex", Label: "GPT-5.3 Codex", InputPer1k: 0.00175, OutputPer1k: 0.014, ContextWindow: 272000},
|
||||
{ID: "gpt-5.3-chat-latest", Label: "GPT-5.3 Chat", InputPer1k: 0.00175, OutputPer1k: 0.014, ContextWindow: 128000},
|
||||
{ID: "o4-mini", Label: "o4-mini", InputPer1k: 0.0011, OutputPer1k: 0.0044, ContextWindow: 200000},
|
||||
{ID: "gpt-4.1", Label: "GPT-4.1", InputPer1k: 0.002, OutputPer1k: 0.008, ContextWindow: 1047576},
|
||||
{ID: "gpt-4.1-mini", Label: "GPT-4.1 mini", InputPer1k: 0.0004, OutputPer1k: 0.0016, ContextWindow: 1047576},
|
||||
{ID: "gpt-4.1-nano", Label: "GPT-4.1 nano", InputPer1k: 0.0001, OutputPer1k: 0.0004, ContextWindow: 1047576},
|
||||
{ID: "gpt-4o", Label: "GPT-4o", InputPer1k: 0.0025, OutputPer1k: 0.010, ContextWindow: 128000},
|
||||
{ID: "gpt-4o-mini", Label: "GPT-4o mini", InputPer1k: 0.00015, OutputPer1k: 0.0006, ContextWindow: 128000},
|
||||
{ID: "gpt-4-turbo", Label: "GPT-4 Turbo", InputPer1k: 0.01, OutputPer1k: 0.03, ContextWindow: 128000},
|
||||
{ID: "gpt-3.5-turbo", Label: "GPT-3.5 Turbo", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 16385},
|
||||
{ID: "text-embedding-3-large", Label: "text-embedding-3-large", InputPer1k: 0.00013, OutputPer1k: 0, ContextWindow: 8191},
|
||||
{ID: "text-embedding-3-small", Label: "text-embedding-3-small", InputPer1k: 0.00002, OutputPer1k: 0, ContextWindow: 8191},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "anthropic_api",
|
||||
Kind: KindProvider,
|
||||
Name: "Anthropic API",
|
||||
Description: "Claude Messages API",
|
||||
DefaultHost: "api.anthropic.com",
|
||||
AuthHeaderName: "x-api-key",
|
||||
AuthHeaderTemplate: "${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#D97757",
|
||||
ParserID: "anthropic",
|
||||
// Per Anthropic's current model lineup. Pricing in USD per 1k
|
||||
// tokens. Context windows: 4.6+ family is 1M; Haiku 4.5 stays at
|
||||
// 200K. claude-3-7-sonnet and claude-3-5-haiku retired
|
||||
// 2026-02-19 — dropped from the catalog. claude-opus-4-1
|
||||
// deprecated, retires 2026-08-05 — kept until the cutover.
|
||||
// claude-mythos-5 omitted: Project Glasswing access only, not a
|
||||
// general-availability target. claude-fable-5 requires the
|
||||
// account to be on >= 30-day data retention or all requests
|
||||
// 400.
|
||||
Models: []Model{
|
||||
{ID: "claude-fable-5", Label: "Claude Fable 5", InputPer1k: 0.010, OutputPer1k: 0.050, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-8", Label: "Claude Opus 4.8", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-7", Label: "Claude Opus 4.7", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-6", Label: "Claude Opus 4.6", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-1", Label: "Claude Opus 4.1 (deprecated, retires 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000},
|
||||
{ID: "claude-sonnet-4-6", Label: "Claude Sonnet 4.6", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000},
|
||||
{ID: "claude-sonnet-4-5", Label: "Claude Sonnet 4.5", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000},
|
||||
{ID: "claude-haiku-4-5", Label: "Claude Haiku 4.5", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "azure_openai_api",
|
||||
Kind: KindProvider,
|
||||
Name: "Azure OpenAI API",
|
||||
Description: "Azure-hosted OpenAI deployments",
|
||||
DefaultHost: "<resource>.openai.azure.com",
|
||||
AuthHeaderName: "api-key",
|
||||
AuthHeaderTemplate: "${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#0078D4",
|
||||
ParserID: "openai",
|
||||
// Mirrors openai_api pricing — Azure resells OpenAI models at the
|
||||
// same per-token rates, just under different deployment names.
|
||||
Models: []Model{
|
||||
{ID: "gpt-5.5", Label: "GPT-5.5 (Azure)", InputPer1k: 0.005, OutputPer1k: 0.030, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4", Label: "GPT-5.4 (Azure)", InputPer1k: 0.0025, OutputPer1k: 0.015, ContextWindow: 1050000},
|
||||
{ID: "gpt-5.4-mini", Label: "GPT-5.4 Mini (Azure)", InputPer1k: 0.00075, OutputPer1k: 0.0045, ContextWindow: 272000},
|
||||
{ID: "gpt-5.4-nano", Label: "GPT-5.4 Nano (Azure)", InputPer1k: 0.0002, OutputPer1k: 0.00125, ContextWindow: 272000},
|
||||
{ID: "o4-mini", Label: "o4-mini (Azure)", InputPer1k: 0.0011, OutputPer1k: 0.0044, ContextWindow: 200000},
|
||||
{ID: "gpt-4.1", Label: "GPT-4.1 (Azure)", InputPer1k: 0.002, OutputPer1k: 0.008, ContextWindow: 1047576},
|
||||
{ID: "gpt-4.1-mini", Label: "GPT-4.1 mini (Azure)", InputPer1k: 0.0004, OutputPer1k: 0.0016, ContextWindow: 1047576},
|
||||
{ID: "gpt-4o", Label: "GPT-4o (Azure)", InputPer1k: 0.0025, OutputPer1k: 0.010, ContextWindow: 128000},
|
||||
{ID: "gpt-4o-mini", Label: "GPT-4o mini (Azure)", InputPer1k: 0.00015, OutputPer1k: 0.0006, ContextWindow: 128000},
|
||||
{ID: "gpt-35-turbo", Label: "GPT-3.5 Turbo (Azure)", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 16385},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "bedrock_api",
|
||||
Kind: KindProvider,
|
||||
Name: "AWS Bedrock API",
|
||||
Description: "Anthropic, Meta, Cohere via Bedrock",
|
||||
DefaultHost: "bedrock-runtime.<region>.amazonaws.com",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#FF9900",
|
||||
// Anthropic models on Bedrock take the anthropic.* prefix and
|
||||
// follow the same lineup / pricing as the first-party Anthropic
|
||||
// catalog entry above. claude-3-7-sonnet and claude-3-5-haiku
|
||||
// were retired upstream on 2026-02-19 — dropped from the
|
||||
// Bedrock list too. Amazon Nova entries cross-checked against
|
||||
// LiteLLM (added Nova Micro + the new Nova 2 Lite preview).
|
||||
// Llama 3.3 70B entry kept unchanged — LiteLLM tracks only
|
||||
// per-region Llama 3 entries; standalone 3.3 not yet listed.
|
||||
Models: []Model{
|
||||
{ID: "anthropic.claude-opus-4-8", Label: "Claude Opus 4.8 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "anthropic.claude-opus-4-7", Label: "Claude Opus 4.7 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "anthropic.claude-opus-4-6", Label: "Claude Opus 4.6 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "anthropic.claude-opus-4-1", Label: "Claude Opus 4.1 (Bedrock, deprecated 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000},
|
||||
{ID: "anthropic.claude-sonnet-4-6", Label: "Claude Sonnet 4.6 (Bedrock)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000},
|
||||
{ID: "anthropic.claude-sonnet-4-5", Label: "Claude Sonnet 4.5 (Bedrock)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000},
|
||||
{ID: "anthropic.claude-haiku-4-5", Label: "Claude Haiku 4.5 (Bedrock)", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000},
|
||||
{ID: "meta.llama3-3-70b-instruct", Label: "Llama 3.3 70B (Bedrock)", InputPer1k: 0.00072, OutputPer1k: 0.00072, ContextWindow: 128000},
|
||||
{ID: "amazon.nova-2-lite", Label: "Amazon Nova 2 Lite (Bedrock, preview)", InputPer1k: 0.0003, OutputPer1k: 0.0025, ContextWindow: 1000000},
|
||||
{ID: "amazon.nova-pro", Label: "Amazon Nova Pro (Bedrock)", InputPer1k: 0.0008, OutputPer1k: 0.0032, ContextWindow: 300000},
|
||||
{ID: "amazon.nova-lite", Label: "Amazon Nova Lite (Bedrock)", InputPer1k: 0.00006, OutputPer1k: 0.00024, ContextWindow: 300000},
|
||||
{ID: "amazon.nova-micro", Label: "Amazon Nova Micro (Bedrock)", InputPer1k: 0.000035, OutputPer1k: 0.00014, ContextWindow: 128000},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "vertex_ai_api",
|
||||
Kind: KindProvider,
|
||||
Name: "Google Vertex AI API",
|
||||
Description: "Anthropic Claude models hosted on Vertex AI",
|
||||
DefaultHost: "<region>-aiplatform.googleapis.com",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#4285F4",
|
||||
// Vertex carries the model in the URL path and authenticates with a
|
||||
// service-account-minted OAuth token (api_key = "keyfile::<base64 SA>").
|
||||
// Only Anthropic-on-Vertex is metered today: the request parser maps the
|
||||
// anthropic publisher to the Anthropic parser, so the lineup + prices
|
||||
// mirror the first-party Anthropic catalog (LiteLLM vertex_ai/claude-*
|
||||
// confirms the same per-token rates; cross-region profiles in eu/apac
|
||||
// carry a ~10% premium that base pricing does not model). Gemini (the
|
||||
// google publisher) is intentionally omitted until a Gemini parser
|
||||
// exists — the router denies unmeterable publishers rather than forward
|
||||
// them uncounted.
|
||||
Models: []Model{
|
||||
{ID: "claude-fable-5", Label: "Claude Fable 5 (Vertex)", InputPer1k: 0.010, OutputPer1k: 0.050, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-8", Label: "Claude Opus 4.8 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-7", Label: "Claude Opus 4.7 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-6", Label: "Claude Opus 4.6 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000},
|
||||
{ID: "claude-opus-4-1", Label: "Claude Opus 4.1 (Vertex, deprecated 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000},
|
||||
{ID: "claude-sonnet-4-6", Label: "Claude Sonnet 4.6 (Vertex)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000},
|
||||
{ID: "claude-sonnet-4-5", Label: "Claude Sonnet 4.5 (Vertex)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000},
|
||||
{ID: "claude-haiku-4-5", Label: "Claude Haiku 4.5 (Vertex)", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "mistral_api",
|
||||
Kind: KindProvider,
|
||||
Name: "Mistral API",
|
||||
Description: "Mistral cloud API",
|
||||
DefaultHost: "api.mistral.ai",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#FF7000",
|
||||
ParserID: "openai",
|
||||
// Pricing + context windows cross-checked against LiteLLM. Key
|
||||
// gotchas the marketing page hides:
|
||||
// - `mistral-medium-latest` aliases to Medium 3.1 ($0.40/$2),
|
||||
// NOT Medium 3.5 ($1.50/$7.50). Catalog exposes both.
|
||||
// - `mistral-large-latest` aliases to Large 3 — 262K context,
|
||||
// cheaper than Medium 3.5.
|
||||
// - Magistral models are tuned for reasoning but cap context
|
||||
// at only 40K (vs 128K-262K elsewhere).
|
||||
// - `codestral-latest` still routes to the old 2405 build
|
||||
// ($1/$3) per LiteLLM; the newer codestral-2508 is both
|
||||
// cheaper and longer-context. Both exposed.
|
||||
// - Pixtral was folded into the main Large/Medium series; no
|
||||
// standalone vision entry.
|
||||
Models: []Model{
|
||||
{ID: "mistral-large-latest", Label: "Mistral Large 3", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 262144},
|
||||
{ID: "mistral-medium-latest", Label: "Mistral Medium 3.1", InputPer1k: 0.0004, OutputPer1k: 0.002, ContextWindow: 131072},
|
||||
{ID: "mistral-medium-3-5", Label: "Mistral Medium 3.5", InputPer1k: 0.0015, OutputPer1k: 0.0075, ContextWindow: 262144},
|
||||
{ID: "mistral-small-latest", Label: "Mistral Small 3.2", InputPer1k: 0.00006, OutputPer1k: 0.00018, ContextWindow: 131072},
|
||||
{ID: "magistral-medium-latest", Label: "Magistral Medium (reasoning)", InputPer1k: 0.002, OutputPer1k: 0.005, ContextWindow: 40000},
|
||||
{ID: "magistral-small-latest", Label: "Magistral Small (reasoning)", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 40000},
|
||||
{ID: "devstral-medium-latest", Label: "Devstral Medium 2 (coding)", InputPer1k: 0.0004, OutputPer1k: 0.002, ContextWindow: 256000},
|
||||
{ID: "devstral-small-latest", Label: "Devstral Small 2 (coding)", InputPer1k: 0.0001, OutputPer1k: 0.0003, ContextWindow: 256000},
|
||||
{ID: "codestral-2508", Label: "Codestral 2508", InputPer1k: 0.0003, OutputPer1k: 0.0009, ContextWindow: 256000},
|
||||
{ID: "codestral-latest", Label: "Codestral (legacy 2405)", InputPer1k: 0.001, OutputPer1k: 0.003, ContextWindow: 32000},
|
||||
{ID: "ministral-3-14b-2512", Label: "Ministral 3 14B", InputPer1k: 0.0002, OutputPer1k: 0.0002, ContextWindow: 262144},
|
||||
{ID: "ministral-8b-latest", Label: "Ministral 8B", InputPer1k: 0.00015, OutputPer1k: 0.00015, ContextWindow: 262144},
|
||||
{ID: "ministral-3-3b-2512", Label: "Ministral 3 3B", InputPer1k: 0.0001, OutputPer1k: 0.0001, ContextWindow: 131072},
|
||||
{ID: "mistral-embed", Label: "Mistral Embed", InputPer1k: 0.0001, OutputPer1k: 0, ContextWindow: 8192},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "litellm_proxy",
|
||||
Kind: KindGateway,
|
||||
Name: "LiteLLM Proxy",
|
||||
Description: "Bring your own LiteLLM proxy with NetBird identity stamped on every request",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#0EA5E9",
|
||||
ParserID: "openai",
|
||||
// IdentityInjection requires a LiteLLM virtual key minted with
|
||||
// metadata.allow_client_tags=true; the master key silently drops
|
||||
// caller tags. Tags go out via both the x-litellm-tags header and
|
||||
// body metadata.tags: LiteLLM enforces budgets from the body only,
|
||||
// so the header is the spend-tracking fallback when body injection
|
||||
// can't run. See the Agent Network provider docs for key setup.
|
||||
IdentityInjection: &IdentityInjection{
|
||||
HeaderPair: &HeaderPairInjection{
|
||||
EndUserIDHeader: "x-litellm-end-user-id",
|
||||
TagsHeader: "x-litellm-tags",
|
||||
TagsInBody: true,
|
||||
EndUserIDInBody: true,
|
||||
},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "portkey",
|
||||
Kind: KindGateway,
|
||||
Name: "Portkey AI Gateway",
|
||||
Description: "Portkey AI Gateway with NetBird identity stamped via x-portkey-metadata",
|
||||
DefaultHost: "api.portkey.ai",
|
||||
// Portkey hosted requires x-portkey-api-key (account key)
|
||||
// plus a routing decision per request. The simplest routing
|
||||
// path is a saved Portkey config id stamped via
|
||||
// x-portkey-config — operators paste the pc-... id once and
|
||||
// Portkey resolves the upstream provider + virtual key from
|
||||
// it. ExtraHeaders below surfaces the input. Alternative:
|
||||
// callers author "@org/model" in the body; both flows
|
||||
// coexist (per-request authoring still works without a
|
||||
// configured value).
|
||||
AuthHeaderName: "x-portkey-api-key",
|
||||
AuthHeaderTemplate: "${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#FF5C00",
|
||||
ParserID: "openai",
|
||||
IdentityInjection: &IdentityInjection{
|
||||
JSONMetadata: &JSONMetadataInjection{
|
||||
Header: "x-portkey-metadata",
|
||||
UserKey: "_user",
|
||||
GroupsKey: "groups",
|
||||
MaxValueLength: 128,
|
||||
},
|
||||
},
|
||||
ExtraHeaders: []ExtraHeader{
|
||||
{Name: "x-portkey-config"},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "bifrost",
|
||||
Kind: KindGateway,
|
||||
Name: "Bifrost",
|
||||
Description: "Maxim AI's Bifrost gateway. Point upstream URL at /openai/v1 or /anthropic/v1 on your Bifrost host depending on which body shape your apps use.",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#7C3AED",
|
||||
// ParserID empty: the proxy's request parser sniffs the URL
|
||||
// path. Bifrost's /openai/v1/... contains "/v1/chat/completions"
|
||||
// (matches OpenAIParser.DetectFromURL); /anthropic/v1/messages
|
||||
// contains "/v1/messages" (matches AnthropicParser). Operators
|
||||
// who paste a different prefix get no usage parsing and the
|
||||
// cost meter skips with skipMissingProvider — degraded but
|
||||
// non-fatal.
|
||||
ParserID: "",
|
||||
// Identity-injection headers are operator-customisable. The
|
||||
// HeaderPair values below are PLACEHOLDERS surfaced by the
|
||||
// dashboard; the actual values stamped on the wire come from
|
||||
// the provider record's IdentityHeaderUserID /
|
||||
// IdentityHeaderGroups fields. An empty operator value
|
||||
// disables stamping for that dimension (the inject middleware
|
||||
// already no-ops on empty header names). Defaulting to the
|
||||
// x-bf-dim- family so the values land in Bifrost's
|
||||
// Prometheus/OTEL pipelines when the operator declares the
|
||||
// label names in their client.prometheus_labels config — see
|
||||
// docs.getbifrost.ai/features/telemetry. Operators who use
|
||||
// the always-on x-bf-lh- log-metadata family (no Bifrost-side
|
||||
// declaration required) just edit the inputs.
|
||||
//
|
||||
// Bifrost virtual keys (sk-bf-*) ride Authorization: Bearer.
|
||||
// Operators provision the VK on their Bifrost (UI /
|
||||
// config.json / POST /api/governance/virtual-keys) and paste
|
||||
// the returned sk-bf-... as ${API_KEY}. Pin v1.4+ to avoid
|
||||
// the v1.3.0 x-bf-vk regression (maximhq/bifrost#632).
|
||||
IdentityInjection: &IdentityInjection{
|
||||
HeaderPair: &HeaderPairInjection{
|
||||
EndUserIDHeader: "x-bf-dim-netbird_user_id",
|
||||
TagsHeader: "x-bf-dim-netbird_groups",
|
||||
Customizable: true,
|
||||
},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "cloudflare_ai_gateway",
|
||||
Kind: KindGateway,
|
||||
Name: "Cloudflare AI Gateway",
|
||||
Description: "Cloudflare AI Gateway. Operator pastes the gateway URL (with the upstream provider slug like /openai or /anthropic so the URL sniffer dispatches to the right parser) and a per-gateway authentication token. Recommended setup is BYOK / Stored Keys: Cloudflare manages the upstream provider credential and the gateway token is the only secret NetBird needs.",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "cf-aig-authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#F38020",
|
||||
// ParserID empty: like Bifrost, the proxy's parser-detect
|
||||
// sniffs the URL path. /openai/... contains the OpenAI hint
|
||||
// substrings; /anthropic/v1/messages contains /v1/messages
|
||||
// (matches AnthropicParser). The /compat universal endpoint
|
||||
// also speaks OpenAI shape so OpenAIParser handles it.
|
||||
// Operators who paste a different prefix degrade to no-cost
|
||||
// (skipMissingProvider) but the request still flows.
|
||||
ParserID: "",
|
||||
// cf-aig-metadata is a single header carrying a JSON object;
|
||||
// up to five string/number/boolean values per request. NetBird
|
||||
// occupies two slots (user id + groups CSV) and leaves three
|
||||
// for operator-added context. JSON keys are operator-
|
||||
// customisable so Cloudflare-side log filters can use the
|
||||
// operator's existing label conventions instead of NetBird's
|
||||
// defaults — hence Customizable=true. The dashboard surfaces
|
||||
// the catalog values as placeholders; only the values stored
|
||||
// on the provider record's IdentityHeader* fields land on the
|
||||
// wire (empty operator value = key is omitted from the JSON,
|
||||
// since applyJSONMetadata already skips empty keys).
|
||||
IdentityInjection: &IdentityInjection{
|
||||
JSONMetadata: &JSONMetadataInjection{
|
||||
Header: "cf-aig-metadata",
|
||||
UserKey: "netbird_user_id",
|
||||
GroupsKey: "netbird_groups",
|
||||
Customizable: true,
|
||||
// Cloudflare's docs don't specify a per-value cap;
|
||||
// leaving 0 disables the truncate path. Header-level
|
||||
// constraint is "5 entries max" rather than length.
|
||||
MaxValueLength: 0,
|
||||
},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "vercel_ai_gateway",
|
||||
Kind: KindGateway,
|
||||
Name: "Vercel AI Gateway",
|
||||
Description: "Vercel's unified API for hundreds of models. Single endpoint, OpenAI-compatible body, model dispatch via prefix (openai/..., anthropic/..., google/..., xai/...). Per-user / per-tag attribution lands in Vercel's Custom Reporting API and observability dashboard.",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#000000",
|
||||
// Vercel always speaks OpenAI shape on /v1/chat/completions —
|
||||
// the model prefix in the body picks the upstream provider.
|
||||
// No URL sniffing needed; pin the parser directly.
|
||||
ParserID: "openai",
|
||||
// HeaderPair shape with fixed wire names dictated by Vercel's
|
||||
// Custom Reporting API contract. Customizable=false because
|
||||
// renaming the headers makes Vercel silently stop attributing
|
||||
// — the gateway's reporting endpoint only matches its own
|
||||
// header names. Same fixed-protocol position as LiteLLM.
|
||||
//
|
||||
// Caveats operators should know:
|
||||
// - up to 10 tags total per request (deduped); 11+ → HTTP 400
|
||||
// - each tag must be 1-64 chars
|
||||
// - user up to 256 chars (NetBird user emails fit)
|
||||
// - $0.075 per 1k unique user/tag values written
|
||||
// We don't enforce the caps in the inject middleware today;
|
||||
// operators in groups beyond the 10-tag limit will see Vercel
|
||||
// 400s and need to re-scope their group memberships.
|
||||
IdentityInjection: &IdentityInjection{
|
||||
HeaderPair: &HeaderPairInjection{
|
||||
EndUserIDHeader: "ai-reporting-user",
|
||||
TagsHeader: "ai-reporting-tags",
|
||||
},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "openrouter",
|
||||
Kind: KindGateway,
|
||||
Name: "OpenRouter",
|
||||
Description: "OpenRouter's unified API for hundreds of models. Single endpoint at openrouter.ai/api/v1, OpenAI-compatible body, model dispatch via prefix (anthropic/claude-..., openai/gpt-..., google/gemini-..., etc.). Per-user attribution lands in OpenRouter's analytics via the OpenAI-standard `user` body field; OpenRouter has no groups / tags dimension at request time.",
|
||||
DefaultHost: "openrouter.ai/api/v1",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#6F4FF2",
|
||||
// OpenRouter is single-endpoint OpenAI-shape on /api/v1/chat/completions —
|
||||
// model prefix in the body picks the upstream provider.
|
||||
// Pinning the parser saves URL sniffing.
|
||||
ParserID: "openai",
|
||||
// HeaderPair shape with EndUserIDInBody as the only active
|
||||
// dimension. OpenRouter's per-user attribution is the
|
||||
// OpenAI-standard `user` body field, not a header — and
|
||||
// OpenRouter offers no per-request groups / tags dimension at
|
||||
// all. Customizable=false because the field name is locked by
|
||||
// OpenAI's spec; renaming would just defeat the inject.
|
||||
IdentityInjection: &IdentityInjection{
|
||||
HeaderPair: &HeaderPairInjection{
|
||||
EndUserIDInBody: true,
|
||||
},
|
||||
},
|
||||
// HTTP-Referer + X-OpenRouter-Title surface in OpenRouter's
|
||||
// app rankings and per-app analytics. Operators paste their
|
||||
// own app URL + display name on the provider record so their
|
||||
// requests show under their brand instead of "no app". Both
|
||||
// are static per-deployment, not per-request, hence the
|
||||
// ExtraHeaders mechanism (operator-typed value, stamped on
|
||||
// every request to this provider). Skip X-OpenRouter-Categories
|
||||
// for now — the marketplace-categories dimension is
|
||||
// niche-enough that we'd add it on demand.
|
||||
ExtraHeaders: []ExtraHeader{
|
||||
{Name: "HTTP-Referer"},
|
||||
{Name: "X-OpenRouter-Title"},
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "custom",
|
||||
Kind: KindCustom,
|
||||
Name: "Custom / Self-hosted",
|
||||
Description: "OpenAI-compatible endpoint (vLLM, Ollama, …)",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#9CA3AF",
|
||||
Models: []Model{},
|
||||
},
|
||||
}
|
||||
|
||||
// All returns a copy of the full catalog.
|
||||
func All() []Provider {
|
||||
out := make([]Provider, len(providers))
|
||||
copy(out, providers)
|
||||
return out
|
||||
}
|
||||
|
||||
// Lookup returns the catalog entry with the given id, if any.
|
||||
func Lookup(id string) (Provider, bool) {
|
||||
for _, p := range providers {
|
||||
if p.ID == id {
|
||||
return p, true
|
||||
}
|
||||
}
|
||||
return Provider{}, false
|
||||
}
|
||||
|
||||
// IsKnown reports whether the given id refers to a catalog entry.
|
||||
func IsKnown(id string) bool {
|
||||
_, ok := Lookup(id)
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsVertexPathStyle reports whether a provider uses the Google Vertex AI
|
||||
// request shape — the model is carried in the URL path
|
||||
// (/v1/projects/{p}/locations/{r}/publishers/{pub}/models/{model}:{action})
|
||||
// rather than the body, so the proxy routes it by path instead of by model.
|
||||
func IsVertexPathStyle(providerID string) bool {
|
||||
return providerID == "vertex_ai_api"
|
||||
}
|
||||
|
||||
// IsBedrockPathStyle reports whether a provider uses the AWS Bedrock request
|
||||
// shape — the model is carried in the URL path (/model/{modelId}/{action},
|
||||
// action being invoke, invoke-with-response-stream, converse, or
|
||||
// converse-stream) rather than the body, so the proxy routes it by path.
|
||||
func IsBedrockPathStyle(providerID string) bool {
|
||||
return providerID == "bedrock_api"
|
||||
}
|
||||
|
||||
// ToAPIResponse renders a catalog provider as the API representation.
|
||||
func (p Provider) ToAPIResponse() api.AgentNetworkCatalogProvider {
|
||||
models := make([]api.AgentNetworkCatalogModel, 0, len(p.Models))
|
||||
for _, m := range p.Models {
|
||||
models = append(models, api.AgentNetworkCatalogModel{
|
||||
Id: m.ID,
|
||||
Label: m.Label,
|
||||
InputPer1k: m.InputPer1k,
|
||||
OutputPer1k: m.OutputPer1k,
|
||||
ContextWindow: m.ContextWindow,
|
||||
})
|
||||
}
|
||||
kind := api.AgentNetworkCatalogProviderKindProvider
|
||||
switch p.Kind {
|
||||
case KindGateway:
|
||||
kind = api.AgentNetworkCatalogProviderKindGateway
|
||||
case KindCustom:
|
||||
kind = api.AgentNetworkCatalogProviderKindCustom
|
||||
}
|
||||
resp := api.AgentNetworkCatalogProvider{
|
||||
Id: p.ID,
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
DefaultHost: p.DefaultHost,
|
||||
Kind: kind,
|
||||
AuthHeaderTemplate: p.AuthHeaderTemplate,
|
||||
DefaultContentType: p.DefaultContentType,
|
||||
BrandColor: p.BrandColor,
|
||||
Models: models,
|
||||
}
|
||||
if len(p.ExtraHeaders) > 0 {
|
||||
extras := make([]api.AgentNetworkCatalogExtraHeader, 0, len(p.ExtraHeaders))
|
||||
for _, h := range p.ExtraHeaders {
|
||||
extras = append(extras, api.AgentNetworkCatalogExtraHeader{
|
||||
Name: h.Name,
|
||||
})
|
||||
}
|
||||
resp.ExtraHeaders = &extras
|
||||
}
|
||||
// Surface IdentityInjection so the dashboard can decide whether
|
||||
// to render editable inputs vs. a read-only mappings strip per
|
||||
// shape's customizable flag. HeaderPair (Bifrost) and
|
||||
// JSONMetadata (Cloudflare, Portkey) are mutually exclusive on a
|
||||
// given catalog entry; emit whichever shape is set.
|
||||
if p.IdentityInjection != nil {
|
||||
injection := &api.AgentNetworkCatalogIdentityInjection{}
|
||||
if hp := p.IdentityInjection.HeaderPair; hp != nil {
|
||||
injection.HeaderPair = &api.AgentNetworkCatalogHeaderPairInjection{
|
||||
Customizable: hp.Customizable,
|
||||
EndUserIdHeader: hp.EndUserIDHeader,
|
||||
TagsHeader: hp.TagsHeader,
|
||||
}
|
||||
}
|
||||
if jm := p.IdentityInjection.JSONMetadata; jm != nil {
|
||||
injection.JsonMetadata = &api.AgentNetworkCatalogJSONMetadataInjection{
|
||||
Customizable: jm.Customizable,
|
||||
Header: jm.Header,
|
||||
UserKey: jm.UserKey,
|
||||
GroupsKey: jm.GroupsKey,
|
||||
}
|
||||
}
|
||||
if injection.HeaderPair != nil || injection.JsonMetadata != nil {
|
||||
resp.IdentityInjection = injection
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
// addAccessLogEndpoints registers the read-only, server-side-filtered
|
||||
// agent-network access-log listing and the aggregated usage overview.
|
||||
func (h *handler) addAccessLogEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/access-logs", h.listAccessLogs).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/usage/overview", h.getUsageOverview).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getUsageOverview(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
// Reuse the access-log filter for the shared date/user/group/provider/model
|
||||
// params; pagination/sort/search are irrelevant for an aggregate.
|
||||
var filter types.AgentNetworkAccessLogFilter
|
||||
if err := filter.ParseFromRequest(r); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
// Bound the aggregation window so an unbounded or over-wide query can't load
|
||||
// an account's entire usage history into memory.
|
||||
filter.ApplyUsageOverviewBounds(time.Now())
|
||||
granularity := types.ParseUsageGranularity(r.URL.Query().Get("granularity"))
|
||||
|
||||
buckets, err := h.manager.GetUsageOverview(r.Context(), userAuth.AccountId, userAuth.UserId, filter, granularity)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]api.AgentNetworkUsageBucket, 0, len(buckets))
|
||||
for _, b := range buckets {
|
||||
out = append(out, b.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) listAccessLogs(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var filter types.AgentNetworkAccessLogFilter
|
||||
if err := filter.ParseFromRequest(r); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rows, total, err := h.manager.ListAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId, filter)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
data := make([]api.AgentNetworkAccessLog, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
data = append(data, row.ToAPIResponse())
|
||||
}
|
||||
|
||||
pageSize := filter.GetLimit()
|
||||
totalPages := 0
|
||||
if pageSize > 0 {
|
||||
totalPages = int((total + int64(pageSize) - 1) / int64(pageSize))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, api.AgentNetworkAccessLogsResponse{
|
||||
Data: data,
|
||||
Page: filter.Page,
|
||||
PageSize: pageSize,
|
||||
TotalRecords: int(total),
|
||||
TotalPages: totalPages,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// addBudgetRuleEndpoints registers the account-level budget rule routes.
|
||||
func (h *handler) addBudgetRuleEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/budget-rules", h.getAllBudgetRules).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/budget-rules", h.createBudgetRule).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.getBudgetRule).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.updateBudgetRule).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.deleteBudgetRule).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllBudgetRules(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rules, err := h.manager.GetAllBudgetRules(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*api.AgentNetworkBudgetRule, 0, len(rules))
|
||||
for _, rule := range rules {
|
||||
out = append(out, rule.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getBudgetRule(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ruleID := mux.Vars(r)["ruleId"]
|
||||
if ruleID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
rule, err := h.manager.GetBudgetRule(r.Context(), userAuth.AccountId, userAuth.UserId, ruleID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, rule.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) createBudgetRule(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkBudgetRuleRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateBudgetRule(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rule := types.NewAccountBudgetRule(userAuth.AccountId)
|
||||
rule.FromAPIRequest(&req)
|
||||
|
||||
created, err := h.manager.CreateBudgetRule(r.Context(), userAuth.UserId, rule)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateBudgetRule(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ruleID := mux.Vars(r)["ruleId"]
|
||||
if ruleID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkBudgetRuleRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateBudgetRule(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rule := &types.AccountBudgetRule{ID: ruleID, AccountID: userAuth.AccountId}
|
||||
rule.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdateBudgetRule(r.Context(), userAuth.UserId, rule)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteBudgetRule(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ruleID := mux.Vars(r)["ruleId"]
|
||||
if ruleID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeleteBudgetRule(r.Context(), userAuth.AccountId, userAuth.UserId, ruleID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
// validateBudgetRule rejects malformed budget rules. It reuses the policy limit
|
||||
// validation since the cap shape is identical, and rejects empty target entries.
|
||||
func validateBudgetRule(req *api.AgentNetworkBudgetRuleRequest) error {
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
if req.TargetGroups != nil {
|
||||
for _, id := range *req.TargetGroups {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "target_groups must not contain empty entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.TargetUsers != nil {
|
||||
for _, id := range *req.TargetUsers {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "target_users must not contain empty entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
return validatePolicyLimits(req.Limits)
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// TestBudgetRuleHandler_RoundTrip seeds a budget rule via the store and asserts
|
||||
// the GET wire shape carries targets and the reused PolicyLimits cap shape. The
|
||||
// create/update/delete success paths go through accountManager.StoreEvent which
|
||||
// this fixture doesn't wire — they are covered by the manager-level no-mock
|
||||
// test (TestAgentNetwork_BudgetRuleCRUD_RealManager).
|
||||
func TestBudgetRuleHandler_RoundTrip(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
rule := &agentNetworkTypes.AccountBudgetRule{
|
||||
ID: "ainbud_test",
|
||||
AccountID: testAccountID,
|
||||
Name: "org-monthly",
|
||||
Enabled: true,
|
||||
TargetGroups: []string{"grp-eng"},
|
||||
TargetUsers: []string{"user-alice"},
|
||||
Limits: agentNetworkTypes.PolicyLimits{
|
||||
TokenLimit: agentNetworkTypes.PolicyTokenLimit{Enabled: true, GroupCap: 100000, UserCap: 10000, WindowSeconds: 2_592_000},
|
||||
BudgetLimit: agentNetworkTypes.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 500, WindowSeconds: 2_592_000},
|
||||
},
|
||||
}
|
||||
require.NoError(t, f.store.SaveAgentNetworkBudgetRule(context.Background(), rule))
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/budget-rules/"+rule.ID, "")
|
||||
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
|
||||
|
||||
var got api.AgentNetworkBudgetRule
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
assert.Equal(t, "org-monthly", got.Name, "name must round-trip")
|
||||
assert.Equal(t, []string{"grp-eng"}, got.TargetGroups, "target groups must round-trip")
|
||||
assert.Equal(t, []string{"user-alice"}, got.TargetUsers, "target users must round-trip")
|
||||
assert.Equal(t, int64(100000), got.Limits.TokenLimit.GroupCap, "token group cap must round-trip")
|
||||
assert.Equal(t, int64(2_592_000), got.Limits.BudgetLimit.WindowSeconds, "budget window must round-trip")
|
||||
}
|
||||
|
||||
// TestBudgetRuleHandler_ListReturnsArray asserts the list endpoint returns a
|
||||
// JSON array (never null) for an account with no rules.
|
||||
func TestBudgetRuleHandler_ListReturnsArray(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/budget-rules", "")
|
||||
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
|
||||
assert.Equal(t, "[]", trimSpace(rec.Body.String()), "empty account must return an empty array, not null")
|
||||
}
|
||||
|
||||
// TestBudgetRuleHandler_RejectsMissingName covers the validation path (which
|
||||
// runs before the manager call, so it works without a wired accountManager).
|
||||
func TestBudgetRuleHandler_RejectsMissingName(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
body := `{
|
||||
"name": "",
|
||||
"limits": {
|
||||
"token_limit": {"enabled": false, "group_cap": 0, "user_cap": 0, "window_seconds": 0},
|
||||
"budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0}
|
||||
}
|
||||
}`
|
||||
rec := f.do(t, http.MethodPost, "/agent-network/budget-rules", body)
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code,
|
||||
"missing name must be rejected as a validation error (not a route/auth 4xx): got %d body=%s", rec.Code, rec.Body.String())
|
||||
assert.Contains(t, rec.Body.String(), "name",
|
||||
"rejection body must name the offending field, proving the validation path: %s", rec.Body.String())
|
||||
}
|
||||
|
||||
// TestBudgetRuleHandler_RejectsSubMinuteWindow proves budget rules reuse the
|
||||
// policy-limit validation (enabled limit needs window >= 60s).
|
||||
func TestBudgetRuleHandler_RejectsSubMinuteWindow(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
body := `{
|
||||
"name": "bad-window",
|
||||
"limits": {
|
||||
"token_limit": {"enabled": true, "group_cap": 1000, "user_cap": 0, "window_seconds": 30},
|
||||
"budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0}
|
||||
}
|
||||
}`
|
||||
rec := f.do(t, http.MethodPost, "/agent-network/budget-rules", body)
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code,
|
||||
"sub-minute window must be rejected as a validation error (not a route/auth 4xx): got %d body=%s", rec.Code, rec.Body.String())
|
||||
assert.Contains(t, rec.Body.String(), "window_seconds",
|
||||
"rejection body must name the offending window_seconds field, proving the validation path: %s", rec.Body.String())
|
||||
}
|
||||
|
||||
// TestSettingsHandler_GetExposesCollectionToggles asserts the GET settings wire
|
||||
// shape carries the account-level collection toggles after a store seed.
|
||||
func TestSettingsHandler_GetExposesCollectionToggles(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
require.NoError(t, f.store.SaveAgentNetworkSettings(context.Background(), &agentNetworkTypes.Settings{
|
||||
AccountID: testAccountID,
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
Subdomain: "violet",
|
||||
EnableLogCollection: true,
|
||||
EnablePromptCollection: true,
|
||||
RedactPii: false,
|
||||
}))
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/settings", "")
|
||||
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
|
||||
|
||||
var got api.AgentNetworkSettings
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
assert.True(t, got.EnableLogCollection, "log collection toggle must surface on the wire")
|
||||
assert.True(t, got.EnablePromptCollection, "prompt collection toggle must surface on the wire")
|
||||
assert.False(t, got.RedactPii, "redact toggle must surface its false value")
|
||||
assert.Equal(t, "violet.eu.proxy.netbird.io", got.Endpoint, "endpoint stays computed from immutable cluster+subdomain")
|
||||
}
|
||||
|
||||
func trimSpace(s string) string {
|
||||
for len(s) > 0 && (s[len(s)-1] == '\n' || s[len(s)-1] == ' ' || s[len(s)-1] == '\t' || s[len(s)-1] == '\r') {
|
||||
s = s[:len(s)-1]
|
||||
}
|
||||
for len(s) > 0 && (s[0] == '\n' || s[0] == ' ' || s[0] == '\t' || s[0] == '\r') {
|
||||
s = s[1:]
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
// addConsumptionEndpoints registers the read-only Agent Network
|
||||
// consumption listing — backs the dashboard's basic counter view.
|
||||
func (h *handler) addConsumptionEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/consumption", h.listConsumption).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) listConsumption(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.manager.ListConsumption(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]api.AgentNetworkConsumption, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, consumptionToAPI(row))
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func consumptionToAPI(c *types.Consumption) api.AgentNetworkConsumption {
|
||||
windowStart := c.WindowStartUTC
|
||||
updatedAt := c.UpdatedAt
|
||||
return api.AgentNetworkConsumption{
|
||||
DimensionKind: api.AgentNetworkConsumptionDimensionKind(c.DimensionKind),
|
||||
DimensionId: c.DimensionID,
|
||||
WindowSeconds: c.WindowSeconds,
|
||||
WindowStartUtc: windowStart,
|
||||
TokensInput: c.TokensInput,
|
||||
TokensOutput: c.TokensOutput,
|
||||
CostUsd: c.CostUSD,
|
||||
UpdatedAt: &updatedAt,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// addGuardrailEndpoints registers all Agent Network guardrail routes.
|
||||
func (h *handler) addGuardrailEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/guardrails", h.getAllGuardrails).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/guardrails", h.createGuardrail).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.getGuardrail).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.updateGuardrail).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.deleteGuardrail).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllGuardrails(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrails, err := h.manager.GetAllGuardrails(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*api.AgentNetworkGuardrail, 0, len(guardrails))
|
||||
for _, g := range guardrails {
|
||||
out = append(out, g.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getGuardrail(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrailID := mux.Vars(r)["guardrailId"]
|
||||
if guardrailID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrail, err := h.manager.GetGuardrail(r.Context(), userAuth.AccountId, userAuth.UserId, guardrailID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, guardrail.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) createGuardrail(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkGuardrailRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateGuardrail(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrail := types.NewGuardrail(userAuth.AccountId)
|
||||
guardrail.FromAPIRequest(&req)
|
||||
|
||||
created, err := h.manager.CreateGuardrail(r.Context(), userAuth.UserId, guardrail)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateGuardrail(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrailID := mux.Vars(r)["guardrailId"]
|
||||
if guardrailID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkGuardrailRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateGuardrail(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrail := &types.Guardrail{
|
||||
ID: guardrailID,
|
||||
AccountID: userAuth.AccountId,
|
||||
}
|
||||
guardrail.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdateGuardrail(r.Context(), userAuth.UserId, guardrail)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteGuardrail(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
guardrailID := mux.Vars(r)["guardrailId"]
|
||||
if guardrailID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeleteGuardrail(r.Context(), userAuth.AccountId, userAuth.UserId, guardrailID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func validateGuardrail(req *api.AgentNetworkGuardrailRequest) error {
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
|
||||
c := req.Checks
|
||||
if c.ModelAllowlist.Enabled {
|
||||
for _, id := range c.ModelAllowlist.Models {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "model_allowlist.models must not contain empty entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,256 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
const (
|
||||
testAccountID = "acc-1"
|
||||
testUserID = "user-bob"
|
||||
)
|
||||
|
||||
// agentNetworkHandlerFixture builds a real agentnetwork.Manager with
|
||||
// a sqlite store and an always-allow permissions mock, then exposes
|
||||
// the HTTP handlers via a gorilla router. Tests issue requests
|
||||
// through httptest and assert on the wire shape — the same path the
|
||||
// dashboard exercises.
|
||||
type agentNetworkHandlerFixture struct {
|
||||
store store.Store
|
||||
manager agentnetwork.Manager
|
||||
router *mux.Router
|
||||
}
|
||||
|
||||
func newAgentNetworkHandlerFixture(t *testing.T) *agentNetworkHandlerFixture {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("sqlite store not properly supported on Windows yet")
|
||||
}
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(nbtypes.SqliteStoreEngine))
|
||||
|
||||
st, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
perms := permissions.NewMockManager(ctrl)
|
||||
// Always-allow: the handler tests are about wire shape, not
|
||||
// authz. Authz is covered by the manager's own tests.
|
||||
perms.EXPECT().
|
||||
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(true, context.Background(), nil).
|
||||
AnyTimes()
|
||||
|
||||
manager := agentnetwork.NewManager(st, perms, nil, nil)
|
||||
h := &handler{manager: manager}
|
||||
|
||||
router := mux.NewRouter()
|
||||
h.addPolicyEndpoints(router)
|
||||
h.addConsumptionEndpoints(router)
|
||||
h.addBudgetRuleEndpoints(router)
|
||||
h.addSettingsEndpoints(router)
|
||||
|
||||
return &agentNetworkHandlerFixture{
|
||||
store: st,
|
||||
manager: manager,
|
||||
router: router,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *agentNetworkHandlerFixture) do(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
var reader io.Reader
|
||||
if body != "" {
|
||||
reader = strings.NewReader(body)
|
||||
}
|
||||
req := httptest.NewRequest(method, path, reader)
|
||||
if body != "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
f.router.ServeHTTP(rec, req)
|
||||
return rec
|
||||
}
|
||||
|
||||
// seedProvider persists a minimal provider record so policy create
|
||||
// passes the manager's destination_provider_ids existence check.
|
||||
func (f *agentNetworkHandlerFixture) seedProvider(t *testing.T, id string) {
|
||||
t.Helper()
|
||||
require.NoError(t, f.store.SaveAgentNetworkProvider(context.Background(), &agentNetworkTypes.Provider{
|
||||
ID: id,
|
||||
AccountID: testAccountID,
|
||||
ProviderID: "openai_api",
|
||||
Name: "test-" + id,
|
||||
UpstreamURL: "https://api.openai.com",
|
||||
APIKey: "sk-test",
|
||||
Enabled: true,
|
||||
SessionPrivateKey: "test-priv-key",
|
||||
SessionPublicKey: "test-pub-key",
|
||||
}))
|
||||
}
|
||||
|
||||
// TestPolicyHandler_WindowSecondsRoundTrip ports bash 10 to Go:
|
||||
// assert that a policy with window_seconds on both Token + Budget
|
||||
// halves round-trips through GET unchanged AND that legacy
|
||||
// window_hours / window_days are absent from the JSON response. We
|
||||
// seed the policy directly via the store rather than POST-ing
|
||||
// because the create path goes through the manager's
|
||||
// accountManager.StoreEvent which we don't wire in this fixture; the
|
||||
// on-wire shape is what matters here, and the POST validation path
|
||||
// is covered separately by the RejectsSubMinuteWindow test.
|
||||
func TestPolicyHandler_WindowSecondsRoundTrip(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
policy := &agentNetworkTypes.Policy{
|
||||
ID: "ainpol_test",
|
||||
AccountID: testAccountID,
|
||||
Name: "round-trip",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-engineers"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: agentNetworkTypes.PolicyLimits{
|
||||
TokenLimit: agentNetworkTypes.PolicyTokenLimit{Enabled: true, GroupCap: 10000, UserCap: 5000, WindowSeconds: 86_400},
|
||||
BudgetLimit: agentNetworkTypes.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 10.0, UserCapUsd: 2.5, WindowSeconds: 2_592_000},
|
||||
},
|
||||
}
|
||||
require.NoError(t, f.store.SaveAgentNetworkPolicy(context.Background(), policy))
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/policies/"+policy.ID, "")
|
||||
require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String())
|
||||
|
||||
var got api.AgentNetworkPolicy
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
assert.Equal(t, int64(86_400), got.Limits.TokenLimit.WindowSeconds, "token_limit.window_seconds must round-trip")
|
||||
assert.Equal(t, int64(2_592_000), got.Limits.BudgetLimit.WindowSeconds, "budget_limit.window_seconds must round-trip")
|
||||
|
||||
// Legacy field names must NOT appear in the response — would
|
||||
// signal that the management server is still emitting the old
|
||||
// shape and would fool a v1 dashboard into rendering days/hours.
|
||||
assert.NotContains(t, rec.Body.String(), "window_hours",
|
||||
"legacy window_hours field must be absent from the on-wire response")
|
||||
assert.NotContains(t, rec.Body.String(), "window_days",
|
||||
"legacy window_days field must be absent from the on-wire response")
|
||||
}
|
||||
|
||||
// TestPolicyHandler_RejectsSubMinuteWindow ports bash 20 to Go: an
|
||||
// enabled limit with window_seconds < 60 must surface as a 4xx
|
||||
// because anything finer than per-minute produces an untenable
|
||||
// volume of consumption rows for a feature whose value comes from
|
||||
// per-window cap enforcement.
|
||||
func TestPolicyHandler_RejectsSubMinuteWindow(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
f.seedProvider(t, "prov-1")
|
||||
|
||||
body := `{
|
||||
"name": "sub-minute-window",
|
||||
"enabled": true,
|
||||
"source_groups": ["grp-engineers"],
|
||||
"destination_provider_ids": ["prov-1"],
|
||||
"guardrail_ids": [],
|
||||
"limits": {
|
||||
"token_limit": {"enabled": true, "group_cap": 10000, "user_cap": 5000, "window_seconds": 30},
|
||||
"budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0}
|
||||
}
|
||||
}`
|
||||
rec := f.do(t, http.MethodPost, "/agent-network/policies", body)
|
||||
// 422 specifically (InvalidArgument) proves the window-validation path —
|
||||
// a route miss would be 404 and an auth failure 403, so a generic 4xx
|
||||
// would let those false-pass.
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code,
|
||||
"enabled token_limit with window_seconds<60 must be rejected as a validation error: got %d body=%s", rec.Code, rec.Body.String())
|
||||
assert.Contains(t, rec.Body.String(), "window_seconds",
|
||||
"rejection body must name the offending window_seconds field, proving it's the validation path: %s", rec.Body.String())
|
||||
}
|
||||
|
||||
// TestConsumptionHandler_EmptyAccountReturnsArray ports bash 30 to
|
||||
// Go: GET /agent-network/consumption on a clean account always
|
||||
// returns a JSON array (possibly empty), never a 404 / 500. The
|
||||
// dashboard depends on this shape to render its empty state.
|
||||
func TestConsumptionHandler_EmptyAccountReturnsArray(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/consumption", "")
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var rows []api.AgentNetworkConsumption
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rows),
|
||||
"response must always be a JSON array — even when empty: %s", rec.Body.String())
|
||||
assert.Empty(t, rows)
|
||||
}
|
||||
|
||||
// TestConsumptionHandler_PopulatedAccountListsRows mirrors the
|
||||
// /consumption read after a few RecordConsumption calls. Validates
|
||||
// the wire shape carries every field the dashboard reads (dim_kind,
|
||||
// dim_id, window_seconds, window_start_utc, tokens, cost_usd) and
|
||||
// rows are ordered window-newest-first.
|
||||
func TestConsumptionHandler_PopulatedAccountListsRows(t *testing.T) {
|
||||
f := newAgentNetworkHandlerFixture(t)
|
||||
|
||||
require.NoError(t, f.manager.RecordConsumption(
|
||||
context.Background(), testAccountID,
|
||||
agentNetworkTypes.DimensionGroup, "grp-engineers",
|
||||
86_400, 100, 50, 0.0125,
|
||||
))
|
||||
require.NoError(t, f.manager.RecordConsumption(
|
||||
context.Background(), testAccountID,
|
||||
agentNetworkTypes.DimensionUser, testUserID,
|
||||
86_400, 100, 50, 0.0125,
|
||||
))
|
||||
|
||||
rec := f.do(t, http.MethodGet, "/agent-network/consumption", "")
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var rows []api.AgentNetworkConsumption
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rows))
|
||||
require.Len(t, rows, 2, "two RecordConsumption calls must yield two rows")
|
||||
|
||||
// Index by dim_kind so we can assert the full wire shape of each row,
|
||||
// including the dimension id and the aligned window start the dashboard
|
||||
// keys on. Both rows share totals and window.
|
||||
byKind := make(map[string]api.AgentNetworkConsumption, len(rows))
|
||||
for _, row := range rows {
|
||||
assert.Equal(t, int64(100), row.TokensInput)
|
||||
assert.Equal(t, int64(50), row.TokensOutput)
|
||||
assert.InDelta(t, 0.0125, row.CostUsd, 1e-9)
|
||||
assert.Equal(t, int64(86_400), row.WindowSeconds)
|
||||
assert.False(t, row.WindowStartUtc.IsZero(), "window_start_utc must be set on every row")
|
||||
byKind[string(row.DimensionKind)] = row
|
||||
}
|
||||
|
||||
groupRow, ok := byKind["group"]
|
||||
require.True(t, ok, "group dimension must surface")
|
||||
assert.Equal(t, "grp-engineers", groupRow.DimensionId, "group row must carry the source group id as dimension_id")
|
||||
|
||||
userRow, ok := byKind["user"]
|
||||
require.True(t, ok, "user dimension must surface")
|
||||
assert.Equal(t, testUserID, userRow.DimensionId, "user row must carry the user id as dimension_id")
|
||||
|
||||
// Both rows fall in the same aligned window (same length, recorded
|
||||
// together), so window_start_utc must match across them.
|
||||
assert.Equal(t, groupRow.WindowStartUtc, userRow.WindowStartUtc,
|
||||
"rows recorded in the same window must share the aligned window_start_utc")
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// minWindowSeconds is the floor enforced on enabled token / budget
|
||||
// limit windows. One minute is short enough for fine-grained burst
|
||||
// control without producing untenable consumption-row volume at scale.
|
||||
const minWindowSeconds int64 = 60
|
||||
|
||||
// addPolicyEndpoints registers all Agent Network policy routes on the
|
||||
// shared handler.
|
||||
func (h *handler) addPolicyEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/policies", h.getAllPolicies).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/policies", h.createPolicy).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/policies/{policyId}", h.getPolicy).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/policies/{policyId}", h.updatePolicy).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/policies/{policyId}", h.deletePolicy).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policies, err := h.manager.GetAllPolicies(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*api.AgentNetworkPolicy, 0, len(policies))
|
||||
for _, p := range policies {
|
||||
out = append(out, p.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policyID := mux.Vars(r)["policyId"]
|
||||
if policyID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
policy, err := h.manager.GetPolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policyID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, policy.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkPolicyRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validatePolicy(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policy := types.NewPolicy(userAuth.AccountId)
|
||||
policy.FromAPIRequest(&req)
|
||||
|
||||
created, err := h.manager.CreatePolicy(r.Context(), userAuth.UserId, policy)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policyID := mux.Vars(r)["policyId"]
|
||||
if policyID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkPolicyRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validatePolicy(&req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policy := &types.Policy{
|
||||
ID: policyID,
|
||||
AccountID: userAuth.AccountId,
|
||||
}
|
||||
policy.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdatePolicy(r.Context(), userAuth.UserId, policy)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policyID := mux.Vars(r)["policyId"]
|
||||
if policyID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeletePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policyID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func validatePolicy(req *api.AgentNetworkPolicyRequest) error {
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
if len(req.SourceGroups) == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "source_groups must contain at least one group id")
|
||||
}
|
||||
for _, id := range req.SourceGroups {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "source_groups must not contain empty entries")
|
||||
}
|
||||
}
|
||||
if len(req.DestinationProviderIds) == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "destination_provider_ids must contain at least one provider id")
|
||||
}
|
||||
for _, id := range req.DestinationProviderIds {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "destination_provider_ids must not contain empty entries")
|
||||
}
|
||||
}
|
||||
if req.GuardrailIds != nil {
|
||||
for _, id := range *req.GuardrailIds {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "guardrail_ids must not contain empty entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.Limits != nil {
|
||||
if err := validatePolicyLimits(*req.Limits); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePolicyLimits(l api.AgentNetworkPolicyLimits) error {
|
||||
if l.TokenLimit.Enabled {
|
||||
if l.TokenLimit.WindowSeconds < minWindowSeconds {
|
||||
return status.Errorf(status.InvalidArgument, "limits.token_limit.window_seconds must be at least %d (one minute) when enabled", minWindowSeconds)
|
||||
}
|
||||
if l.TokenLimit.GroupCap < 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.token_limit.group_cap must not be negative")
|
||||
}
|
||||
if l.TokenLimit.UserCap < 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.token_limit.user_cap must not be negative")
|
||||
}
|
||||
if l.TokenLimit.GroupCap == 0 && l.TokenLimit.UserCap == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.token_limit requires group_cap or user_cap to be greater than zero when enabled")
|
||||
}
|
||||
}
|
||||
if l.BudgetLimit.Enabled {
|
||||
if l.BudgetLimit.WindowSeconds < minWindowSeconds {
|
||||
return status.Errorf(status.InvalidArgument, "limits.budget_limit.window_seconds must be at least %d (one minute) when enabled", minWindowSeconds)
|
||||
}
|
||||
if l.BudgetLimit.GroupCapUsd < 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.budget_limit.group_cap_usd must not be negative")
|
||||
}
|
||||
if l.BudgetLimit.UserCapUsd < 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.budget_limit.user_cap_usd must not be negative")
|
||||
}
|
||||
if l.BudgetLimit.GroupCapUsd == 0 && l.BudgetLimit.UserCapUsd == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "limits.budget_limit requires group_cap_usd or user_cap_usd to be greater than zero when enabled")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
// Package handlers serves the Agent Network HTTP API.
|
||||
//
|
||||
// All persistence is delegated to agentnetwork.Manager so this layer only
|
||||
// translates between the wire format (api.AgentNetworkProvider*) and the
|
||||
// domain types.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/catalog"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
manager agentnetwork.Manager
|
||||
}
|
||||
|
||||
// RegisterEndpoints registers all Agent Network routes.
|
||||
func RegisterEndpoints(manager agentnetwork.Manager, router *mux.Router) {
|
||||
h := &handler{manager: manager}
|
||||
router.HandleFunc("/agent-network/catalog/providers", h.getCatalogProviders).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers", h.getAllProviders).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers", h.createProvider).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers/{providerId}", h.getProvider).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers/{providerId}", h.updateProvider).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/providers/{providerId}", h.deleteProvider).Methods("DELETE", "OPTIONS")
|
||||
h.addPolicyEndpoints(router)
|
||||
h.addGuardrailEndpoints(router)
|
||||
h.addSettingsEndpoints(router)
|
||||
h.addConsumptionEndpoints(router)
|
||||
h.addAccessLogEndpoints(router)
|
||||
h.addBudgetRuleEndpoints(router)
|
||||
}
|
||||
|
||||
func (h *handler) getCatalogProviders(w http.ResponseWriter, r *http.Request) {
|
||||
if _, err := nbcontext.GetUserAuthFromContext(r.Context()); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
entries := catalog.All()
|
||||
out := make([]api.AgentNetworkCatalogProvider, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
out = append(out, e.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getAllProviders(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.manager.GetAllProviders(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*api.AgentNetworkProvider, 0, len(providers))
|
||||
for _, p := range providers {
|
||||
out = append(out, p.ToAPIResponse())
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, out)
|
||||
}
|
||||
|
||||
func (h *handler) getProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
providerID := mux.Vars(r)["providerId"]
|
||||
if providerID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.manager.GetProvider(r.Context(), userAuth.AccountId, userAuth.UserId, providerID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, provider.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) createProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkProviderRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validate(&req, true); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
provider := types.NewProvider(userAuth.AccountId)
|
||||
provider.FromAPIRequest(&req)
|
||||
|
||||
bootstrapCluster := ""
|
||||
if req.BootstrapCluster != nil {
|
||||
bootstrapCluster = *req.BootstrapCluster
|
||||
}
|
||||
|
||||
created, err := h.manager.CreateProvider(r.Context(), userAuth.UserId, provider, bootstrapCluster)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, created.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
providerID := mux.Vars(r)["providerId"]
|
||||
if providerID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkProviderRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validate(&req, false); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
provider := &types.Provider{
|
||||
ID: providerID,
|
||||
AccountID: userAuth.AccountId,
|
||||
}
|
||||
provider.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdateProvider(r.Context(), userAuth.UserId, provider)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
providerID := mux.Vars(r)["providerId"]
|
||||
if providerID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeleteProvider(r.Context(), userAuth.AccountId, userAuth.UserId, providerID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func validate(req *api.AgentNetworkProviderRequest, requireAPIKey bool) error {
|
||||
if strings.TrimSpace(req.ProviderId) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "provider_id is required")
|
||||
}
|
||||
if !catalog.IsKnown(req.ProviderId) {
|
||||
return status.Errorf(status.InvalidArgument, "provider_id %q is not a known catalog provider", req.ProviderId)
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
if strings.TrimSpace(req.UpstreamUrl) == "" {
|
||||
return status.Errorf(status.InvalidArgument, "upstream_url is required")
|
||||
}
|
||||
u, err := url.Parse(strings.TrimSpace(req.UpstreamUrl))
|
||||
if err != nil || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") {
|
||||
return status.Errorf(status.InvalidArgument, "upstream_url must be a full http(s) URL")
|
||||
}
|
||||
if requireAPIKey && (req.ApiKey == nil || strings.TrimSpace(*req.ApiKey) == "") {
|
||||
return status.Errorf(status.InvalidArgument, "api_key is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// addSettingsEndpoints registers the Agent Network settings routes. The
|
||||
// settings row is bootstrapped server-side on first provider create; GET reads
|
||||
// it and PUT updates the mutable collection toggles (cluster/subdomain stay
|
||||
// immutable).
|
||||
func (h *handler) addSettingsEndpoints(router *mux.Router) {
|
||||
router.HandleFunc("/agent-network/settings", h.getSettings).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/agent-network/settings", h.updateSettings).Methods("PUT", "OPTIONS")
|
||||
}
|
||||
|
||||
// updateSettings applies the collection toggles to the account's settings row.
|
||||
func (h *handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.AgentNetworkSettingsRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings := &types.Settings{AccountID: userAuth.AccountId}
|
||||
settings.FromAPIRequest(&req)
|
||||
|
||||
updated, err := h.manager.UpdateSettings(r.Context(), userAuth.UserId, settings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse())
|
||||
}
|
||||
|
||||
// getSettings returns the account's agent-network settings. The settings
|
||||
// row is bootstrapped on first provider create, so freshly-onboarded
|
||||
// accounts have nothing to read. Rather than 404-ing in that case (which
|
||||
// the dashboard would have to special-case), return a JSON null with 200
|
||||
// so consumers can branch on the body alone.
|
||||
func (h *handler) getSettings(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.manager.GetSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
var sErr *status.Error
|
||||
if errors.As(err, &sErr) && sErr.Type() == status.NotFound {
|
||||
util.WriteJSONObject(r.Context(), w, nil)
|
||||
return
|
||||
}
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, settings.ToAPIResponse())
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
// Package labelgen produces DNS-safe Agent Network subdomain labels.
|
||||
package labelgen
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// pickAttempts caps the random retries before falling back to the
|
||||
// suffixed form. Eight is a soft compromise: with a near-empty taken
|
||||
// set the very first pick almost always succeeds; when the wordlist is
|
||||
// densely populated the fallback eventually fires anyway.
|
||||
const pickAttempts = 8
|
||||
|
||||
var (
|
||||
dedupOnce sync.Once
|
||||
uniqWords []string
|
||||
)
|
||||
|
||||
// uniqueWords returns the wordlist deduplicated and sorted for
|
||||
// deterministic exhaustion behaviour. Lazy-built once per process.
|
||||
func uniqueWords() []string {
|
||||
dedupOnce.Do(func() {
|
||||
seen := make(map[string]struct{}, len(words))
|
||||
uniqWords = make([]string, 0, len(words))
|
||||
for _, w := range words {
|
||||
if _, ok := seen[w]; ok {
|
||||
continue
|
||||
}
|
||||
seen[w] = struct{}{}
|
||||
uniqWords = append(uniqWords, w)
|
||||
}
|
||||
sort.Strings(uniqWords)
|
||||
})
|
||||
return uniqWords
|
||||
}
|
||||
|
||||
// PickUnique selects a label not already in `taken`. It tries up to
|
||||
// pickAttempts random picks; on exhaustion it scans the deduplicated
|
||||
// wordlist for any remaining free entry, and if none is left appends
|
||||
// `-<fallbackSuffix>` to a deterministic word and returns. The caller
|
||||
// is responsible for seeding rng (math/rand).
|
||||
func PickUnique(rng *rand.Rand, taken map[string]struct{}, fallbackSuffix string) string {
|
||||
pool := uniqueWords()
|
||||
if len(pool) == 0 {
|
||||
return fallbackSuffix
|
||||
}
|
||||
|
||||
for i := 0; i < pickAttempts; i++ {
|
||||
w := pool[rng.Intn(len(pool))]
|
||||
if _, ok := taken[w]; !ok {
|
||||
return w
|
||||
}
|
||||
}
|
||||
|
||||
for _, w := range pool {
|
||||
if _, ok := taken[w]; !ok {
|
||||
return w
|
||||
}
|
||||
}
|
||||
|
||||
w := pool[rng.Intn(len(pool))]
|
||||
return fmt.Sprintf("%s-%s", w, fallbackSuffix)
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package labelgen
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestPickUnique_DeterministicWithSeededRng locks the property the
|
||||
// caller relies on: same seed + same taken set → same pick. Without
|
||||
// that, the bootstrap flow can't reproduce a label across retries.
|
||||
func TestPickUnique_DeterministicWithSeededRng(t *testing.T) {
|
||||
taken := map[string]struct{}{}
|
||||
|
||||
rngA := rand.New(rand.NewSource(42))
|
||||
rngB := rand.New(rand.NewSource(42))
|
||||
|
||||
a := PickUnique(rngA, taken, "abcd")
|
||||
b := PickUnique(rngB, taken, "abcd")
|
||||
|
||||
assert.Equal(t, a, b, "Same seed and taken set must produce identical pick")
|
||||
}
|
||||
|
||||
// TestPickUnique_AvoidsTakenWordsWhenMostAreReserved seeds taken with
|
||||
// every word in the pool except a handful and confirms PickUnique
|
||||
// finds one of the remaining free entries instead of returning the
|
||||
// fallback form.
|
||||
func TestPickUnique_AvoidsTakenWordsWhenMostAreReserved(t *testing.T) {
|
||||
pool := uniqueWords()
|
||||
require.NotEmpty(t, pool, "wordlist must be populated for the test to mean anything")
|
||||
|
||||
free := map[string]struct{}{
|
||||
pool[0]: {},
|
||||
pool[len(pool)/2]: {},
|
||||
pool[len(pool)-1]: {},
|
||||
}
|
||||
|
||||
taken := make(map[string]struct{}, len(pool))
|
||||
for _, w := range pool {
|
||||
if _, ok := free[w]; ok {
|
||||
continue
|
||||
}
|
||||
taken[w] = struct{}{}
|
||||
}
|
||||
|
||||
rng := rand.New(rand.NewSource(7))
|
||||
got := PickUnique(rng, taken, "abcd")
|
||||
|
||||
_, isFree := free[got]
|
||||
assert.True(t, isFree, "PickUnique must return one of the free words; got %q", got)
|
||||
assert.NotContains(t, got, "-", "Free pick must not be the suffix fallback form")
|
||||
}
|
||||
|
||||
// TestPickUnique_FallsBackWhenAllReserved exhausts the pool and
|
||||
// confirms PickUnique appends the supplied suffix instead of
|
||||
// returning a duplicate.
|
||||
func TestPickUnique_FallsBackWhenAllReserved(t *testing.T) {
|
||||
pool := uniqueWords()
|
||||
|
||||
taken := make(map[string]struct{}, len(pool))
|
||||
for _, w := range pool {
|
||||
taken[w] = struct{}{}
|
||||
}
|
||||
|
||||
rng := rand.New(rand.NewSource(99))
|
||||
got := PickUnique(rng, taken, "abcd")
|
||||
|
||||
assert.True(t, strings.HasSuffix(got, "-abcd"), "Exhausted pool must produce <word>-<suffix>; got %q", got)
|
||||
|
||||
prefix := strings.TrimSuffix(got, "-abcd")
|
||||
found := false
|
||||
for _, w := range pool {
|
||||
if w == prefix {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Fallback prefix must be drawn from the wordlist; got %q", prefix)
|
||||
}
|
||||
|
||||
// TestUniqueWords_DropsDuplicates guards against authoring slips in
|
||||
// words.go: every entry must be unique and DNS-safe.
|
||||
func TestUniqueWords_DropsDuplicates(t *testing.T) {
|
||||
pool := uniqueWords()
|
||||
seen := make(map[string]struct{}, len(pool))
|
||||
for _, w := range pool {
|
||||
_, dup := seen[w]
|
||||
assert.False(t, dup, "Duplicate entry %q in deduplicated pool", w)
|
||||
seen[w] = struct{}{}
|
||||
assert.GreaterOrEqual(t, len(w), 4, "Word %q is shorter than 4 chars", w)
|
||||
assert.LessOrEqual(t, len(w), 12, "Word %q is longer than 12 chars", w)
|
||||
for _, r := range w {
|
||||
ok := r >= 'a' && r <= 'z'
|
||||
assert.True(t, ok, "Word %q contains non-lowercase-ASCII rune %q", w, r)
|
||||
}
|
||||
}
|
||||
assert.GreaterOrEqual(t, len(pool), 500, "Pool must contain at least 500 unique words")
|
||||
}
|
||||
136
management/internals/modules/agentnetwork/labelgen/words.go
Normal file
136
management/internals/modules/agentnetwork/labelgen/words.go
Normal file
@@ -0,0 +1,136 @@
|
||||
// Package labelgen produces DNS-safe Agent Network subdomain labels.
|
||||
//
|
||||
// The wordlist below is a curated subset drawn from public-domain
|
||||
// nature / common-noun pools (e.g. EFF's diceware lists). Every entry
|
||||
// is lowercase ASCII, 4–12 chars, no hyphens, no digits, and was
|
||||
// hand-checked to avoid offensive, brand, or region-specific terms.
|
||||
package labelgen
|
||||
|
||||
// words is the pool PickUnique selects from. The slice is intentionally
|
||||
// not sorted — random picks distribute across the list naturally.
|
||||
var words = []string{
|
||||
"acorn", "adobe", "agate", "alder", "almond", "alpine", "amber", "amethyst",
|
||||
"anchor", "antler", "apple", "apricot", "arcade", "arctic", "arrow", "ashen",
|
||||
"aspen", "atlas", "atom", "aurora", "autumn", "azure",
|
||||
"badger", "bamboo", "banana", "banjo", "barley", "barn", "basalt", "basil",
|
||||
"basin", "bayou", "beach", "beacon", "beaver", "beech", "beetle", "berry",
|
||||
"birch", "bison", "blossom", "blue", "bobcat", "bonsai", "boulder", "branch",
|
||||
"brass", "breeze", "bridge", "bright", "brook", "broom", "brown", "buffalo",
|
||||
"bumble", "burrow", "butter", "button",
|
||||
"cabin", "cactus", "calm", "camel", "campfire", "canary", "candle", "canoe",
|
||||
"canyon", "cardinal", "carrot", "cascade", "castle", "cedar", "celery", "cello",
|
||||
"cement", "cherry", "chestnut", "chime", "cinnamon", "cinder", "citron", "clay",
|
||||
"clear", "cliff", "clock", "cloud", "clover", "coast", "cobalt", "cobble",
|
||||
"cocoa", "coffee", "comet", "compass", "copper", "coral", "corner", "cosmos",
|
||||
"cotton", "cougar", "country", "coyote", "cove", "crane", "crater", "creek",
|
||||
"crescent", "crimson", "crocus", "crystal", "cypress",
|
||||
"daffodil", "dahlia", "daisy", "dawn", "deer", "delta", "denim", "desert",
|
||||
"dewdrop", "diamond", "dolphin", "doodle", "dove", "dragon", "drift", "drop",
|
||||
"dune", "dusk", "dusty",
|
||||
"eagle", "earth", "echo", "elder", "elkhorn", "ember", "emerald", "emperor",
|
||||
"evergreen", "evening",
|
||||
"falcon", "fawn", "feather", "fern", "fiddle", "field", "fiesta", "finch",
|
||||
"firepit", "firefly", "fjord", "flame", "flax", "fleece", "flint", "floral",
|
||||
"flower", "flute", "foal", "foggy", "forest", "fountain", "foxglove", "fresh",
|
||||
"frost", "fuchsia", "fudge",
|
||||
"gable", "galaxy", "garden", "garnet", "gazelle", "geode", "geyser", "ginger",
|
||||
"glacier", "glade", "glass", "glow", "gold", "goose", "gorge", "gourd",
|
||||
"granite", "grape", "grass", "gravel", "grayling", "greenery", "grizzly", "grove",
|
||||
"gull", "gumdrop", "gust",
|
||||
"hammock", "harbor", "harvest", "hawk", "hazel", "heather", "hedge", "heron",
|
||||
"hibiscus", "hickory", "hideaway", "highland", "hill", "hive", "hollow", "honey",
|
||||
"hopper", "horizon", "hummingbird", "husky",
|
||||
"iceberg", "indigo", "iris", "island", "ivory", "ivybush",
|
||||
"jade", "jasmine", "jasper", "jaybird", "jelly", "jewel", "jonquil", "journey",
|
||||
"juniper", "jupiter", "jute",
|
||||
"kale", "kangaroo", "kayak", "kelp", "kestrel", "kettle", "khaki", "kindling",
|
||||
"kingfisher", "kiwi", "knapweed", "koala",
|
||||
"lagoon", "lake", "lantern", "larch", "lark", "laurel", "lava", "lavender",
|
||||
"leaf", "lemon", "lichen", "light", "lilac", "lily", "lime", "limestone",
|
||||
"linden", "linen", "lion", "lobster", "locust", "loon", "lotus", "lumber",
|
||||
"lunar", "lupine", "lynx",
|
||||
"madrone", "magenta", "magnolia", "mahogany", "mallow", "mango", "manor", "maple",
|
||||
"marble", "marigold", "marina", "marlin", "marsh", "mauve", "meadow", "melody",
|
||||
"melon", "merlin", "metal", "midnight", "milk", "millet", "mineral", "mint",
|
||||
"mirror", "mist", "mitten", "molasses", "moon", "moose", "morning", "moss",
|
||||
"mountain", "mulberry", "muscat", "mustard",
|
||||
"narwhal", "navy", "nectar", "needle", "nest", "nettle", "newt", "nightfall",
|
||||
"noon", "nook", "north", "nova", "nutmeg",
|
||||
"oaken", "oasis", "oatmeal", "ocean", "ochre", "octagon", "olive", "onyx",
|
||||
"opal", "orange", "orbit", "orchard", "orchid", "oregano", "orion", "osprey",
|
||||
"otter", "outpost", "owlet", "oyster",
|
||||
"painter", "palace", "palm", "pansy", "panther", "papaya", "paprika", "parsley",
|
||||
"partridge", "passage", "pastel", "patio", "peach", "peacock", "pear", "pearl",
|
||||
"pebble", "pecan", "pelican", "penguin", "peony", "pepper", "perch", "peridot",
|
||||
"pewter", "phoenix", "pier", "pillar", "pine", "pineapple", "pinto", "piper",
|
||||
"pistachio", "plain", "planet", "plateau", "platinum", "plum", "plume", "polar",
|
||||
"pollen", "pond", "poplar", "poppy", "porcelain", "portal", "portrait", "potato",
|
||||
"prairie", "primrose", "prism", "puffin", "pumpkin",
|
||||
"quail", "quartz", "quaver", "quill", "quince", "quinoa",
|
||||
"rabbit", "raccoon", "radish", "rain", "rainbow", "raindrop", "rapids", "raspberry",
|
||||
"raven", "ravine", "redwood", "reed", "reef", "ridge", "river", "robin",
|
||||
"rocket", "rubyred", "rose", "rosemary", "rosewood", "ruffle", "rugby", "russet",
|
||||
"rustic", "ryefield",
|
||||
"saffron", "sage", "salmon", "sand", "sandstone", "sapphire", "savanna", "scarlet",
|
||||
"scout", "seal", "season", "seaweed", "sequoia", "shadow", "shamrock", "shell",
|
||||
"sherbet", "shore", "silver", "siskin", "skybloom", "skyline", "sleet", "smoke",
|
||||
"snail", "snapdragon", "snow", "snowflake", "snowy", "solar", "song", "sonic",
|
||||
"sorrel", "south", "sparkle", "sparrow", "spice", "spider", "spinach", "spire",
|
||||
"spring", "sprout", "spruce", "squirrel", "starfish", "starlight", "stoat", "stone",
|
||||
"stork", "storm", "stream", "studio", "summer", "sunbeam", "sundew", "sunny",
|
||||
"sunrise", "sunset", "swallow", "swan", "sweet", "sycamore",
|
||||
"tangelo", "tangerine", "tansy", "taupe", "teak", "teal", "thicket", "thistle",
|
||||
"thrush", "thunder", "tide", "tiger", "tinder", "topaz", "torch", "tortoise",
|
||||
"tower", "trail", "tranquil", "tundra", "tulip", "turquoise", "turtle", "twig",
|
||||
"twilight",
|
||||
"umber", "uplands",
|
||||
"valley", "vanilla", "velvet", "venus", "verdant", "verdigris", "vermilion", "violet",
|
||||
"vista", "vivid", "volcano", "vortex",
|
||||
"walnut", "warbler", "watercress", "waterfall", "wave", "waxwing", "weasel", "westwind",
|
||||
"whale", "whisker", "whisper", "wicker", "wildwood", "willow", "winter", "wisp",
|
||||
"wisteria", "wolf", "wombat", "woodland", "woolly", "wren", "wreath",
|
||||
"yarrow", "yellow", "yewtree", "yodel",
|
||||
"zebra", "zenith", "zephyr", "zinnia",
|
||||
"alabaster", "alfalfa", "almanac", "anise", "antelope", "arbor", "arena", "armadillo",
|
||||
"avocet", "azalea", "balsam", "bayou", "beacon", "blizzard", "bluebell", "bluebird",
|
||||
"bluejay", "bobolink", "borage", "boreal", "buckeye", "buckthorn", "buttercup",
|
||||
"cabana", "calico", "canopy", "caraway", "cardamom", "cattail", "celadon", "centaur",
|
||||
"chambray", "chamois", "champlain", "chestnuts", "chickadee", "chinook", "chipmunk", "cinnabar",
|
||||
"cirrus", "citrine", "clematis", "copperhead",
|
||||
"crocodile", "currant", "cuttlebone", "daffy", "dapple", "delphinium", "dervish", "diamondback",
|
||||
"dogwood", "dolphins", "dragonfly", "driftwood", "dusk", "dustpan", "ebony", "edelweiss",
|
||||
"emperor", "endive", "estuary", "everglade", "fairway", "feldspar", "fennel", "fieldstone",
|
||||
"firebrand", "firefly", "fireweed", "firework", "flagstone", "fossil", "frostbite", "galleon",
|
||||
"gardener", "geranium", "gingko", "ginseng", "goldfish", "goldfinch", "goldenrod", "graphite",
|
||||
"greenfinch", "guppy", "haiku", "halibut", "hammerhead", "harbinger", "harvest", "hatchling",
|
||||
"havana", "hawthorn", "hazelnut", "heartwood", "henna", "heron", "highrise", "homestead",
|
||||
"honeycomb", "honeydew", "horseshoe", "hyacinth", "iceland", "icicle", "indigobird", "ironwood",
|
||||
"jacaranda", "jamboree", "javelina", "jellyfish", "junebug", "kaleido", "kayaker", "kerchief",
|
||||
"keystone", "kingdom", "labrador", "lacewing", "ladybug", "lakeside", "lamplight", "leopard",
|
||||
"lighthouse", "lilypad", "lullaby", "magnet", "mahonia", "mandolin", "manzanita", "maraschino",
|
||||
"mariner", "marsupial", "mastodon", "matterhorn", "mayflower", "mayfly", "meadowlark", "merlot",
|
||||
"meteor", "midshipman", "millpond", "mimosa", "minnow", "mockingbird", "molten", "monarch",
|
||||
"monsoon", "moondust", "moonlight", "moorland", "morning", "mossland", "mountain", "mulch",
|
||||
"narcissus", "nautilus", "nettlebush", "northstar", "nuthatch", "obsidian", "okra", "olivine",
|
||||
"opalescent", "orchidea", "orchard", "ornament", "outrigger", "oxalis", "paddler", "paintbrush",
|
||||
"papyrus", "paradise", "pasture", "patchwork", "pathway", "peridot", "periwinkle", "petalbloom",
|
||||
"petrel", "petunia", "phlox", "pikeperch", "pinecone", "pioneer", "pipevine", "platypus",
|
||||
"pomelo", "pondweed", "porpoise", "powder", "promise", "puddle", "pumice", "puzzle",
|
||||
"quetzal", "quicksilver", "raccoon", "ragwort", "rainforest", "ramble", "rapid", "rascal",
|
||||
"raspberry", "redbud", "redfern", "redpoll", "reedling", "ringtail", "riverbed", "riverbird",
|
||||
"riverstone", "rockcress", "roebuck", "rosebay", "rosehip", "rosemary", "rowan", "rumble",
|
||||
"runaway", "rustler", "sagebrush", "sailcloth", "salamander", "salsify", "samphire", "sandbar",
|
||||
"sanddollar", "sandpiper", "santolina", "sapodilla", "sassafras", "scallion", "schooner", "seafoam",
|
||||
"seafrost", "seagrass", "seahorse", "seaport", "seashell", "seaspray", "shamble", "shimmer",
|
||||
"shoreline", "silkmoth", "silverfox", "skylark", "snapdragon", "snowberry", "snowdrop", "snowfall",
|
||||
"snowmelt", "softwood", "songbird", "sorghum", "southwind", "speedwell", "spinnaker", "spruce",
|
||||
"starlight", "starling", "stormcloud", "summit", "sundance", "sundew", "sundial", "sunflower",
|
||||
"surface", "swallowtail", "sweetcorn", "sycamore", "tabletop", "tamarack", "tamarind", "tangerine",
|
||||
"tarragon", "telescope", "thicket", "thrasher", "thunder", "thyme", "tideline", "timberland",
|
||||
"tinderbox", "topiary", "torchwood", "totem", "tradewind", "treasure", "tremolo", "trinket",
|
||||
"trumpetvine", "tugboat", "tundra", "turnstone", "underbrush", "vagabond", "valerian", "vanilla",
|
||||
"velveteen", "vermilion", "vinca", "vineyard", "violet", "voyager", "wagonwheel", "walnutwood",
|
||||
"watermark", "watershed", "waterway", "wavefront", "westerly", "whaleback", "whetstone", "wicker",
|
||||
"wildbloom", "wildflower", "wilderness", "windsong", "windward", "winterberry", "woodbine", "woodfern",
|
||||
"woodland", "woodthrush", "woolgrass", "yellowfin", "zenithal", "zucchini",
|
||||
}
|
||||
896
management/internals/modules/agentnetwork/manager.go
Normal file
896
management/internals/modules/agentnetwork/manager.go
Normal file
@@ -0,0 +1,896 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/labelgen"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// ensureSessionKeys mints an ed25519 session keypair on the provider
|
||||
// when one is missing. Idempotent: skips when both fields are already
|
||||
// populated (e.g. update or migrated rows). The keys are used by the
|
||||
// synthesised reverse-proxy service to sign / verify session JWTs
|
||||
// after a successful OIDC handshake.
|
||||
func ensureSessionKeys(p *types.Provider) error {
|
||||
if p.SessionPrivateKey != "" && p.SessionPublicKey != "" {
|
||||
return nil
|
||||
}
|
||||
pair, err := sessionkey.GenerateKeyPair()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate provider session keys: %w", err)
|
||||
}
|
||||
p.SessionPrivateKey = pair.PrivateKey
|
||||
p.SessionPublicKey = pair.PublicKey
|
||||
return nil
|
||||
}
|
||||
|
||||
// Manager governs the lifecycle of Agent Network providers and policies.
|
||||
type Manager interface {
|
||||
GetAllProviders(ctx context.Context, accountID, userID string) ([]*types.Provider, error)
|
||||
GetProvider(ctx context.Context, accountID, userID, providerID string) (*types.Provider, error)
|
||||
CreateProvider(ctx context.Context, userID string, provider *types.Provider, bootstrapCluster string) (*types.Provider, error)
|
||||
UpdateProvider(ctx context.Context, userID string, provider *types.Provider) (*types.Provider, error)
|
||||
DeleteProvider(ctx context.Context, accountID, userID, providerID string) error
|
||||
|
||||
GetAllPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
||||
GetPolicy(ctx context.Context, accountID, userID, policyID string) (*types.Policy, error)
|
||||
CreatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error)
|
||||
UpdatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error)
|
||||
DeletePolicy(ctx context.Context, accountID, userID, policyID string) error
|
||||
|
||||
GetAllGuardrails(ctx context.Context, accountID, userID string) ([]*types.Guardrail, error)
|
||||
GetGuardrail(ctx context.Context, accountID, userID, guardrailID string) (*types.Guardrail, error)
|
||||
CreateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error)
|
||||
UpdateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error)
|
||||
DeleteGuardrail(ctx context.Context, accountID, userID, guardrailID string) error
|
||||
|
||||
GetAllBudgetRules(ctx context.Context, accountID, userID string) ([]*types.AccountBudgetRule, error)
|
||||
GetBudgetRule(ctx context.Context, accountID, userID, ruleID string) (*types.AccountBudgetRule, error)
|
||||
CreateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error)
|
||||
UpdateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error)
|
||||
DeleteBudgetRule(ctx context.Context, accountID, userID, ruleID string) error
|
||||
|
||||
GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error)
|
||||
UpdateSettings(ctx context.Context, userID string, settings *types.Settings) (*types.Settings, error)
|
||||
|
||||
ListConsumption(ctx context.Context, accountID, userID string) ([]*types.Consumption, error)
|
||||
ListAccessLogs(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error)
|
||||
GetUsageOverview(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter, granularity types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error)
|
||||
StartAccessLogCleanup(ctx context.Context, cleanupIntervalHours int)
|
||||
RecordConsumption(ctx context.Context, accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds, tokensIn, tokensOut int64, costUSD float64) error
|
||||
RecordAccountBudgetUsage(ctx context.Context, accountID, userID string, groupIDs []string, tokensIn, tokensOut int64, costUSD float64) error
|
||||
RecordUsage(ctx context.Context, in RecordUsageInput) error
|
||||
SelectPolicyForRequest(ctx context.Context, in PolicySelectionInput) (*PolicySelectionResult, error)
|
||||
}
|
||||
|
||||
// PolicySelectionInput is the per-request selection envelope. The
|
||||
// proxy populates it from CapturedData (account, user, groups) plus
|
||||
// the provider llm_router resolved.
|
||||
type PolicySelectionInput struct {
|
||||
AccountID string
|
||||
UserID string
|
||||
GroupIDs []string
|
||||
ProviderID string
|
||||
}
|
||||
|
||||
// PolicySelectionResult names the policy that "pays" for this request
|
||||
// plus the deny envelope when every applicable policy has exhausted
|
||||
// every cap. AttributionGroupID is the lowest group id (string sort)
|
||||
// of caller_groups ∩ selected_policy.source_groups; empty when no
|
||||
// group dimension applies. WindowSeconds is the chosen policy's
|
||||
// effective window length in seconds (token_limit's wins when both
|
||||
// halves are enabled with mismatched windows; budget_limit's
|
||||
// otherwise; 0 when no caps are configured at all).
|
||||
type PolicySelectionResult struct {
|
||||
Allow bool
|
||||
SelectedPolicyID string
|
||||
AttributionGroupID string
|
||||
WindowSeconds int64
|
||||
DenyCode string
|
||||
DenyReason string
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
proxyController proxy.Controller
|
||||
|
||||
// reconcileCache holds the last set of synthesised proxy mappings
|
||||
// per account so reconcile can emit precise Create/Update/Delete
|
||||
// updates instead of a full re-push on every mutation. Keyed by
|
||||
// accountID, then by synthesised service ID.
|
||||
reconcileMu sync.Mutex
|
||||
reconcileCache map[string]map[string]*proto.ProxyMapping
|
||||
|
||||
// labelRngMu guards labelRng. PickUnique consumes math/rand.Source
|
||||
// state; concurrent provider creates would otherwise race.
|
||||
labelRngMu sync.Mutex
|
||||
labelRng *rand.Rand
|
||||
}
|
||||
|
||||
// NewManager constructs the persistent Agent Network manager. The
|
||||
// manager persists provider/policy/guardrail configuration and, on
|
||||
// every mutation, reconciles the in-memory synthesised reverse-proxy
|
||||
// services with the proxy cluster via proxyController. Pass nil for
|
||||
// proxyController to disable the reconcile push (useful in tests).
|
||||
func NewManager(
|
||||
store store.Store,
|
||||
permissionsManager permissions.Manager,
|
||||
accountManager account.Manager,
|
||||
proxyController proxy.Controller,
|
||||
) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
proxyController: proxyController,
|
||||
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
|
||||
labelRng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllProviders(ctx context.Context, accountID, userID string) ([]*types.Provider, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetProvider(ctx context.Context, accountID, userID, providerID string) (*types.Provider, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthNone, accountID, providerID)
|
||||
}
|
||||
|
||||
// CreateProvider persists a new provider for the account. bootstrapCluster
|
||||
// is used only when the per-account agent-network Settings row hasn't
|
||||
// been created yet; otherwise it is ignored (the cluster is pinned on
|
||||
// Settings and every provider in the account routes through it).
|
||||
func (m *managerImpl) CreateProvider(ctx context.Context, userID string, provider *types.Provider, bootstrapCluster string) (*types.Provider, error) {
|
||||
if err := m.requirePermission(ctx, provider.AccountID, userID, operations.Create); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// An empty api_key would silently produce a synthesised service
|
||||
// that 401s on every upstream request. Surface the misconfiguration
|
||||
// at create time instead.
|
||||
if strings.TrimSpace(provider.APIKey) == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "api_key is required when creating an agent network provider")
|
||||
}
|
||||
|
||||
if provider.ID == "" {
|
||||
fresh := types.NewProvider(provider.AccountID)
|
||||
provider.ID = fresh.ID
|
||||
provider.CreatedAt = fresh.CreatedAt
|
||||
provider.UpdatedAt = fresh.UpdatedAt
|
||||
}
|
||||
|
||||
if err := ensureSessionKeys(provider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.store.SaveAgentNetworkProvider(ctx, provider); err != nil {
|
||||
return nil, fmt.Errorf("save agent network provider: %w", err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(bootstrapCluster) != "" {
|
||||
if _, err := m.bootstrapSettingsIfNeeded(ctx, provider.AccountID, bootstrapCluster); err != nil {
|
||||
// The provider create has already succeeded; logging the
|
||||
// bootstrap miss matches the plan's PoC behaviour. The synth
|
||||
// path treats a missing settings row as a no-op, and the next
|
||||
// provider create retries the bootstrap.
|
||||
log.WithContext(ctx).Debugf("agent-network bootstrap settings for account %s on cluster %s: %v", provider.AccountID, bootstrapCluster, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, provider.ID, provider.AccountID, activity.AgentNetworkProviderCreated, provider.EventMeta())
|
||||
m.reconcile(ctx, provider.AccountID)
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateProvider(ctx context.Context, userID string, provider *types.Provider) (*types.Provider, error) {
|
||||
if err := m.requirePermission(ctx, provider.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthUpdate, provider.AccountID, provider.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get agent network provider: %w", err)
|
||||
}
|
||||
|
||||
// Preserve the API key if the caller didn't rotate it. A
|
||||
// whitespace-only value is treated as "not rotated" rather than a
|
||||
// real key, but it must not silently overwrite a valid stored key.
|
||||
if provider.APIKey == "" {
|
||||
provider.APIKey = existing.APIKey
|
||||
} else if strings.TrimSpace(provider.APIKey) == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "api_key must be non-blank when rotating an agent network provider")
|
||||
}
|
||||
// Always preserve the session keypair across updates so existing
|
||||
// session cookies stay valid. The keys are server-managed and
|
||||
// never surfaced through the API.
|
||||
provider.SessionPrivateKey = existing.SessionPrivateKey
|
||||
provider.SessionPublicKey = existing.SessionPublicKey
|
||||
if err := ensureSessionKeys(provider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
provider.CreatedAt = existing.CreatedAt
|
||||
provider.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkProvider(ctx, provider); err != nil {
|
||||
return nil, fmt.Errorf("save agent network provider: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, provider.ID, provider.AccountID, activity.AgentNetworkProviderUpdated, provider.EventMeta())
|
||||
m.reconcile(ctx, provider.AccountID)
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteProvider(ctx context.Context, accountID, userID, providerID string) error {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
provider, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthUpdate, accountID, providerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get agent network provider: %w", err)
|
||||
}
|
||||
|
||||
// Refuse to delete while any policy still references this provider.
|
||||
// The operator must detach it first.
|
||||
policies, err := m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get agent network policies: %w", err)
|
||||
}
|
||||
var blocking []string
|
||||
for _, p := range policies {
|
||||
if slices.Contains(p.DestinationProviderIDs, providerID) {
|
||||
blocking = append(blocking, p.Name)
|
||||
}
|
||||
}
|
||||
if len(blocking) > 0 {
|
||||
return status.Errorf(
|
||||
status.InvalidArgument,
|
||||
"provider is in use by %d %s (%s); detach it before deleting",
|
||||
len(blocking),
|
||||
pluralize(len(blocking), "policy", "policies"),
|
||||
strings.Join(blocking, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
if err := m.store.DeleteAgentNetworkProvider(ctx, accountID, providerID); err != nil {
|
||||
return fmt.Errorf("failed to delete agent network provider: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, providerID, accountID, activity.AgentNetworkProviderDeleted, provider.EventMeta())
|
||||
m.reconcile(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func pluralize(n int, singular, plural string) string {
|
||||
if n == 1 {
|
||||
return singular
|
||||
}
|
||||
return plural
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPolicy(ctx context.Context, accountID, userID, policyID string) (*types.Policy, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error) {
|
||||
if err := m.requirePermission(ctx, policy.AccountID, userID, operations.Create); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if policy.ID == "" {
|
||||
fresh := types.NewPolicy(policy.AccountID)
|
||||
policy.ID = fresh.ID
|
||||
policy.CreatedAt = fresh.CreatedAt
|
||||
policy.UpdatedAt = fresh.UpdatedAt
|
||||
}
|
||||
|
||||
if err := m.validateProviderRefs(ctx, policy.AccountID, policy.DestinationProviderIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.store.SaveAgentNetworkPolicy(ctx, policy); err != nil {
|
||||
return nil, fmt.Errorf("failed to save agent network policy: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, policy.ID, policy.AccountID, activity.AgentNetworkPolicyCreated, policy.EventMeta())
|
||||
m.reconcile(ctx, policy.AccountID)
|
||||
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error) {
|
||||
if err := m.requirePermission(ctx, policy.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthUpdate, policy.AccountID, policy.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get agent network policy: %w", err)
|
||||
}
|
||||
|
||||
if err := m.validateProviderRefs(ctx, policy.AccountID, policy.DestinationProviderIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
policy.CreatedAt = existing.CreatedAt
|
||||
policy.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkPolicy(ctx, policy); err != nil {
|
||||
return nil, fmt.Errorf("failed to save agent network policy: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, policy.ID, policy.AccountID, activity.AgentNetworkPolicyUpdated, policy.EventMeta())
|
||||
m.reconcile(ctx, policy.AccountID)
|
||||
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeletePolicy(ctx context.Context, accountID, userID, policyID string) error {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
policy, err := m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get agent network policy: %w", err)
|
||||
}
|
||||
|
||||
if err := m.store.DeleteAgentNetworkPolicy(ctx, accountID, policyID); err != nil {
|
||||
return fmt.Errorf("failed to delete agent network policy: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, policyID, accountID, activity.AgentNetworkPolicyDeleted, policy.EventMeta())
|
||||
m.reconcile(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllGuardrails(ctx context.Context, accountID, userID string) ([]*types.Guardrail, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetGuardrail(ctx context.Context, accountID, userID, guardrailID string) (*types.Guardrail, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthNone, accountID, guardrailID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error) {
|
||||
if err := m.requirePermission(ctx, guardrail.AccountID, userID, operations.Create); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if guardrail.ID == "" {
|
||||
fresh := types.NewGuardrail(guardrail.AccountID)
|
||||
guardrail.ID = fresh.ID
|
||||
guardrail.CreatedAt = fresh.CreatedAt
|
||||
guardrail.UpdatedAt = fresh.UpdatedAt
|
||||
}
|
||||
|
||||
if err := m.store.SaveAgentNetworkGuardrail(ctx, guardrail); err != nil {
|
||||
return nil, fmt.Errorf("failed to save agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, guardrail.ID, guardrail.AccountID, activity.AgentNetworkGuardrailCreated, guardrail.EventMeta())
|
||||
m.reconcile(ctx, guardrail.AccountID)
|
||||
|
||||
return guardrail, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error) {
|
||||
if err := m.requirePermission(ctx, guardrail.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthUpdate, guardrail.AccountID, guardrail.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
guardrail.CreatedAt = existing.CreatedAt
|
||||
guardrail.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkGuardrail(ctx, guardrail); err != nil {
|
||||
return nil, fmt.Errorf("failed to save agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, guardrail.ID, guardrail.AccountID, activity.AgentNetworkGuardrailUpdated, guardrail.EventMeta())
|
||||
m.reconcile(ctx, guardrail.AccountID)
|
||||
|
||||
return guardrail, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteGuardrail(ctx context.Context, accountID, userID, guardrailID string) error {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
guardrail, err := m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthUpdate, accountID, guardrailID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
if err := m.store.DeleteAgentNetworkGuardrail(ctx, accountID, guardrailID); err != nil {
|
||||
return fmt.Errorf("failed to delete agent network guardrail: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, guardrailID, accountID, activity.AgentNetworkGuardrailDeleted, guardrail.EventMeta())
|
||||
m.reconcile(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllBudgetRules returns every account-level budget rule for the account.
|
||||
func (m *managerImpl) GetAllBudgetRules(ctx context.Context, accountID, userID string) ([]*types.AccountBudgetRule, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// GetBudgetRule returns a single account-level budget rule.
|
||||
func (m *managerImpl) GetBudgetRule(ctx context.Context, accountID, userID, ruleID string) (*types.AccountBudgetRule, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthNone, accountID, ruleID)
|
||||
}
|
||||
|
||||
// CreateBudgetRule persists a new account-level budget rule. Budget rules are
|
||||
// enforced at request time (CheckLLMPolicyLimits), not baked into the synth
|
||||
// proxy config, so no reconcile is needed.
|
||||
func (m *managerImpl) CreateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
|
||||
if err := m.requirePermission(ctx, rule.AccountID, userID, operations.Create); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rule.ID == "" {
|
||||
fresh := types.NewAccountBudgetRule(rule.AccountID)
|
||||
rule.ID = fresh.ID
|
||||
rule.CreatedAt = fresh.CreatedAt
|
||||
rule.UpdatedAt = fresh.UpdatedAt
|
||||
}
|
||||
|
||||
if err := m.store.SaveAgentNetworkBudgetRule(ctx, rule); err != nil {
|
||||
return nil, fmt.Errorf("save agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, rule.ID, rule.AccountID, activity.AgentNetworkBudgetRuleCreated, rule.EventMeta())
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// UpdateBudgetRule updates an existing account-level budget rule.
|
||||
func (m *managerImpl) UpdateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
|
||||
if err := m.requirePermission(ctx, rule.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthUpdate, rule.AccountID, rule.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
rule.CreatedAt = existing.CreatedAt
|
||||
rule.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkBudgetRule(ctx, rule); err != nil {
|
||||
return nil, fmt.Errorf("save agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, rule.ID, rule.AccountID, activity.AgentNetworkBudgetRuleUpdated, rule.EventMeta())
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// DeleteBudgetRule removes an account-level budget rule.
|
||||
func (m *managerImpl) DeleteBudgetRule(ctx context.Context, accountID, userID, ruleID string) error {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rule, err := m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthUpdate, accountID, ruleID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.store.DeleteAgentNetworkBudgetRule(ctx, accountID, ruleID); err != nil {
|
||||
return fmt.Errorf("delete agent network budget rule: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, ruleID, accountID, activity.AgentNetworkBudgetRuleDeleted, rule.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateSettings applies the mutable account-level settings — the collection
|
||||
// toggles — onto the existing row. Cluster and Subdomain are immutable and are
|
||||
// preserved from the persisted row regardless of the input. Because the
|
||||
// collection toggles change the synthesised service config (prompt-capture
|
||||
// gating, access-log emission), a reconcile is triggered so the proxy and peer
|
||||
// network maps converge on the new state.
|
||||
func (m *managerImpl) UpdateSettings(ctx context.Context, userID string, settings *types.Settings) (*types.Settings, error) {
|
||||
if err := m.requirePermission(ctx, settings.AccountID, userID, operations.Update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthUpdate, settings.AccountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get agent network settings: %w", err)
|
||||
}
|
||||
|
||||
existing.EnableLogCollection = settings.EnableLogCollection
|
||||
existing.EnablePromptCollection = settings.EnablePromptCollection
|
||||
existing.RedactPii = settings.RedactPii
|
||||
existing.AccessLogRetentionDays = settings.AccessLogRetentionDays
|
||||
existing.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err := m.store.SaveAgentNetworkSettings(ctx, existing); err != nil {
|
||||
return nil, fmt.Errorf("save agent network settings: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, settings.AccountID, settings.AccountID, activity.AgentNetworkSettingsUpdated, map[string]any{
|
||||
"log_collection": existing.EnableLogCollection,
|
||||
"prompt_collection": existing.EnablePromptCollection,
|
||||
"redact_pii": existing.RedactPii,
|
||||
})
|
||||
m.reconcile(ctx, settings.AccountID)
|
||||
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
// validateProviderRefs ensures every destination provider id refers to a
|
||||
// provider that exists in the same account.
|
||||
func (m *managerImpl) validateProviderRefs(ctx context.Context, accountID string, providerIDs []string) error {
|
||||
if len(providerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, id := range providerIDs {
|
||||
if _, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthNone, accountID, id); err != nil {
|
||||
// Only a genuine not-found means the reference is invalid; a
|
||||
// store/runtime error must propagate as-is rather than be
|
||||
// masked as a client validation error.
|
||||
var sErr *status.Error
|
||||
if errors.As(err, &sErr) && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "destination_provider_ids: provider %s does not exist", id)
|
||||
}
|
||||
return fmt.Errorf("get destination provider %s: %w", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSettings returns the agent-network settings row for the account.
|
||||
// Returns the underlying status.NotFound when no row has been
|
||||
// bootstrapped yet (i.e. the account has no providers).
|
||||
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// bootstrapSettingsIfNeeded creates the per-account agent-network
|
||||
// settings row when missing. The cluster comes from the create-time
|
||||
// hint the dashboard sends (auto-picked from the active cluster list);
|
||||
// the subdomain is picked from the curated wordlist avoiding
|
||||
// collisions on the same cluster. Idempotent: if a row already exists
|
||||
// it is returned untouched and the hint is ignored.
|
||||
func (m *managerImpl) bootstrapSettingsIfNeeded(ctx context.Context, accountID, providerCluster string) (*types.Settings, error) {
|
||||
if accountID == "" {
|
||||
return nil, fmt.Errorf("bootstrap settings: account id is required")
|
||||
}
|
||||
if strings.TrimSpace(providerCluster) == "" {
|
||||
return nil, fmt.Errorf("bootstrap settings: provider cluster is required")
|
||||
}
|
||||
|
||||
existing, err := m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err == nil {
|
||||
return existing, nil
|
||||
}
|
||||
var sErr *status.Error
|
||||
if !errors.As(err, &sErr) || sErr.Type() != status.NotFound {
|
||||
return nil, fmt.Errorf("get agent network settings: %w", err)
|
||||
}
|
||||
|
||||
siblings, err := m.store.GetAgentNetworkSettingsByCluster(ctx, store.LockingStrengthNone, providerCluster)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list agent network settings on cluster: %w", err)
|
||||
}
|
||||
taken := make(map[string]struct{}, len(siblings))
|
||||
for _, s := range siblings {
|
||||
taken[s.Subdomain] = struct{}{}
|
||||
}
|
||||
|
||||
suffix := accountID
|
||||
if len(suffix) > 4 {
|
||||
suffix = suffix[:4]
|
||||
}
|
||||
|
||||
m.labelRngMu.Lock()
|
||||
subdomain := labelgen.PickUnique(m.labelRng, taken, suffix)
|
||||
m.labelRngMu.Unlock()
|
||||
|
||||
now := time.Now().UTC()
|
||||
settings := &types.Settings{
|
||||
AccountID: accountID,
|
||||
Cluster: providerCluster,
|
||||
Subdomain: subdomain,
|
||||
// Logs on by default; usage is collected regardless. Retention bounds
|
||||
// how long full log rows are kept.
|
||||
EnableLogCollection: true,
|
||||
AccessLogRetentionDays: types.DefaultAccessLogRetentionDays,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := m.store.SaveAgentNetworkSettings(ctx, settings); err != nil {
|
||||
return nil, fmt.Errorf("save agent network settings: %w", err)
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// ListConsumption returns every consumption row recorded for the
|
||||
// account, ordered window-newest-first. Backs the dashboard's basic
|
||||
// counter view; permission gate is the same Read role that gates
|
||||
// every other agent-network surface.
|
||||
func (m *managerImpl) ListConsumption(ctx context.Context, accountID, userID string) ([]*types.Consumption, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.store.ListAgentNetworkConsumption(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// ListAccessLogs returns a paginated, server-side-filtered page of
|
||||
// agent-network access logs plus the total count matching the filter.
|
||||
func (m *managerImpl) ListAccessLogs(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return m.store.GetAgentNetworkAccessLogs(ctx, store.LockingStrengthNone, accountID, filter)
|
||||
}
|
||||
|
||||
// GetUsageOverview returns the filtered usage rows aggregated into time buckets
|
||||
// at the requested granularity, oldest-first.
|
||||
func (m *managerImpl) GetUsageOverview(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter, granularity types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error) {
|
||||
if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := m.store.GetAgentNetworkUsageRows(ctx, store.LockingStrengthNone, accountID, filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return types.AggregateUsageByGranularity(rows, granularity), nil
|
||||
}
|
||||
|
||||
// StartAccessLogCleanup launches a background sweep that periodically deletes
|
||||
// each account's agent-network access-log rows older than that account's
|
||||
// AccessLogRetentionDays. Usage records are never swept. A non-positive
|
||||
// interval defaults to 24h.
|
||||
func (m *managerImpl) StartAccessLogCleanup(ctx context.Context, cleanupIntervalHours int) {
|
||||
if cleanupIntervalHours <= 0 {
|
||||
cleanupIntervalHours = 24
|
||||
}
|
||||
interval := time.Duration(cleanupIntervalHours) * time.Hour
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
m.cleanupAccessLogsOnce(ctx) // run once on startup
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.cleanupAccessLogsOnce(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// cleanupAccessLogsOnce sweeps every account's expired access-log rows against
|
||||
// its configured retention. Best-effort: a per-account failure is logged and
|
||||
// the sweep continues.
|
||||
func (m *managerImpl) cleanupAccessLogsOnce(ctx context.Context) {
|
||||
settings, err := m.store.GetAllAgentNetworkSettings(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("agent-network access-log cleanup: list settings: %v", err)
|
||||
return
|
||||
}
|
||||
for _, s := range settings {
|
||||
if s.AccessLogRetentionDays <= 0 {
|
||||
continue // keep indefinitely
|
||||
}
|
||||
cutoff := time.Now().UTC().AddDate(0, 0, -s.AccessLogRetentionDays)
|
||||
deleted, err := m.store.DeleteOldAgentNetworkAccessLogs(ctx, s.AccountID, cutoff)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("agent-network access-log cleanup for account %s: %v", s.AccountID, err)
|
||||
continue
|
||||
}
|
||||
if deleted > 0 {
|
||||
log.WithContext(ctx).Infof("agent-network access-log cleanup: deleted %d rows for account %s (retention %d days)", deleted, s.AccountID, s.AccessLogRetentionDays)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecordConsumption increments the (dim, window) counter by the
|
||||
// supplied deltas. The window_start is computed from time.Now under
|
||||
// the supplied window_seconds so callers don't have to pre-align —
|
||||
// the proxy's post-flight path simply hands us tokens + cost and
|
||||
// which dimension we're booking against.
|
||||
func (m *managerImpl) RecordConsumption(ctx context.Context, accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds, tokensIn, tokensOut int64, costUSD float64) error {
|
||||
if accountID == "" || dimID == "" || windowSeconds <= 0 {
|
||||
return status.Errorf(status.InvalidArgument, "account_id, dim_id and window_seconds must be set")
|
||||
}
|
||||
windowStart := types.WindowStart(time.Now(), windowSeconds)
|
||||
return m.store.IncrementAgentNetworkConsumption(ctx, accountID, kind, dimID, windowSeconds, windowStart, tokensIn, tokensOut, costUSD)
|
||||
}
|
||||
|
||||
func (m *managerImpl) requirePermission(ctx context.Context, accountID, userID string, op operations.Operation) error {
|
||||
ok, _, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.AgentNetwork, op)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockManager struct{}
|
||||
|
||||
// NewManagerMock returns a no-op manager useful for tests.
|
||||
func NewManagerMock() Manager {
|
||||
return &mockManager{}
|
||||
}
|
||||
|
||||
func (*mockManager) GetAllProviders(_ context.Context, _, _ string) ([]*types.Provider, error) {
|
||||
return []*types.Provider{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetProvider(_ context.Context, _, _, _ string) (*types.Provider, error) {
|
||||
return &types.Provider{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) CreateProvider(_ context.Context, _ string, p *types.Provider, _ string) (*types.Provider, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (*mockManager) UpdateProvider(_ context.Context, _ string, p *types.Provider) (*types.Provider, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (*mockManager) DeleteProvider(_ context.Context, _, _, _ string) error { return nil }
|
||||
|
||||
func (*mockManager) GetAllPolicies(_ context.Context, _, _ string) ([]*types.Policy, error) {
|
||||
return []*types.Policy{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetPolicy(_ context.Context, _, _, _ string) (*types.Policy, error) {
|
||||
return &types.Policy{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) CreatePolicy(_ context.Context, _ string, p *types.Policy) (*types.Policy, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (*mockManager) UpdatePolicy(_ context.Context, _ string, p *types.Policy) (*types.Policy, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (*mockManager) DeletePolicy(_ context.Context, _, _, _ string) error { return nil }
|
||||
|
||||
func (*mockManager) GetAllGuardrails(_ context.Context, _, _ string) ([]*types.Guardrail, error) {
|
||||
return []*types.Guardrail{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetGuardrail(_ context.Context, _, _, _ string) (*types.Guardrail, error) {
|
||||
return &types.Guardrail{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) CreateGuardrail(_ context.Context, _ string, g *types.Guardrail) (*types.Guardrail, error) {
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (*mockManager) UpdateGuardrail(_ context.Context, _ string, g *types.Guardrail) (*types.Guardrail, error) {
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (*mockManager) DeleteGuardrail(_ context.Context, _, _, _ string) error { return nil }
|
||||
|
||||
func (*mockManager) GetAllBudgetRules(_ context.Context, _, _ string) ([]*types.AccountBudgetRule, error) {
|
||||
return []*types.AccountBudgetRule{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetBudgetRule(_ context.Context, _, _, _ string) (*types.AccountBudgetRule, error) {
|
||||
return &types.AccountBudgetRule{}, nil
|
||||
}
|
||||
|
||||
func (*mockManager) CreateBudgetRule(_ context.Context, _ string, r *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (*mockManager) UpdateBudgetRule(_ context.Context, _ string, r *types.AccountBudgetRule) (*types.AccountBudgetRule, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (*mockManager) DeleteBudgetRule(_ context.Context, _, _, _ string) error { return nil }
|
||||
|
||||
func (*mockManager) GetSettings(_ context.Context, _, _ string) (*types.Settings, error) {
|
||||
return nil, status.Errorf(status.NotFound, "agent network settings not found")
|
||||
}
|
||||
|
||||
func (*mockManager) UpdateSettings(_ context.Context, _ string, s *types.Settings) (*types.Settings, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (*mockManager) ListConsumption(_ context.Context, _, _ string) ([]*types.Consumption, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*mockManager) ListAccessLogs(_ context.Context, _, _ string, _ types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (*mockManager) GetUsageOverview(_ context.Context, _, _ string, _ types.AgentNetworkAccessLogFilter, _ types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*mockManager) StartAccessLogCleanup(_ context.Context, _ int) {}
|
||||
|
||||
func (*mockManager) RecordConsumption(_ context.Context, _ string, _ types.ConsumptionDimension, _ string, _, _, _ int64, _ float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*mockManager) RecordAccountBudgetUsage(_ context.Context, _, _ string, _ []string, _, _ int64, _ float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*mockManager) RecordUsage(_ context.Context, _ RecordUsageInput) error {
|
||||
return nil
|
||||
}
|
||||
660
management/internals/modules/agentnetwork/policyselect.go
Normal file
660
management/internals/modules/agentnetwork/policyselect.go
Normal file
@@ -0,0 +1,660 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// validateUsageDeltas rejects negative or non-finite usage counters before they
|
||||
// reach the consumption store, so a bad delta can't decrement or poison totals.
|
||||
// The store batch method enforces the same invariant; this is the manager-level
|
||||
// guard so direct callers fail fast with a clear error.
|
||||
func validateUsageDeltas(tokensIn, tokensOut int64, costUSD float64) error {
|
||||
if tokensIn < 0 || tokensOut < 0 || costUSD < 0 || math.IsNaN(costUSD) || math.IsInf(costUSD, 0) {
|
||||
return status.Errorf(status.InvalidArgument, "usage deltas must be non-negative and finite")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deny codes the proxy surfaces back to the caller when every
|
||||
// applicable policy is exhausted. The proxy converts these into
|
||||
// upstream-shaped error responses.
|
||||
const (
|
||||
//nolint:gosec // policy deny code label, not a credential
|
||||
denyCodeTokenCapExceeded = "llm_policy.token_cap_exceeded"
|
||||
//nolint:gosec // policy deny code label, not a credential
|
||||
denyCodeBudgetCapExceeded = "llm_policy.budget_cap_exceeded"
|
||||
//nolint:gosec // account deny code label, not a credential
|
||||
denyCodeAccountTokenCapExceeded = "llm_account.token_cap_exceeded"
|
||||
//nolint:gosec // account deny code label, not a credential
|
||||
denyCodeAccountBudgetCapExceeded = "llm_account.budget_cap_exceeded"
|
||||
)
|
||||
|
||||
// consumptionCache holds the consumption counters prefetched for one
|
||||
// policy-selection request, keyed by ConsumptionKey. A miss returns a zero
|
||||
// counter — the same contract the store's single-row getter uses for absent
|
||||
// rows — so the eval logic is identical whether a counter exists yet or not.
|
||||
type consumptionCache map[types.ConsumptionKey]*types.Consumption
|
||||
|
||||
func (c consumptionCache) get(accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds int64, windowStart time.Time) *types.Consumption {
|
||||
key := types.ConsumptionKey{Kind: kind, DimID: dimID, WindowSeconds: windowSeconds, WindowStartUTC: windowStart.UTC()}
|
||||
if row, ok := c[key]; ok && row != nil {
|
||||
return row
|
||||
}
|
||||
return &types.Consumption{
|
||||
AccountID: accountID,
|
||||
DimensionKind: kind,
|
||||
DimensionID: dimID,
|
||||
WindowSeconds: windowSeconds,
|
||||
WindowStartUTC: windowStart.UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// addLimitKeys records the user/group consumption keys a single enabled (token
|
||||
// or budget) limit window reads for the given attribution group, into a dedup
|
||||
// set. attrGroup may be empty (no group dimension applies).
|
||||
func addLimitKeys(set map[types.ConsumptionKey]struct{}, userID, attrGroup string, windowSeconds int64, now time.Time) {
|
||||
if windowSeconds <= 0 {
|
||||
return
|
||||
}
|
||||
ws := types.WindowStart(now, windowSeconds)
|
||||
if userID != "" {
|
||||
set[types.ConsumptionKey{Kind: types.DimensionUser, DimID: userID, WindowSeconds: windowSeconds, WindowStartUTC: ws}] = struct{}{}
|
||||
}
|
||||
if attrGroup != "" {
|
||||
set[types.ConsumptionKey{Kind: types.DimensionGroup, DimID: attrGroup, WindowSeconds: windowSeconds, WindowStartUTC: ws}] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// prefetchConsumption loads, in one store round-trip, every consumption counter
|
||||
// that the account-budget ceiling and the candidate policies will read while
|
||||
// scoring this request. This replaces the per-cap point reads the selector
|
||||
// previously issued one at a time (the N+1 on the hot path).
|
||||
func (m *managerImpl) prefetchConsumption(ctx context.Context, in PolicySelectionInput, rules []*types.AccountBudgetRule, candidates []*types.Policy, now time.Time) (consumptionCache, error) {
|
||||
set := make(map[types.ConsumptionKey]struct{})
|
||||
for _, p := range candidates {
|
||||
attr := lowestIntersect(p.SourceGroups, in.GroupIDs)
|
||||
if p.Limits.TokenLimit.Enabled {
|
||||
addLimitKeys(set, in.UserID, attr, p.Limits.TokenLimit.WindowSeconds, now)
|
||||
}
|
||||
if p.Limits.BudgetLimit.Enabled {
|
||||
addLimitKeys(set, in.UserID, attr, p.Limits.BudgetLimit.WindowSeconds, now)
|
||||
}
|
||||
}
|
||||
for _, r := range rules {
|
||||
if r == nil || !r.Enabled || !budgetRuleApplies(r, in) {
|
||||
continue
|
||||
}
|
||||
attr := lowestIntersect(r.TargetGroups, in.GroupIDs)
|
||||
if r.Limits.TokenLimit.Enabled {
|
||||
addLimitKeys(set, in.UserID, attr, r.Limits.TokenLimit.WindowSeconds, now)
|
||||
}
|
||||
if r.Limits.BudgetLimit.Enabled {
|
||||
addLimitKeys(set, in.UserID, attr, r.Limits.BudgetLimit.WindowSeconds, now)
|
||||
}
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return consumptionCache{}, nil
|
||||
}
|
||||
keys := make([]types.ConsumptionKey, 0, len(set))
|
||||
for k := range set {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
rows, err := m.store.GetAgentNetworkConsumptionBatch(ctx, store.LockingStrengthNone, in.AccountID, keys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch read consumption: %w", err)
|
||||
}
|
||||
return consumptionCache(rows), nil
|
||||
}
|
||||
|
||||
// SelectPolicyForRequest picks the policy that "pays" for the
|
||||
// incoming request. The chosen policy is the one with the largest
|
||||
// pool that still has headroom — drain the bigger bucket first,
|
||||
// fall through to the next-biggest only when the current one's
|
||||
// group cap or shared per-user cap is exhausted. This matches
|
||||
// operator intuition for layered tiers ("privileged group has the
|
||||
// 10k budget, regular group has 1k as the safety net") and avoids
|
||||
// the load-balancer flapping that fraction-based scoring produces
|
||||
// once any cap has been touched.
|
||||
//
|
||||
// Ordering across non-exhausted candidates:
|
||||
// 1. Policies with NO enabled caps (catch-all-allow) win over any
|
||||
// capped policy — operators who configure unlimited access
|
||||
// expect requests to attribute there until they explicitly add
|
||||
// caps.
|
||||
// 2. Larger group token cap wins.
|
||||
// 3. Larger group budget USD cap wins.
|
||||
// 4. Larger user token cap wins.
|
||||
// 5. Larger user budget USD cap wins.
|
||||
// 6. Older created_at wins (deterministic final tiebreak so
|
||||
// multi-node selection converges).
|
||||
//
|
||||
// Returns Allow=true with empty SelectedPolicyID when no policy in
|
||||
// the account targets the (provider, caller-groups) combination —
|
||||
// llm_router is the gate that owns "no policy authorises this
|
||||
// request" semantics; this function trusts that authorisation has
|
||||
// already happened upstream and only does the limit-aware
|
||||
// attribution.
|
||||
func (m *managerImpl) SelectPolicyForRequest(ctx context.Context, in PolicySelectionInput) (*PolicySelectionResult, error) {
|
||||
if in.AccountID == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "account_id is required")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, in.AccountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list account budget rules: %w", err)
|
||||
}
|
||||
policies, err := m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, in.AccountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list account policies: %w", err)
|
||||
}
|
||||
candidates := filterApplicablePolicies(policies, in)
|
||||
|
||||
// Prefetch every consumption counter the ceiling + candidate policies will
|
||||
// read, in a single store round-trip, then score against the cache.
|
||||
cache, err := m.prefetchConsumption(ctx, in, rules, candidates, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Account-level budget rules are an always-on ceiling, evaluated
|
||||
// independently of policy selection (they bind even for catch-all-allow
|
||||
// policies or requests that match no policy). All applicable rules must
|
||||
// pass — this is where min-wins lives.
|
||||
if deny, code, reason := checkAccountBudget(in, rules, cache, now); deny {
|
||||
return &PolicySelectionResult{Allow: false, DenyCode: code, DenyReason: reason}, nil
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return &PolicySelectionResult{Allow: true}, nil
|
||||
}
|
||||
scored, lastDenyCode, lastDenyReason := scoreCandidates(in, candidates, cache, now)
|
||||
if len(scored) == 0 {
|
||||
return &PolicySelectionResult{
|
||||
Allow: false,
|
||||
DenyCode: lastDenyCode,
|
||||
DenyReason: lastDenyReason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sort.SliceStable(scored, func(i, j int) bool {
|
||||
// Catch-all-allow (no caps configured) wins outright over
|
||||
// any capped policy.
|
||||
iNoCap := isUncapped(scored[i].policy)
|
||||
jNoCap := isUncapped(scored[j].policy)
|
||||
if iNoCap != jNoCap {
|
||||
return iNoCap
|
||||
}
|
||||
// Bigger pool drains first. Group caps dominate (shared
|
||||
// across the group) before individual caps.
|
||||
if a, b := groupCapTokens(scored[i].policy), groupCapTokens(scored[j].policy); a != b {
|
||||
return a > b
|
||||
}
|
||||
if a, b := groupCapBudgetUsd(scored[i].policy), groupCapBudgetUsd(scored[j].policy); a != b {
|
||||
return a > b
|
||||
}
|
||||
if a, b := userCapTokens(scored[i].policy), userCapTokens(scored[j].policy); a != b {
|
||||
return a > b
|
||||
}
|
||||
if a, b := userCapBudgetUsd(scored[i].policy), userCapBudgetUsd(scored[j].policy); a != b {
|
||||
return a > b
|
||||
}
|
||||
return scored[i].policy.CreatedAt.Before(scored[j].policy.CreatedAt)
|
||||
})
|
||||
|
||||
winner := scored[0]
|
||||
return &PolicySelectionResult{
|
||||
Allow: true,
|
||||
SelectedPolicyID: winner.policy.ID,
|
||||
AttributionGroupID: winner.attributionGroup,
|
||||
WindowSeconds: winner.windowSeconds,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// filterApplicablePolicies returns the enabled policies that target
|
||||
// the requested provider and have at least one of the caller's groups
|
||||
// in their source_groups. Caller's group set is matched
|
||||
// case-sensitively against policy.SourceGroups.
|
||||
func filterApplicablePolicies(policies []*types.Policy, in PolicySelectionInput) []*types.Policy {
|
||||
if len(policies) == 0 {
|
||||
return nil
|
||||
}
|
||||
groupSet := make(map[string]struct{}, len(in.GroupIDs))
|
||||
for _, g := range in.GroupIDs {
|
||||
if g != "" {
|
||||
groupSet[g] = struct{}{}
|
||||
}
|
||||
}
|
||||
out := make([]*types.Policy, 0, len(policies))
|
||||
for _, p := range policies {
|
||||
if p == nil || !p.Enabled {
|
||||
continue
|
||||
}
|
||||
if !sliceContains(p.DestinationProviderIDs, in.ProviderID) {
|
||||
continue
|
||||
}
|
||||
if !anyGroupMatches(p.SourceGroups, groupSet) {
|
||||
continue
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// candidate is the per-policy intermediate the selector ranks. A
|
||||
// policy that's been exhausted on any enabled cap never makes it
|
||||
// into this slice; the selector's deny envelope carries the latest
|
||||
// exhaustion's reason out separately.
|
||||
type candidate struct {
|
||||
policy *types.Policy
|
||||
attributionGroup string
|
||||
windowSeconds int64
|
||||
}
|
||||
|
||||
// scoreCandidates evaluates every applicable policy against the
|
||||
// caller's current consumption. Exhausted policies are filtered out
|
||||
// of the returned slice; the most recent exhaustion's deny code +
|
||||
// human reason is returned alongside so the caller can surface it
|
||||
// when no candidate survives.
|
||||
func scoreCandidates(
|
||||
in PolicySelectionInput,
|
||||
candidates []*types.Policy,
|
||||
cache consumptionCache,
|
||||
now time.Time,
|
||||
) ([]candidate, string, string) {
|
||||
out := make([]candidate, 0, len(candidates))
|
||||
var lastDenyCode, lastDenyReason string
|
||||
|
||||
for _, p := range candidates {
|
||||
c, exhausted, denyCode, denyReason := scoreOne(in, p, cache, now)
|
||||
if exhausted {
|
||||
lastDenyCode = denyCode
|
||||
lastDenyReason = denyReason
|
||||
continue
|
||||
}
|
||||
out = append(out, c)
|
||||
}
|
||||
return out, lastDenyCode, lastDenyReason
|
||||
}
|
||||
|
||||
// scoreOne checks a single policy for cap exhaustion. Returns the
|
||||
// candidate envelope when the policy still has headroom on every
|
||||
// enabled cap; reports exhausted=true with a deny code naming the
|
||||
// offending cap kind otherwise.
|
||||
func scoreOne(
|
||||
in PolicySelectionInput,
|
||||
p *types.Policy,
|
||||
cache consumptionCache,
|
||||
now time.Time,
|
||||
) (candidate, bool, string, string) {
|
||||
attrGroup := lowestIntersect(p.SourceGroups, in.GroupIDs)
|
||||
c := candidate{
|
||||
policy: p,
|
||||
attributionGroup: attrGroup,
|
||||
windowSeconds: effectiveWindowSeconds(p),
|
||||
}
|
||||
|
||||
if p.Limits.TokenLimit.Enabled && p.Limits.TokenLimit.WindowSeconds > 0 {
|
||||
if exhausted, reason := evalTokenCap(cache, in.AccountID, in.UserID, attrGroup, p.Limits.TokenLimit, now, "policy "+p.ID); exhausted {
|
||||
return candidate{}, true, denyCodeTokenCapExceeded, reason
|
||||
}
|
||||
}
|
||||
|
||||
if p.Limits.BudgetLimit.Enabled && p.Limits.BudgetLimit.WindowSeconds > 0 {
|
||||
if exhausted, reason := evalBudgetCap(cache, in.AccountID, in.UserID, attrGroup, p.Limits.BudgetLimit, now, "policy "+p.ID); exhausted {
|
||||
return candidate{}, true, denyCodeBudgetCapExceeded, reason
|
||||
}
|
||||
}
|
||||
|
||||
return c, false, "", ""
|
||||
}
|
||||
|
||||
// evalTokenCap reports whether the token limit is already exhausted for the
|
||||
// caller in its own window. attrGroup may be empty (no group dimension applies).
|
||||
// label identifies the cap source ("policy <id>" or "account rule <id>") for the
|
||||
// deny reason. It is the shared primitive behind both policy and account-rule
|
||||
// enforcement.
|
||||
func evalTokenCap(
|
||||
cache consumptionCache,
|
||||
accountID, userID, attrGroup string,
|
||||
tl types.PolicyTokenLimit,
|
||||
now time.Time,
|
||||
label string,
|
||||
) (bool, string) {
|
||||
windowStart := types.WindowStart(now, tl.WindowSeconds)
|
||||
|
||||
if tl.UserCap > 0 && userID != "" {
|
||||
row := cache.get(accountID, types.DimensionUser, userID, tl.WindowSeconds, windowStart)
|
||||
used := row.TokensInput + row.TokensOutput
|
||||
if used >= tl.UserCap {
|
||||
return true, fmt.Sprintf("user token cap exhausted on %s (used %d of %d)", label, used, tl.UserCap)
|
||||
}
|
||||
}
|
||||
|
||||
if tl.GroupCap > 0 && attrGroup != "" {
|
||||
row := cache.get(accountID, types.DimensionGroup, attrGroup, tl.WindowSeconds, windowStart)
|
||||
used := row.TokensInput + row.TokensOutput
|
||||
if used >= tl.GroupCap {
|
||||
return true, fmt.Sprintf("group token cap exhausted on %s (used %d of %d)", label, used, tl.GroupCap)
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// evalBudgetCap is the budget (USD) counterpart of evalTokenCap.
|
||||
func evalBudgetCap(
|
||||
cache consumptionCache,
|
||||
accountID, userID, attrGroup string,
|
||||
bl types.PolicyBudgetLimit,
|
||||
now time.Time,
|
||||
label string,
|
||||
) (bool, string) {
|
||||
windowStart := types.WindowStart(now, bl.WindowSeconds)
|
||||
|
||||
if bl.UserCapUsd > 0 && userID != "" {
|
||||
row := cache.get(accountID, types.DimensionUser, userID, bl.WindowSeconds, windowStart)
|
||||
if row.CostUSD >= bl.UserCapUsd {
|
||||
return true, fmt.Sprintf("user budget cap exhausted on %s (used $%.4f of $%.4f)", label, row.CostUSD, bl.UserCapUsd)
|
||||
}
|
||||
}
|
||||
|
||||
if bl.GroupCapUsd > 0 && attrGroup != "" {
|
||||
row := cache.get(accountID, types.DimensionGroup, attrGroup, bl.WindowSeconds, windowStart)
|
||||
if row.CostUSD >= bl.GroupCapUsd {
|
||||
return true, fmt.Sprintf("group budget cap exhausted on %s (used $%.4f of $%.4f)", label, row.CostUSD, bl.GroupCapUsd)
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// checkAccountBudget evaluates every applicable account-level budget rule as an
|
||||
// all-must-pass ceiling. A rule applies when the caller is in its TargetUsers,
|
||||
// one of its TargetGroups, or it has no targets at all (account-wide). Returns
|
||||
// deny=true with an llm_account.* code on the first exhausted rule. Group caps
|
||||
// attribute to the lowest intersecting group (the same model policies use), so
|
||||
// multi-group behavior is unchanged.
|
||||
func checkAccountBudget(in PolicySelectionInput, rules []*types.AccountBudgetRule, cache consumptionCache, now time.Time) (bool, string, string) {
|
||||
for _, r := range rules {
|
||||
if r == nil || !r.Enabled || !budgetRuleApplies(r, in) {
|
||||
continue
|
||||
}
|
||||
attrGroup := lowestIntersect(r.TargetGroups, in.GroupIDs)
|
||||
label := "account rule " + r.ID
|
||||
|
||||
if r.Limits.TokenLimit.Enabled && r.Limits.TokenLimit.WindowSeconds > 0 {
|
||||
if exhausted, reason := evalTokenCap(cache, in.AccountID, in.UserID, attrGroup, r.Limits.TokenLimit, now, label); exhausted {
|
||||
return true, denyCodeAccountTokenCapExceeded, reason
|
||||
}
|
||||
}
|
||||
|
||||
if r.Limits.BudgetLimit.Enabled && r.Limits.BudgetLimit.WindowSeconds > 0 {
|
||||
if exhausted, reason := evalBudgetCap(cache, in.AccountID, in.UserID, attrGroup, r.Limits.BudgetLimit, now, label); exhausted {
|
||||
return true, denyCodeAccountBudgetCapExceeded, reason
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, "", ""
|
||||
}
|
||||
|
||||
// budgetRuleApplies reports whether an account budget rule binds the caller:
|
||||
// a direct user match, a group intersection, or an untargeted (account-wide)
|
||||
// rule.
|
||||
func budgetRuleApplies(r *types.AccountBudgetRule, in PolicySelectionInput) bool {
|
||||
if len(r.TargetUsers) == 0 && len(r.TargetGroups) == 0 {
|
||||
return true
|
||||
}
|
||||
if in.UserID != "" && sliceContains(r.TargetUsers, in.UserID) {
|
||||
return true
|
||||
}
|
||||
groupSet := make(map[string]struct{}, len(in.GroupIDs))
|
||||
for _, g := range in.GroupIDs {
|
||||
if g != "" {
|
||||
groupSet[g] = struct{}{}
|
||||
}
|
||||
}
|
||||
return anyGroupMatches(r.TargetGroups, groupSet)
|
||||
}
|
||||
|
||||
// RecordAccountBudgetUsage fans the served request's usage out to every
|
||||
// applicable account budget rule's own (dimension, window) counter. The user
|
||||
// dimension is always booked when a rule has a user-applicable cap; the group
|
||||
// dimension books against the rule's lowest intersecting group. This runs
|
||||
// alongside the policy-window record so account ceilings accumulate in their own
|
||||
// windows (commonly monthly) independently of the per-policy window.
|
||||
func (m *managerImpl) RecordAccountBudgetUsage(ctx context.Context, accountID, userID string, groupIDs []string, tokensIn, tokensOut int64, costUSD float64) error {
|
||||
if accountID == "" {
|
||||
return status.Errorf(status.InvalidArgument, "account_id is required")
|
||||
}
|
||||
if err := validateUsageDeltas(tokensIn, tokensOut, costUSD); err != nil {
|
||||
return err
|
||||
}
|
||||
rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list account budget rules: %w", err)
|
||||
}
|
||||
set := make(map[types.ConsumptionKey]struct{})
|
||||
addAccountBudgetKeys(set, PolicySelectionInput{AccountID: accountID, UserID: userID, GroupIDs: groupIDs}, rules, time.Now().UTC())
|
||||
if len(set) == 0 {
|
||||
return nil
|
||||
}
|
||||
return m.store.IncrementAgentNetworkConsumptionBatch(ctx, accountID, keysSlice(set), tokensIn, tokensOut, costUSD)
|
||||
}
|
||||
|
||||
// RecordUsageInput carries everything RecordUsage books for one served request.
|
||||
type RecordUsageInput struct {
|
||||
AccountID string
|
||||
UserID string
|
||||
AttributionGroupID string // selected policy's attribution group (policy window)
|
||||
GroupIDs []string
|
||||
WindowSeconds int64 // selected policy's window; 0 means no policy cap
|
||||
TokensIn int64
|
||||
TokensOut int64
|
||||
CostUSD float64
|
||||
}
|
||||
|
||||
// RecordUsage books a served request's usage against every counter it touches —
|
||||
// the selected policy's per-(user, group) window plus every applicable account
|
||||
// budget rule's own window — deduplicated and written in a single transaction.
|
||||
// Two counters that collapse to the same (dimension, window) tuple are booked
|
||||
// once, so a single request can never double-count against one cap.
|
||||
func (m *managerImpl) RecordUsage(ctx context.Context, in RecordUsageInput) error {
|
||||
if in.AccountID == "" {
|
||||
return status.Errorf(status.InvalidArgument, "account_id is required")
|
||||
}
|
||||
if err := validateUsageDeltas(in.TokensIn, in.TokensOut, in.CostUSD); err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
set := make(map[types.ConsumptionKey]struct{})
|
||||
|
||||
// Policy-window dimensions are booked only when a policy cap bound this
|
||||
// request (window > 0). A zero window means catch-all-allow / no policy cap;
|
||||
// the account fan-out below still books against the budget rules' windows.
|
||||
if in.WindowSeconds > 0 {
|
||||
addLimitKeys(set, in.UserID, in.AttributionGroupID, in.WindowSeconds, now)
|
||||
}
|
||||
|
||||
rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, in.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list account budget rules: %w", err)
|
||||
}
|
||||
addAccountBudgetKeys(set, PolicySelectionInput{AccountID: in.AccountID, UserID: in.UserID, GroupIDs: in.GroupIDs}, rules, now)
|
||||
|
||||
if len(set) == 0 {
|
||||
return nil
|
||||
}
|
||||
return m.store.IncrementAgentNetworkConsumptionBatch(ctx, in.AccountID, keysSlice(set), in.TokensIn, in.TokensOut, in.CostUSD)
|
||||
}
|
||||
|
||||
// addAccountBudgetKeys adds the (dimension, window) keys a served request books
|
||||
// against every applicable account budget rule into the dedup set.
|
||||
func addAccountBudgetKeys(set map[types.ConsumptionKey]struct{}, in PolicySelectionInput, rules []*types.AccountBudgetRule, now time.Time) {
|
||||
for _, r := range rules {
|
||||
if r == nil || !r.Enabled || !budgetRuleApplies(r, in) {
|
||||
continue
|
||||
}
|
||||
attrGroup := lowestIntersect(r.TargetGroups, in.GroupIDs)
|
||||
for _, window := range ruleWindows(r) {
|
||||
addLimitKeys(set, in.UserID, attrGroup, window, now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// keysSlice flattens a ConsumptionKey set into a slice.
|
||||
func keysSlice(set map[types.ConsumptionKey]struct{}) []types.ConsumptionKey {
|
||||
keys := make([]types.ConsumptionKey, 0, len(set))
|
||||
for k := range set {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// ruleWindows returns the distinct enabled window lengths a budget rule books
|
||||
// against (token window and/or budget window, deduplicated).
|
||||
func ruleWindows(r *types.AccountBudgetRule) []int64 {
|
||||
var windows []int64
|
||||
if r.Limits.TokenLimit.Enabled && r.Limits.TokenLimit.WindowSeconds > 0 {
|
||||
windows = append(windows, r.Limits.TokenLimit.WindowSeconds)
|
||||
}
|
||||
if r.Limits.BudgetLimit.Enabled && r.Limits.BudgetLimit.WindowSeconds > 0 {
|
||||
bw := r.Limits.BudgetLimit.WindowSeconds
|
||||
if len(windows) == 0 || windows[0] != bw {
|
||||
windows = append(windows, bw)
|
||||
}
|
||||
}
|
||||
return windows
|
||||
}
|
||||
|
||||
// effectiveWindowSeconds returns the window length the proxy should
|
||||
// hand back to RecordLLMUsage. When both halves are enabled with
|
||||
// different windows, token_limit wins (the more common config); when
|
||||
// only one is enabled that one wins; when neither is enabled the
|
||||
// returned value is 0 — RecordLLMUsage treats 0 as "no limit
|
||||
// tracking" and skips the increment, which is the right pass-through
|
||||
// for catch-all-allow policies with no caps configured.
|
||||
func effectiveWindowSeconds(p *types.Policy) int64 {
|
||||
if p.Limits.TokenLimit.Enabled && p.Limits.TokenLimit.WindowSeconds > 0 {
|
||||
return p.Limits.TokenLimit.WindowSeconds
|
||||
}
|
||||
if p.Limits.BudgetLimit.Enabled && p.Limits.BudgetLimit.WindowSeconds > 0 {
|
||||
return p.Limits.BudgetLimit.WindowSeconds
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// lowestIntersect returns the lowest-by-string-sort element of
|
||||
// callerGroups ∩ sourceGroups. Empty when the intersection is empty.
|
||||
// Lowest is deterministic so multi-node selection converges.
|
||||
func lowestIntersect(sourceGroups, callerGroups []string) string {
|
||||
if len(sourceGroups) == 0 || len(callerGroups) == 0 {
|
||||
return ""
|
||||
}
|
||||
srcSet := make(map[string]struct{}, len(sourceGroups))
|
||||
for _, g := range sourceGroups {
|
||||
srcSet[g] = struct{}{}
|
||||
}
|
||||
var best string
|
||||
for _, g := range callerGroups {
|
||||
if _, ok := srcSet[g]; !ok {
|
||||
continue
|
||||
}
|
||||
if best == "" || g < best {
|
||||
best = g
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func anyGroupMatches(sourceGroups []string, callerSet map[string]struct{}) bool {
|
||||
for _, g := range sourceGroups {
|
||||
if _, ok := callerSet[g]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isUncapped reports whether a policy has any enabled cap with a
|
||||
// positive limit value. Mirrors the eval functions' guards: a policy
|
||||
// with token_limit.enabled=true but every cap value at 0 still
|
||||
// counts as uncapped because the eval would query nothing and bind
|
||||
// nothing.
|
||||
func isUncapped(p *types.Policy) bool {
|
||||
tl := p.Limits.TokenLimit
|
||||
if tl.Enabled && tl.WindowSeconds > 0 && (tl.GroupCap > 0 || tl.UserCap > 0) {
|
||||
return false
|
||||
}
|
||||
bl := p.Limits.BudgetLimit
|
||||
if bl.Enabled && bl.WindowSeconds > 0 && (bl.GroupCapUsd > 0 || bl.UserCapUsd > 0) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// groupCapTokens returns the policy's group-token cap when the token
|
||||
// limit is enabled, zero otherwise. Drives the primary "bigger pool
|
||||
// first" sort.
|
||||
func groupCapTokens(p *types.Policy) int64 {
|
||||
if p.Limits.TokenLimit.Enabled {
|
||||
return p.Limits.TokenLimit.GroupCap
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// groupCapBudgetUsd returns the policy's group-budget cap in USD
|
||||
// when the budget limit is enabled, zero otherwise. Secondary sort
|
||||
// key after token group cap so budget-only policies still order
|
||||
// predictably.
|
||||
func groupCapBudgetUsd(p *types.Policy) float64 {
|
||||
if p.Limits.BudgetLimit.Enabled {
|
||||
return p.Limits.BudgetLimit.GroupCapUsd
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// userCapTokens returns the policy's per-user token cap when the
|
||||
// token limit is enabled, zero otherwise. Tertiary sort key, used
|
||||
// when group caps tie or are absent.
|
||||
func userCapTokens(p *types.Policy) int64 {
|
||||
if p.Limits.TokenLimit.Enabled {
|
||||
return p.Limits.TokenLimit.UserCap
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// userCapBudgetUsd returns the policy's per-user budget cap in USD
|
||||
// when the budget limit is enabled, zero otherwise. Quaternary sort
|
||||
// key for budget-only policies whose group caps tie or are absent.
|
||||
func userCapBudgetUsd(p *types.Policy) float64 {
|
||||
if p.Limits.BudgetLimit.Enabled {
|
||||
return p.Limits.BudgetLimit.UserCapUsd
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func sliceContains(haystack []string, needle string) bool {
|
||||
for _, v := range haystack {
|
||||
if v == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// mockManager fallback so tests that don't care about selection still
|
||||
// compile.
|
||||
func (*mockManager) SelectPolicyForRequest(_ context.Context, _ PolicySelectionInput) (*PolicySelectionResult, error) {
|
||||
return &PolicySelectionResult{Allow: true}, nil
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// GC-2 no-mock enforcement tests for the account-budget ceiling. They drive the
|
||||
// real store + real consumption accounting through SelectPolicyForRequest and
|
||||
// RecordAccountBudgetUsage, asserting min-wins (account binds independently of
|
||||
// policy), targeting (groups + direct users), and the record fan-out.
|
||||
|
||||
func accountWideUserTokenRule(id string, userCap, window int64) *types.AccountBudgetRule {
|
||||
r := types.NewAccountBudgetRule(realSelectAccount)
|
||||
r.ID = id
|
||||
r.Limits.TokenLimit = types.PolicyTokenLimit{Enabled: true, UserCap: userCap, WindowSeconds: window}
|
||||
return r
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AccountCeilingBindsEvenWithUncappedPolicy proves
|
||||
// min-wins: the account user ceiling denies once exhausted even though a
|
||||
// catch-all-allow (uncapped) policy would otherwise pass the request. The
|
||||
// account gate runs independently of and ahead of policy selection.
|
||||
func TestSelectPolicy_RealStore_AccountCeilingBindsEvenWithUncappedPolicy(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// An uncapped (catch-all-allow) policy: enabled token limit, zero caps.
|
||||
uncapped := capPolicy("pol-open", realSelectAccount, []string{"grp-eng"}, "prov-1", 0, 86_400)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, uncapped))
|
||||
|
||||
// Account-wide user ceiling of 100 tokens in an hourly window.
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-1", 100, 3_600)))
|
||||
|
||||
in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"}
|
||||
|
||||
// Fresh: account ceiling has headroom, uncapped policy wins.
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "fresh account ceiling must allow")
|
||||
|
||||
// Drain the account user ceiling via the fan-out path.
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", []string{"grp-eng"}, 100, 0, 0))
|
||||
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "account ceiling must deny even though the policy is uncapped (min-wins)")
|
||||
assert.Equal(t, denyCodeAccountTokenCapExceeded, res.DenyCode, "deny must carry the llm_account.* code")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AccountGroupCeiling proves a group-targeted rule
|
||||
// binds the caller's group dimension.
|
||||
func TestSelectPolicy_RealStore_AccountGroupCeiling(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
rule := types.NewAccountBudgetRule(realSelectAccount)
|
||||
rule.ID = "ainbud-grp"
|
||||
rule.TargetGroups = []string{"grp-eng"}
|
||||
rule.Limits.BudgetLimit = types.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 5.0, WindowSeconds: 2_592_000}
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, rule))
|
||||
|
||||
in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"}
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "fresh group ceiling must allow")
|
||||
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", []string{"grp-eng"}, 0, 0, 5.0))
|
||||
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "group budget ceiling must deny once spent")
|
||||
assert.Equal(t, denyCodeAccountBudgetCapExceeded, res.DenyCode, "account budget deny code")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AccountTargetUsersBindsOnlyThatUser proves a
|
||||
// TargetUsers rule tightens only the named user, leaving others unbound.
|
||||
func TestSelectPolicy_RealStore_AccountTargetUsersBindsOnlyThatUser(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
rule := types.NewAccountBudgetRule(realSelectAccount)
|
||||
rule.ID = "ainbud-alice"
|
||||
rule.TargetUsers = []string{"alice"}
|
||||
rule.Limits.TokenLimit = types.PolicyTokenLimit{Enabled: true, UserCap: 100, WindowSeconds: 3_600}
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, rule))
|
||||
|
||||
// Record alice's usage to the rule window.
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "alice", nil, 100, 0, 0))
|
||||
|
||||
aliceIn := PolicySelectionInput{AccountID: realSelectAccount, UserID: "alice", ProviderID: "prov-1"}
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, aliceIn)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "alice is bound by the TargetUsers rule and is exhausted")
|
||||
|
||||
bobIn := PolicySelectionInput{AccountID: realSelectAccount, UserID: "bob", ProviderID: "prov-1"}
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, bobIn)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "bob is not in TargetUsers, so the rule must not bind him")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AccountRuleRecordsToOwnWindow proves the record
|
||||
// fan-out books usage in the rule's own window (distinct from any policy
|
||||
// window), so the account ceiling accumulates independently.
|
||||
func TestSelectPolicy_RealStore_AccountRuleRecordsToOwnWindow(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-w", 100, 3_600)))
|
||||
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", nil, 60, 0, 0))
|
||||
|
||||
// Same user, a policy-style daily window must NOT see the account-window
|
||||
// usage — windows are independent counters.
|
||||
dailyRow, err := s.GetAgentNetworkConsumption(ctx, store.LockingStrengthNone, realSelectAccount, types.DimensionUser, "user-1", 86_400, types.WindowStart(time.Now().UTC(), 86_400))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), dailyRow.TokensInput+dailyRow.TokensOutput, "daily window must be untouched by the hourly account-rule record")
|
||||
|
||||
// A second record pushes the hourly account window to its cap → deny.
|
||||
require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", nil, 40, 0, 0))
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", ProviderID: "prov-1"})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "100 tokens recorded in the rule's hourly window must exhaust the 100-token ceiling")
|
||||
assert.Equal(t, denyCodeAccountTokenCapExceeded, res.DenyCode, "account token deny code")
|
||||
}
|
||||
|
||||
// TestRecordUsage_RealStore_BooksPolicyAndAccountWindows proves the batched
|
||||
// post-flight write books the selected policy's window AND every applicable
|
||||
// account rule's (independent) window in a single call — the #6 batched-write
|
||||
// path the proxy's RecordLLMUsage RPC now uses.
|
||||
func TestRecordUsage_RealStore_BooksPolicyAndAccountWindows(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Policy: 100-token group cap on a daily window. Account rule: 100-token
|
||||
// user ceiling on an hourly window — an independent counter.
|
||||
policy := capPolicy("pol-1", realSelectAccount, []string{"grp-eng"}, "prov-1", 100, 86_400)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, policy))
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-1", 100, 3_600)))
|
||||
|
||||
in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"}
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
require.True(t, res.Allow)
|
||||
require.Equal(t, "pol-1", res.SelectedPolicyID)
|
||||
|
||||
// One batched record books the policy window (group + user @86400) and the
|
||||
// account rule window (user @3600) atomically.
|
||||
require.NoError(t, mgr.RecordUsage(ctx, RecordUsageInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
AttributionGroupID: res.AttributionGroupID,
|
||||
GroupIDs: []string{"grp-eng"},
|
||||
WindowSeconds: res.WindowSeconds,
|
||||
TokensIn: 100,
|
||||
}))
|
||||
|
||||
// The next selection denies — the account hourly ceiling binds first.
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "usage booked by RecordUsage must enforce on the next request")
|
||||
|
||||
// Prove BOTH windows were booked in the one call via a direct batch read.
|
||||
now := time.Now().UTC()
|
||||
userKey := types.ConsumptionKey{Kind: types.DimensionUser, DimID: "user-1", WindowSeconds: 3_600, WindowStartUTC: types.WindowStart(now, 3_600)}
|
||||
groupKey := types.ConsumptionKey{Kind: types.DimensionGroup, DimID: "grp-eng", WindowSeconds: 86_400, WindowStartUTC: types.WindowStart(now, 86_400)}
|
||||
rows, err := s.GetAgentNetworkConsumptionBatch(ctx, store.LockingStrengthNone, realSelectAccount, []types.ConsumptionKey{userKey, groupKey})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, rows, userKey, "account rule user/hourly window booked")
|
||||
require.Contains(t, rows, groupKey, "policy group/daily window booked")
|
||||
assert.Equal(t, int64(100), rows[userKey].TokensInput, "account hourly user counter")
|
||||
assert.Equal(t, int64(100), rows[groupKey].TokensInput, "policy daily group counter")
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// This file is the no-mock regression guard for policy limit enforcement.
|
||||
// policyselect_test.go pins the same behavior through a gomock store with
|
||||
// explicit call-sequence expectations — brittle precisely where the upcoming
|
||||
// account-budget work (GC-2) refactors the cap-eval primitive and adds an
|
||||
// account-level gate. These tests drive the REAL sqlite store + REAL
|
||||
// consumption accounting and assert observable behavior (allow / deny /
|
||||
// selection / attribution), not which store methods get called. They must keep
|
||||
// passing unchanged after GC-2 lands, which is what proves "current behavior is
|
||||
// not changed."
|
||||
|
||||
const realSelectAccount = "acc-realselect-1"
|
||||
|
||||
// newRealSelectorMgr builds a managerImpl backed by a real sqlite test store.
|
||||
func newRealSelectorMgr(t *testing.T) (*managerImpl, store.Store) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
t.Cleanup(cleanup)
|
||||
return &managerImpl{store: s}, s
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_NoApplicablePolicies pins the pass-through:
|
||||
// nothing targets the (provider, groups) combination, so the selector allows
|
||||
// without attribution or consumption tracking.
|
||||
func TestSelectPolicy_RealStore_NoApplicablePolicies(t *testing.T) {
|
||||
mgr, _ := newRealSelectorMgr(t)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-x"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "no applicable policy must pass through as allow")
|
||||
assert.Empty(t, res.SelectedPolicyID, "no selection when nothing applies")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_AllowAndLowestGroupAttribution pins the v1
|
||||
// attribution rule (lowest intersecting group by string sort) through the
|
||||
// real store, with a fresh (zero) consumption row.
|
||||
func TestSelectPolicy_RealStore_AllowAndLowestGroupAttribution(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
p := capPolicy("pol-A", realSelectAccount, []string{"grp-zz", "grp-aa", "grp-mm"}, "prov-1", 10_000, 86_400)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p))
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-zz", "grp-aa", "grp-mm"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "fresh state under cap must allow")
|
||||
assert.Equal(t, "pol-A", res.SelectedPolicyID, "only applicable policy must be selected")
|
||||
assert.Equal(t, "grp-aa", res.AttributionGroupID, "lowest-by-sort intersecting group must win")
|
||||
assert.Equal(t, int64(86_400), res.WindowSeconds, "selected policy's window must be returned")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_LargerPoolWins_FallsThroughWhenExhausted pins the
|
||||
// core selection behavior end to end. The two policies bind DISTINCT groups so
|
||||
// they read separate counters — the only shape where fall-through actually
|
||||
// yields headroom (policies on the same group share one counter, as
|
||||
// policyselect_test.go notes). Larger pool wins fresh; after real consumption
|
||||
// drains the larger group, selection falls through to the smaller; once both
|
||||
// counters are exhausted the request is denied.
|
||||
func TestSelectPolicy_RealStore_LargerPoolWins_FallsThroughWhenExhausted(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
tight := capPolicy("pol-tight", realSelectAccount, []string{"grp-tight"}, "prov-1", 100, 86_400)
|
||||
tight.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
wide := capPolicy("pol-wide", realSelectAccount, []string{"grp-wide"}, "prov-1", 10_000, 86_400)
|
||||
wide.CreatedAt = time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, tight))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, wide))
|
||||
|
||||
// Caller is in both groups, so both policies apply with independent counters.
|
||||
in := PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-tight", "grp-wide"},
|
||||
ProviderID: "prov-1",
|
||||
}
|
||||
|
||||
// Fresh: larger pool wins.
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-wide", res.SelectedPolicyID, "larger pool drains first")
|
||||
|
||||
// Drain only the wide group's counter to its cap.
|
||||
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-wide", 86_400, 10_000, 0, 0))
|
||||
|
||||
// Wide exhausted, tight's separate counter is fresh → fall through to tight.
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "tight pool has its own untouched counter")
|
||||
assert.Equal(t, "pol-tight", res.SelectedPolicyID, "selection falls through to the smaller pool once the larger is exhausted")
|
||||
|
||||
// Drain the tight group's counter too → both exhausted → deny.
|
||||
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-tight", 86_400, 100, 0, 0))
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "both group counters exhausted must deny")
|
||||
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode, "deny code names the offending cap kind")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_BudgetCapDenies pins budget (USD) enforcement
|
||||
// through the real store: once recorded cost reaches the cap, deny.
|
||||
func TestSelectPolicy_RealStore_BudgetCapDenies(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
p := &types.Policy{
|
||||
ID: "pol-budget",
|
||||
AccountID: realSelectAccount,
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-eng"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
BudgetLimit: types.PolicyBudgetLimit{
|
||||
Enabled: true,
|
||||
GroupCapUsd: 5.0,
|
||||
WindowSeconds: 86_400,
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p))
|
||||
|
||||
in := PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-eng"},
|
||||
ProviderID: "prov-1",
|
||||
}
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "fresh budget must allow")
|
||||
|
||||
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-eng", 86_400, 0, 0, 5.0))
|
||||
|
||||
res, err = mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "cost at the cap must deny")
|
||||
assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode, "budget deny code must be surfaced")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_GroupCounterSharedAcrossPolicies pins that two
|
||||
// policies on the same group+window read one shared consumption counter: usage
|
||||
// recorded once is visible to both, so exhausting the group budget denies
|
||||
// regardless of which policy would attribute.
|
||||
func TestSelectPolicy_RealStore_GroupCounterSharedAcrossPolicies(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
a := capPolicy("pol-a", realSelectAccount, []string{"grp-eng"}, "prov-1", 1_000, 86_400)
|
||||
b := capPolicy("pol-b", realSelectAccount, []string{"grp-eng"}, "prov-1", 1_000, 86_400)
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, a))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, b))
|
||||
|
||||
in := PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-eng"},
|
||||
ProviderID: "prov-1",
|
||||
}
|
||||
|
||||
require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-eng", 86_400, 1_000, 0, 0))
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, in)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "shared group counter at cap denies both equal policies")
|
||||
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode, "token deny code on the shared counter")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RealStore_DisabledPolicyIgnored pins that a disabled policy
|
||||
// is invisible to selection even when it otherwise matches.
|
||||
func TestSelectPolicy_RealStore_DisabledPolicyIgnored(t *testing.T) {
|
||||
mgr, s := newRealSelectorMgr(t)
|
||||
ctx := context.Background()
|
||||
|
||||
p := capPolicy("pol-disabled", realSelectAccount, []string{"grp-eng"}, "prov-1", 10_000, 86_400)
|
||||
p.Enabled = false
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p))
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{
|
||||
AccountID: realSelectAccount,
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-eng"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "no enabled policy applies → pass-through allow")
|
||||
assert.Empty(t, res.SelectedPolicyID, "disabled policy must not be selected")
|
||||
}
|
||||
641
management/internals/modules/agentnetwork/policyselect_test.go
Normal file
641
management/internals/modules/agentnetwork/policyselect_test.go
Normal file
@@ -0,0 +1,641 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
nbstatus "github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func newSelectorMgr(t *testing.T, ctrl *gomock.Controller) (*managerImpl, *store.MockStore) {
|
||||
t.Helper()
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
// SelectPolicyForRequest evaluates the account-budget ceiling before policy
|
||||
// selection. These policy-selection tests don't exercise account rules, so
|
||||
// default to "no rules" — the no-mock policyselect_realstore_test.go covers
|
||||
// the account gate's behavior end to end.
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkBudgetRules(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nil, nil).
|
||||
AnyTimes()
|
||||
return &managerImpl{store: mockStore}, mockStore
|
||||
}
|
||||
|
||||
type usedKey struct {
|
||||
kind types.ConsumptionDimension
|
||||
dimID string
|
||||
window int64
|
||||
}
|
||||
|
||||
// expectConsumptionBatch stubs the batched consumption read to return the
|
||||
// supplied per-(kind, dim, window) counters, filling each row's window start
|
||||
// from the actual request keys so it always matches what the selector computed.
|
||||
// Keys absent from used resolve to zero counters.
|
||||
func expectConsumptionBatch(mockStore *store.MockStore, used map[usedKey]*types.Consumption) {
|
||||
mockStore.EXPECT().
|
||||
GetAgentNetworkConsumptionBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, _ store.LockingStrength, _ string, keys []types.ConsumptionKey) (map[types.ConsumptionKey]*types.Consumption, error) {
|
||||
out := make(map[types.ConsumptionKey]*types.Consumption)
|
||||
for _, k := range keys {
|
||||
if row, ok := used[usedKey{k.Kind, k.DimID, k.WindowSeconds}]; ok {
|
||||
rc := *row
|
||||
rc.WindowStartUTC = k.WindowStartUTC
|
||||
out[k] = &rc
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}).
|
||||
AnyTimes()
|
||||
}
|
||||
|
||||
func capPolicy(id, account string, sourceGroups []string, providerID string, tokenCap int64, windowSec int64) *types.Policy {
|
||||
return &types.Policy{
|
||||
ID: id,
|
||||
AccountID: account,
|
||||
Enabled: true,
|
||||
SourceGroups: sourceGroups,
|
||||
DestinationProviderIDs: []string{providerID},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{
|
||||
Enabled: true,
|
||||
GroupCap: tokenCap,
|
||||
WindowSeconds: windowSec,
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectPolicy_NoApplicablePolicies covers the pass-through path:
|
||||
// llm_router authorisation is upstream of selection; when the
|
||||
// selector finds no policy targeting the (provider, caller-groups)
|
||||
// combination, it returns Allow with no attribution and lets the
|
||||
// request continue without consumption tracking.
|
||||
func TestSelectPolicy_NoApplicablePolicies(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{}, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-x"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow, "no applicable policies = pass-through allow")
|
||||
assert.Empty(t, res.SelectedPolicyID, "no selection when nothing applies")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_AllowWithLowestGroupAttribution proves the v1
|
||||
// attribution rule: when the caller's groups intersect a policy's
|
||||
// source_groups in multiple positions, the selector picks the lowest
|
||||
// group id by string sort so multi-node selection converges.
|
||||
func TestSelectPolicy_AllowWithLowestGroupAttribution(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policy := capPolicy("pol-A", "acc-1", []string{"grp-zz", "grp-aa", "grp-mm"}, "prov-1", 10_000, 86_400)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policy}, nil)
|
||||
// Fresh: zero consumption across the board.
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-zz", "grp-aa", "grp-mm"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow)
|
||||
assert.Equal(t, "pol-A", res.SelectedPolicyID)
|
||||
assert.Equal(t, "grp-aa", res.AttributionGroupID,
|
||||
"lowest-by-sort intersection wins so multi-node selection converges")
|
||||
assert.Equal(t, int64(86_400), res.WindowSeconds)
|
||||
}
|
||||
|
||||
// TestSelectPolicy_LargerPoolWinsAcrossUsageLevels proves the core
|
||||
// selection rule: among multiple applicable policies with caps, the
|
||||
// selector picks the one with the larger absolute pool — at every
|
||||
// usage level, not just at fresh state. The smaller-pool policy is
|
||||
// only reached when the larger one is exhausted. This is the
|
||||
// "drain biggest first" semantic operators expect for layered
|
||||
// tiers; a fraction-based score would flap between the two as
|
||||
// soon as one is partially used.
|
||||
func TestSelectPolicy_LargerPoolWinsAcrossUsageLevels(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
tight := capPolicy("pol-tight", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400)
|
||||
tight.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
wide := capPolicy("pol-wide", "acc-1", []string{"grp-engineers"}, "prov-1", 10_000, 86_400)
|
||||
wide.CreatedAt = time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{tight, wide}, nil)
|
||||
|
||||
// Both partially used. tight at 50/100 (50% used); wide at
|
||||
// 50/10000 (0.5% used). Old fraction-based algo would pick wide
|
||||
// here too — but for the wrong reason ("more relative slack").
|
||||
// New algo picks wide because its initial group cap is bigger
|
||||
// (10000 > 100), and that decision is stable as wide drains.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 50},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-wide", res.SelectedPolicyID,
|
||||
"the policy with the bigger initial pool wins — operators expect 'drain the privileged tier first', not load-balance across tiers")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_StaysOnLargerPoolAfterPartialDrain locks the
|
||||
// stickiness contract reported by operators: with two policies
|
||||
// where A has a 200-token group cap and B has 150, the very first
|
||||
// request goes to A AND every subsequent request continues to land
|
||||
// on A until A's group cap is exhausted — at which point B becomes
|
||||
// the only candidate. A fraction-based score would flap to B as
|
||||
// soon as A had any consumption (B's 1.0 fraction beats A's 0.75)
|
||||
// even though A still has more absolute headroom; that produced
|
||||
// confusing per-policy attribution ledger entries and stranded
|
||||
// A's remaining capacity behind B's exhaustion.
|
||||
func TestSelectPolicy_StaysOnLargerPoolAfterPartialDrain(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policyA := capPolicy("pol-A-200", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400)
|
||||
policyB := capPolicy("pol-B-150", "acc-1", []string{"grp-engineers"}, "prov-1", 150, 86_400)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policyA, policyB}, nil)
|
||||
|
||||
// A is partially drained (50/200 used = 25% used; 75% headroom
|
||||
// remaining). B is fresh (0/150). The old fraction-based score
|
||||
// would pick B here (1.0 > 0.75 fraction); the new pool-size
|
||||
// score sticks with A (200 > 150 absolute cap).
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 50},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-A-200", res.SelectedPolicyID,
|
||||
"once attribution lands on the bigger pool it must STAY there until exhausted — operators expect 'drain A then B', not 'flip to B as soon as A is touched'")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_FallsThroughToSmallerPoolWhenLargerExhausted
|
||||
// proves the second half of the stickiness contract: once the
|
||||
// larger-pool policy IS exhausted, the smaller one takes over.
|
||||
// Without this we'd deny on requests the smaller policy is fully
|
||||
// equipped to serve.
|
||||
func TestSelectPolicy_FallsThroughToSmallerPoolWhenLargerExhausted(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policyA := capPolicy("pol-A-200", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400)
|
||||
// B uses a different window length so it has an INDEPENDENT counter — the
|
||||
// realistic shape for fall-through. On the SAME (group, window) tuple the
|
||||
// counter is shared, so A's cap of 200 being reached would also exhaust B's
|
||||
// 150; independent counters are what let A exhaust while B retains headroom.
|
||||
policyB := capPolicy("pol-B-150", "acc-1", []string{"grp-engineers"}, "prov-1", 150, 3_600)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policyA, policyB}, nil)
|
||||
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 200}, // A: 200 >= 200 → exhausted
|
||||
{types.DimensionGroup, "grp-engineers", 3_600}: {TokensInput: 100}, // B: 100 < 150 → headroom
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-B-150", res.SelectedPolicyID,
|
||||
"once the bigger pool is exhausted, the smaller one must take over — denying when capacity remains would strand B's allowance")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_TiebreakByLargerGroupPool covers the user-reported
|
||||
// bug: an admin in two groups (Users + Admins) where Users is bound
|
||||
// by a smaller-group-cap policy (50 group, 100 user) and Admins is
|
||||
// bound by a bigger-group-cap policy (100 group, 20 user) MUST get
|
||||
// attributed to the Admins policy on the first request.
|
||||
//
|
||||
// Without this rule, the fresh-state fraction is 1.0 for both and
|
||||
// the older policy wins by created_at. The first 24-token request
|
||||
// then drains the shared user counter past Admins's tight 20-token
|
||||
// user cap, locking Admins out of selection forever. The 100-token
|
||||
// Admins group pool ends up stranded while requests pile onto the
|
||||
// 50-token Users pool — the opposite of what the operator intended
|
||||
// when they put the bigger pool on the privileged group.
|
||||
func TestSelectPolicy_TiebreakByLargerGroupPool(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
// Policy A: Users group, smaller group pool, looser per-user cap.
|
||||
policyA := &types.Policy{
|
||||
ID: "pol-Users",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-Users"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{
|
||||
Enabled: true, GroupCap: 50, UserCap: 100, WindowSeconds: 86_400,
|
||||
},
|
||||
},
|
||||
// Older — would win the legacy created_at tiebreak.
|
||||
CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
// Policy B: Admins group, bigger group pool, tighter per-user cap.
|
||||
policyB := &types.Policy{
|
||||
ID: "pol-Admins",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-Admins"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{
|
||||
Enabled: true, GroupCap: 100, UserCap: 20, WindowSeconds: 86_400,
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policyA, policyB}, nil)
|
||||
// Fresh state: every cap evaluation reads zero usage.
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
UserID: "user-1",
|
||||
GroupIDs: []string{"grp-Users", "grp-Admins"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-Admins", res.SelectedPolicyID,
|
||||
"the bigger group pool wins the fresh-state tiebreak — picking Users first would burn the shared user counter past Admins's tight user cap on the very first request and strand the bigger Admins pool")
|
||||
assert.Equal(t, "grp-Admins", res.AttributionGroupID)
|
||||
}
|
||||
|
||||
// TestSelectPolicy_TiebreakByCreatedAt proves the deterministic
|
||||
// final tiebreak: when two applicable policies have the same
|
||||
// headroom fraction AND the same group cap (so the larger-pool rule
|
||||
// can't differentiate either), the older policy wins so attribution
|
||||
// is stable across replays.
|
||||
func TestSelectPolicy_TiebreakByCreatedAt(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
older := capPolicy("pol-old", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400)
|
||||
older.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
newer := capPolicy("pol-new", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400)
|
||||
newer.CreatedAt = time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{newer, older}, nil)
|
||||
// Both at zero consumption → identical headroom fraction.
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-old", res.SelectedPolicyID,
|
||||
"older policy wins on equal-headroom tiebreak so attribution is stable across replays")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_DeniesWhenAllExhausted proves the deny envelope:
|
||||
// when every applicable policy has at least one cap fully exhausted,
|
||||
// the selector returns Allow=false with the most-recent exhaustion's
|
||||
// deny code + human reason. The proxy's middleware surfaces this as
|
||||
// a 403 with the canonical llm_policy.* code.
|
||||
func TestSelectPolicy_DeniesWhenAllExhausted(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
a := capPolicy("pol-a", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400)
|
||||
b := capPolicy("pol-b", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{a, b}, nil)
|
||||
|
||||
// Shared group counter at 200: A (cap 100) and B (cap 200) both exhausted.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 200},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "every applicable policy exhausted = deny")
|
||||
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode)
|
||||
assert.Contains(t, res.DenyReason, "token cap exhausted",
|
||||
"deny reason must name the exhausted cap kind for operator debugging")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_UncappedPolicyAlwaysWinsAgainstCapped proves the
|
||||
// catch-all-allow contract: a policy with NO enabled caps wins
|
||||
// against any capped policy regardless of how much headroom the
|
||||
// capped one has, because operators who configure unlimited access
|
||||
// expect requests to attribute there until they explicitly add caps.
|
||||
func TestSelectPolicy_UncappedPolicyAlwaysWinsAgainstCapped(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
uncapped := &types.Policy{
|
||||
ID: "pol-uncapped",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-engineers"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
// All Limits.*.Enabled = false (zero-value).
|
||||
CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
wide := capPolicy("pol-wide", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000_000, 86_400)
|
||||
wide.CreatedAt = time.Date(2025, 12, 1, 0, 0, 0, 0, time.UTC) // older than uncapped
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{uncapped, wide}, nil)
|
||||
// Only the wide policy reads consumption; uncapped doesn't query
|
||||
// because it has no enabled caps.
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-uncapped", res.SelectedPolicyID,
|
||||
"a no-caps policy must always win selection — that's how operators express 'unlimited access through this path'")
|
||||
assert.Equal(t, int64(0), res.WindowSeconds, "no caps configured = WindowSeconds=0 so RecordLLMUsage skips counter writes")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_DisabledPolicyIgnored proves disabled policies
|
||||
// don't count toward selection — even when they'd otherwise be the
|
||||
// best match. Operators disable a policy to take it offline; the
|
||||
// selector must respect that and route through whatever's left.
|
||||
func TestSelectPolicy_DisabledPolicyIgnored(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
disabled := capPolicy("pol-disabled", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000_000, 86_400)
|
||||
disabled.Enabled = false
|
||||
enabled := capPolicy("pol-enabled", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{disabled, enabled}, nil)
|
||||
expectConsumptionBatch(mockStore, nil)
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pol-enabled", res.SelectedPolicyID,
|
||||
"disabled policies must be ignored at selection time")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_StoreErrorPropagates locks the no-fail-open
|
||||
// contract: a transient store error must surface to the caller, not
|
||||
// be silently treated as "no policies = allow". A false allow on the
|
||||
// hot path would let a request slip past every cap.
|
||||
func TestSelectPolicy_StoreErrorPropagates(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return(nil, errors.New("boom"))
|
||||
|
||||
_, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
})
|
||||
require.Error(t, err, "store errors must surface — never fail open on the hot path")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_RejectsEmptyAccount is the input-validation guard:
|
||||
// empty account_id is a programmer error and must surface as
|
||||
// InvalidArgument, not as a silent zero-result lookup.
|
||||
func TestSelectPolicy_RejectsEmptyAccount(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, _ := newSelectorMgr(t, ctrl)
|
||||
|
||||
_, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{})
|
||||
require.Error(t, err)
|
||||
var sErr *nbstatus.Error
|
||||
require.True(t, errors.As(err, &sErr))
|
||||
assert.Equal(t, nbstatus.InvalidArgument, sErr.Type())
|
||||
}
|
||||
|
||||
// TestSelectPolicy_SharesGroupCounterAcrossPolicies locks the
|
||||
// counter-keying design fork: counters are keyed on (account,
|
||||
// dim_kind, dim_id, window_hours, window_start) — NOT on policy_id.
|
||||
// Two policies that target the same group with the SAME window length
|
||||
// share one bucket: spend booked under policy A is visible to policy
|
||||
// B's headroom calculation and counts toward B's cap.
|
||||
//
|
||||
// This is what makes "operator's per-group enforcement" sane — caps
|
||||
// describe how much a GROUP can use, not how much each policy owes.
|
||||
func TestSelectPolicy_SharesGroupCounterAcrossPolicies(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
// Two policies, both targeting grp-engineers + prov-1, same 24h
|
||||
// window length. Different cap sizes.
|
||||
policyA := capPolicy("pol-A", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400)
|
||||
policyB := capPolicy("pol-B", "acc-1", []string{"grp-engineers"}, "prov-1", 5_000, 86_400)
|
||||
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policyA, policyB}, nil)
|
||||
// Both policies query the SAME consumption row — same dim_id,
|
||||
// same window_hours, same window_start. The mock returns the
|
||||
// same row for both calls, simulating the shared counter.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 800},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// 800 used → policy A has 200 tokens left of 1000 (20% headroom);
|
||||
// policy B has 4200 left of 5000 (84% headroom). B wins.
|
||||
assert.Equal(t, "pol-B", res.SelectedPolicyID,
|
||||
"the SAME 800 tokens count toward both policies — counters share the (group, window) key, caps differ per policy")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_AntiFallThroughOnLowestGroup locks the no-fall-
|
||||
// through behaviour: when a caller is in multiple of a policy's
|
||||
// source_groups and the lowest-by-sort group is exhausted, we DENY
|
||||
// rather than fall through to a less-loaded sibling. Per-group caps
|
||||
// are independent (each group has its own bucket), but attribution
|
||||
// is one-shot — operators wanting fall-through must split into
|
||||
// separate policies.
|
||||
//
|
||||
// This nails down semantics future contributors might "improve" into
|
||||
// fall-through behaviour by accident.
|
||||
func TestSelectPolicy_AntiFallThroughOnLowestGroup(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
// Policy targets two groups; caller is in both.
|
||||
policy := capPolicy("pol-1", "acc-1", []string{"grp-aaa", "grp-bbb"}, "prov-1", 100, 86_400)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policy}, nil)
|
||||
|
||||
// grp-aaa is the lowest by sort → attribution picks it, and the
|
||||
// prefetch only collects the attribution group's key. We exhaust
|
||||
// grp-aaa (100/100); grp-bbb's counter is never requested because the
|
||||
// selector attributes one-shot to the lowest group, so it can't fall
|
||||
// through to a less-loaded sibling.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-aaa", 86_400}: {TokensInput: 100},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-aaa", "grp-bbb"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow,
|
||||
"lowest-group-by-sort attribution does NOT fall through to a less-loaded sibling — operators wanting fall-through must split into separate policies")
|
||||
assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode)
|
||||
assert.Contains(t, res.DenyReason, "pol-1",
|
||||
"deny reason names the exhausted policy id so operators can grep it from the access log")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_BudgetOnlyExhaustionDenies covers the symmetric
|
||||
// path to TestSelectPolicy_DeniesWhenAllExhausted but for the budget
|
||||
// cap: a policy with token_limit DISABLED and budget_limit at-cap
|
||||
// must deny with llm_policy.budget_cap_exceeded (not the token code).
|
||||
//
|
||||
// Without this, the budget evaluation path in evalBudgetCap could
|
||||
// silently regress and we'd still pass DeniesWhenAllExhausted (which
|
||||
// only exercises tokens).
|
||||
func TestSelectPolicy_BudgetOnlyExhaustionDenies(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policy := &types.Policy{
|
||||
ID: "pol-budget",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-engineers"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{Enabled: false},
|
||||
BudgetLimit: types.PolicyBudgetLimit{
|
||||
Enabled: true,
|
||||
GroupCapUsd: 10.00,
|
||||
WindowSeconds: 86_400,
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policy}, nil)
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {CostUSD: 10.50}, // over the $10 cap
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow, "budget cap exhausted must deny independently of any token cap state")
|
||||
assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode,
|
||||
"deny code must be the budget code — token-only deny would silently regress the budget evaluation path")
|
||||
assert.Contains(t, res.DenyReason, "budget", "deny reason names the budget cap kind for operator debugging")
|
||||
}
|
||||
|
||||
// TestSelectPolicy_BudgetTighterThanTokenWins is the dual-cap headroom
|
||||
// fork: when both Token and Budget are enabled on the same policy,
|
||||
// the SMALLER remaining ratio gates the policy. A policy with
|
||||
// abundant token headroom but near-zero budget headroom must deny on
|
||||
// budget, not pass on tokens.
|
||||
func TestSelectPolicy_BudgetTighterThanTokenWins(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr, mockStore := newSelectorMgr(t, ctrl)
|
||||
|
||||
policy := &types.Policy{
|
||||
ID: "pol-dual",
|
||||
AccountID: "acc-1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{"grp-engineers"},
|
||||
DestinationProviderIDs: []string{"prov-1"},
|
||||
Limits: types.PolicyLimits{
|
||||
TokenLimit: types.PolicyTokenLimit{Enabled: true, GroupCap: 10_000_000, WindowSeconds: 86_400},
|
||||
BudgetLimit: types.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 1.00, WindowSeconds: 86_400},
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1").
|
||||
Return([]*types.Policy{policy}, nil)
|
||||
// One shared counter carries both token usage (ample headroom) and cost
|
||||
// (at the $1 budget cap); the tighter budget cap gates the policy.
|
||||
expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{
|
||||
{types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 100, CostUSD: 1.00},
|
||||
})
|
||||
|
||||
res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{
|
||||
AccountID: "acc-1",
|
||||
GroupIDs: []string{"grp-engineers"},
|
||||
ProviderID: "prov-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.Allow,
|
||||
"the tighter of (token, budget) wins — abundant token headroom must NOT mask an exhausted budget")
|
||||
assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode)
|
||||
}
|
||||
131
management/internals/modules/agentnetwork/reconcile.go
Normal file
131
management/internals/modules/agentnetwork/reconcile.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// reconcile recomputes the synthesised reverse-proxy services for an
|
||||
// account, diffs them against the previously-synthesised set in the
|
||||
// in-memory cache, and emits Create / Update / Delete proxy mappings
|
||||
// to the affected clusters. Also triggers a peer-side network-map
|
||||
// recompute via accountManager.UpdateAccountPeers so the
|
||||
// private-service ACL injection picks up the new state immediately.
|
||||
//
|
||||
// Reconcile failures are logged and swallowed — the underlying CRUD
|
||||
// has already completed, and the next mutation (or proxy reconnect)
|
||||
// will re-converge the cluster's view.
|
||||
func (m *managerImpl) reconcile(ctx context.Context, accountID string) {
|
||||
if accountID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if m.accountManager != nil {
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{
|
||||
Resource: types.UpdateResourceService,
|
||||
Operation: types.UpdateOperationUpdate,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
if m.proxyController == nil {
|
||||
return
|
||||
}
|
||||
|
||||
services, err := SynthesizeServices(ctx, m.store, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).WithError(err).Warnf("agent-network reconcile: synthesise services for account %s", accountID)
|
||||
return
|
||||
}
|
||||
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
current := make(map[string]*proto.ProxyMapping, len(services))
|
||||
for _, svc := range services {
|
||||
if svc == nil || svc.ID == "" {
|
||||
continue
|
||||
}
|
||||
current[svc.ID] = svc.ToProtoMapping(rpservice.Update, "", oidcCfg)
|
||||
}
|
||||
|
||||
m.reconcileMu.Lock()
|
||||
previous := m.reconcileCache[accountID]
|
||||
if previous == nil {
|
||||
previous = make(map[string]*proto.ProxyMapping)
|
||||
}
|
||||
|
||||
creates, updates, deletes := diffMappings(previous, current)
|
||||
if len(current) == 0 {
|
||||
delete(m.reconcileCache, accountID)
|
||||
} else {
|
||||
m.reconcileCache[accountID] = current
|
||||
}
|
||||
m.reconcileMu.Unlock()
|
||||
|
||||
for _, mapping := range creates {
|
||||
mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping))
|
||||
}
|
||||
for _, mapping := range updates {
|
||||
mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping))
|
||||
}
|
||||
for _, mapping := range deletes {
|
||||
mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping))
|
||||
}
|
||||
}
|
||||
|
||||
// diffMappings classifies the previous→current transition for a
|
||||
// single account into Create / Update / Delete sets.
|
||||
//
|
||||
// Cluster moves (current.cluster != previous.cluster) are surfaced as
|
||||
// a Delete on the old cluster + Create on the new — handled by
|
||||
// emitting both a delete (on previous mapping) and a create (on the
|
||||
// current mapping) for that service ID.
|
||||
func diffMappings(previous, current map[string]*proto.ProxyMapping) (creates, updates, deletes []*proto.ProxyMapping) {
|
||||
for id, cur := range current {
|
||||
prev, existed := previous[id]
|
||||
switch {
|
||||
case !existed:
|
||||
creates = append(creates, cur)
|
||||
case prev.GetDomain() == "" || cur.GetAccountId() == prev.GetAccountId() && currentClusterChanged(prev, cur):
|
||||
deletes = append(deletes, prev)
|
||||
creates = append(creates, cur)
|
||||
default:
|
||||
updates = append(updates, cur)
|
||||
}
|
||||
}
|
||||
for id, prev := range previous {
|
||||
if _, stillThere := current[id]; !stillThere {
|
||||
deletes = append(deletes, prev)
|
||||
}
|
||||
}
|
||||
return creates, updates, deletes
|
||||
}
|
||||
|
||||
func currentClusterChanged(prev, cur *proto.ProxyMapping) bool {
|
||||
return clusterFromMapping(prev) != clusterFromMapping(cur)
|
||||
}
|
||||
|
||||
// clusterFromMapping returns the cluster the mapping should be sent
|
||||
// to. ProxyMapping doesn't carry the cluster directly, so we rely on
|
||||
// the synthesised service's domain (`<slug>.<cluster>`) and split on
|
||||
// the first '.'.
|
||||
func clusterFromMapping(m *proto.ProxyMapping) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
domain := m.GetDomain()
|
||||
for i := 0; i < len(domain); i++ {
|
||||
if domain[i] == '.' {
|
||||
return domain[i+1:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
232
management/internals/modules/agentnetwork/reconcile_test.go
Normal file
232
management/internals/modules/agentnetwork/reconcile_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func newReconcileMgr(t *testing.T, ctrl *gomock.Controller) (*managerImpl, *store.MockStore, *proxy.MockController) {
|
||||
t.Helper()
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockProxy := proxy.NewMockController(ctrl)
|
||||
return &managerImpl{
|
||||
store: mockStore,
|
||||
proxyController: mockProxy,
|
||||
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
|
||||
}, mockStore, mockProxy
|
||||
}
|
||||
|
||||
func newReconcileTestProvider() *types.Provider {
|
||||
return &types.Provider{
|
||||
ID: "prov-1",
|
||||
AccountID: "acct-1",
|
||||
ProviderID: "openai_api",
|
||||
Name: "OpenAI",
|
||||
UpstreamURL: "https://api.openai.com",
|
||||
APIKey: "sk-test-key",
|
||||
Enabled: true,
|
||||
SessionPrivateKey: "test-priv-key",
|
||||
SessionPublicKey: "test-pub-key",
|
||||
}
|
||||
}
|
||||
|
||||
func newReconcileTestPolicy(providerID, sourceGroupID string) *types.Policy {
|
||||
return &types.Policy{
|
||||
ID: "pol-1",
|
||||
AccountID: "acct-1",
|
||||
Name: "engineers",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{sourceGroupID},
|
||||
DestinationProviderIDs: []string{providerID},
|
||||
}
|
||||
}
|
||||
|
||||
func newReconcileTestSettings() *types.Settings {
|
||||
return &types.Settings{
|
||||
AccountID: "acct-1",
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
Subdomain: "violet",
|
||||
}
|
||||
}
|
||||
|
||||
func expectReconcileSynthInputs(mockStore *store.MockStore, ctx context.Context, providers []*types.Provider, policies []*types.Policy, guardrails []*types.Guardrail) {
|
||||
mockStore.EXPECT().
|
||||
GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(newReconcileTestSettings(), nil)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(providers, nil)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(policies, nil)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(guardrails, nil)
|
||||
}
|
||||
|
||||
func TestReconcile_FirstSynth_EmitsCreate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl)
|
||||
provider := newReconcileTestProvider()
|
||||
policy := newReconcileTestPolicy(provider.ID, "grp-eng")
|
||||
|
||||
expectReconcileSynthInputs(mockStore, ctx, []*types.Provider{provider}, []*types.Policy{policy}, []*types.Guardrail{})
|
||||
mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{})
|
||||
|
||||
var sentMappings []*proto.ProxyMapping
|
||||
mockProxy.EXPECT().
|
||||
SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), "eu.proxy.netbird.io").
|
||||
Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) {
|
||||
sentMappings = append(sentMappings, m)
|
||||
})
|
||||
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
|
||||
require.Len(t, sentMappings, 1, "first synth must emit one mapping")
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, sentMappings[0].Type, "first synth is a Create")
|
||||
assert.Equal(t, "agent-net-svc-acct-1", sentMappings[0].Id, "stable account-scoped virtual service id")
|
||||
assert.Equal(t, "violet.eu.proxy.netbird.io", sentMappings[0].Domain, "domain comes from settings (subdomain.cluster)")
|
||||
|
||||
mgr.reconcileMu.Lock()
|
||||
cached := mgr.reconcileCache["acct-1"]
|
||||
mgr.reconcileMu.Unlock()
|
||||
require.Len(t, cached, 1, "cache must hold the synth result for next diff")
|
||||
}
|
||||
|
||||
func TestReconcile_NoChange_EmitsNothingExtra(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl)
|
||||
provider := newReconcileTestProvider()
|
||||
policy := newReconcileTestPolicy(provider.ID, "grp-eng")
|
||||
|
||||
// Two identical synth runs.
|
||||
mockStore.EXPECT().
|
||||
GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return(newReconcileTestSettings(), nil).Times(2)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return([]*types.Provider{provider}, nil).Times(2)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return([]*types.Policy{policy}, nil).Times(2)
|
||||
mockStore.EXPECT().
|
||||
GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1").
|
||||
Return([]*types.Guardrail{}, nil).Times(2)
|
||||
mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).Times(2)
|
||||
|
||||
createCalls := 0
|
||||
updateCalls := 0
|
||||
mockProxy.EXPECT().
|
||||
SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), gomock.Any()).
|
||||
Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) {
|
||||
switch m.Type {
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
|
||||
createCalls++
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
|
||||
updateCalls++
|
||||
}
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
|
||||
assert.Equal(t, 1, createCalls, "first reconcile creates")
|
||||
assert.Equal(t, 1, updateCalls, "second reconcile re-pushes as Modified (no semantic change but mapping fields refresh)")
|
||||
}
|
||||
|
||||
func TestReconcile_PolicyRemoved_EmitsDelete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl)
|
||||
provider := newReconcileTestProvider()
|
||||
policy := newReconcileTestPolicy(provider.ID, "grp-eng")
|
||||
|
||||
gomock.InOrder(
|
||||
// First reconcile: provider + policy, synthesised.
|
||||
mockStore.EXPECT().GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").Return(newReconcileTestSettings(), nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Provider{provider}, nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Policy{policy}, nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Guardrail{}, nil),
|
||||
// Second reconcile: policy gone, provider stays but no longer referenced.
|
||||
mockStore.EXPECT().GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").Return(newReconcileTestSettings(), nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Provider{provider}, nil),
|
||||
mockStore.EXPECT().GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Policy{}, nil),
|
||||
)
|
||||
mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()
|
||||
|
||||
var seenTypes []proto.ProxyMappingUpdateType
|
||||
mockProxy.EXPECT().
|
||||
SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), "eu.proxy.netbird.io").
|
||||
Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) {
|
||||
seenTypes = append(seenTypes, m.Type)
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
|
||||
require.Len(t, seenTypes, 2, "create then delete")
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, seenTypes[0])
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, seenTypes[1])
|
||||
|
||||
mgr.reconcileMu.Lock()
|
||||
_, present := mgr.reconcileCache["acct-1"]
|
||||
mgr.reconcileMu.Unlock()
|
||||
assert.False(t, present, "cache for the account must be cleared once nothing is synthesised")
|
||||
}
|
||||
|
||||
func TestReconcile_NilProxyController_NoOp(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mgr := &managerImpl{
|
||||
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
|
||||
}
|
||||
// Must not panic; must not query the store.
|
||||
mgr.reconcile(ctx, "acct-1")
|
||||
}
|
||||
|
||||
func TestReconcile_EmptyAccountID_NoOp(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mgr, _, _ := newReconcileMgr(t, ctrl)
|
||||
// Empty accountID short-circuits before any store call.
|
||||
mgr.reconcile(ctx, "")
|
||||
}
|
||||
|
||||
func TestClusterFromMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
want string
|
||||
}{
|
||||
{"simple", "openai.eu.proxy.netbird.io", "eu.proxy.netbird.io"},
|
||||
{"deeply nested", "a.b.c.d", "b.c.d"},
|
||||
{"no dot", "openai", ""},
|
||||
{"empty", "", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := clusterFromMapping(&proto.ProxyMapping{Domain: tt.domain})
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
1059
management/internals/modules/agentnetwork/synthesizer.go
Normal file
1059
management/internals/modules/agentnetwork/synthesizer.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,178 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// decodeServiceGuardrailConfig pulls the llm_guardrail middleware config off the
|
||||
// synthesised service's single target.
|
||||
func decodeServiceGuardrailConfig(t *testing.T, svc *rpservice.Service) guardrailConfig {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, svc.Targets, "synth service must carry a target")
|
||||
for _, mw := range svc.Targets[0].Options.Middlewares {
|
||||
if mw.ID == middlewareIDLLMGuardrail {
|
||||
var cfg guardrailConfig
|
||||
require.NoError(t, json.Unmarshal(mw.ConfigJSON, &cfg), "guardrail config must decode")
|
||||
return cfg
|
||||
}
|
||||
}
|
||||
t.Fatal("llm_guardrail middleware not present on synthesised service")
|
||||
return guardrailConfig{}
|
||||
}
|
||||
|
||||
// decodeMiddlewareRawConfig returns the raw ConfigJSON bytes for the named
|
||||
// middleware on the synth service's target, or fails the test.
|
||||
func decodeMiddlewareRawConfig(t *testing.T, svc *rpservice.Service, id string) []byte {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, svc.Targets, "synth service must carry a target")
|
||||
for _, mw := range svc.Targets[0].Options.Middlewares {
|
||||
if mw.ID == id {
|
||||
return mw.ConfigJSON
|
||||
}
|
||||
}
|
||||
t.Fatalf("middleware %q not present on synthesised service", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveGuardrailAndPolicy persists a guardrail with prompt capture + redact + a
|
||||
// model allowlist, referenced by one enabled policy. Shared by the GC-3 tests.
|
||||
func saveGuardrailAndPolicy(t *testing.T, ctx context.Context, s store.Store, provider *types.Provider) {
|
||||
t.Helper()
|
||||
guardrail := &types.Guardrail{
|
||||
ID: "ainguard-1",
|
||||
AccountID: testAccountID,
|
||||
Name: "strict",
|
||||
Checks: types.GuardrailChecks{
|
||||
ModelAllowlist: types.GuardrailModelAllowlist{Enabled: true, Models: []string{"gpt-5.4"}},
|
||||
PromptCapture: types.GuardrailPromptCapture{Enabled: true, RedactPii: true},
|
||||
},
|
||||
}
|
||||
require.NoError(t, s.SaveAgentNetworkGuardrail(ctx, guardrail))
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", guardrail.ID)))
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_PromptCaptureAccountIsSoleControl is the
|
||||
// GC-3 contract: the account master switch (EnablePromptCollection) is the
|
||||
// SOLE control for capture enablement. Policy-level guardrail prompt_capture is
|
||||
// ignored for enablement — operators don't need to attach a capture guardrail
|
||||
// to a policy just to turn capture on for the account. Off by default.
|
||||
func TestSynthesizeServices_RealStore_PromptCaptureAccountIsSoleControl(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
// Account collection master switch OFF (default).
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
saveGuardrailAndPolicy(t, ctx, s, newSynthTestProvider())
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
cfg := decodeServiceGuardrailConfig(t, services[0])
|
||||
assert.Equal(t, []string{"gpt-5.4"}, cfg.ModelAllowlist,
|
||||
"model allowlist is a pure policy guardrail and must always reach the config")
|
||||
assert.False(t, cfg.PromptCapture.Enabled,
|
||||
"prompt capture must be off when the account toggle is off, even with a capture-enabled guardrail")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_PromptCaptureFlowsWhenAccountOptsIn proves
|
||||
// the account toggle is sufficient on its own — even with NO guardrail
|
||||
// attached to the policy, capture fires when the account opts in. Redact is
|
||||
// the OR of account + guardrail.
|
||||
func TestSynthesizeServices_RealStore_PromptCaptureFlowsWhenAccountOptsIn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.EnablePromptCollection = true
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
|
||||
// Save a provider and a policy with NO guardrails attached — proves the
|
||||
// account toggle is sufficient on its own.
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
cfg := decodeServiceGuardrailConfig(t, services[0])
|
||||
assert.True(t, cfg.PromptCapture.Enabled,
|
||||
"account toggle alone must enable capture; no guardrail attachment required")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_AccountRedactWithoutGuardrailRedact proves
|
||||
// the redact OR-merge from the account side: account RedactPii on, guardrail
|
||||
// redact off, capture on at both levels.
|
||||
func TestSynthesizeServices_RealStore_AccountRedactWithoutGuardrailRedact(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.EnablePromptCollection = true
|
||||
settings.RedactPii = true
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
guardrail := &types.Guardrail{
|
||||
ID: "ainguard-noredact",
|
||||
AccountID: testAccountID,
|
||||
Name: "capture-only",
|
||||
Checks: types.GuardrailChecks{
|
||||
PromptCapture: types.GuardrailPromptCapture{Enabled: true, RedactPii: false},
|
||||
},
|
||||
}
|
||||
require.NoError(t, s.SaveAgentNetworkGuardrail(ctx, guardrail))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", guardrail.ID)))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
cfg := decodeServiceGuardrailConfig(t, services[0])
|
||||
assert.True(t, cfg.PromptCapture.Enabled, "capture on (account + guardrail)")
|
||||
assert.True(t, cfg.PromptCapture.RedactPii, "account RedactPii must apply even when the guardrail leaves it off (OR)")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_NoGuardrail_CaptureOff pins the default:
|
||||
// with no guardrail referenced, the synth service's guardrail config has prompt
|
||||
// capture disabled and an empty allowlist. This is the "off by default" baseline
|
||||
// the account switch must preserve.
|
||||
func TestSynthesizeServices_RealStore_NoGuardrail_CaptureOff(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1, "exactly one synth service expected")
|
||||
|
||||
cfg := decodeServiceGuardrailConfig(t, services[0])
|
||||
assert.Empty(t, cfg.ModelAllowlist, "no guardrail → no allowlist")
|
||||
assert.False(t, cfg.PromptCapture.Enabled, "no guardrail → prompt capture off by default")
|
||||
assert.False(t, cfg.PromptCapture.RedactPii, "no guardrail → redact off by default")
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// TestSynthesizeServices_RealStore_LogCollectionOff_SuppressesAccessLog drives the
|
||||
// happy default: account settings ship with EnableLogCollection=false, so the
|
||||
// synthesised target opts out of access-log emission (DisableAccessLog=true) and
|
||||
// the proto mapping the proxy receives reflects that.
|
||||
func TestSynthesizeServices_RealStore_LogCollectionOff_SuppressesAccessLog(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1, "exactly one synth service expected")
|
||||
require.NotEmpty(t, services[0].Targets, "synth service must carry a target")
|
||||
assert.True(t, services[0].Targets[0].Options.DisableAccessLog,
|
||||
"EnableLogCollection=false (default) must produce DisableAccessLog=true on the synth target")
|
||||
|
||||
mapping := services[0].ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{})
|
||||
require.NotEmpty(t, mapping.GetPath(), "proto mapping must carry a path")
|
||||
assert.True(t, mapping.GetPath()[0].GetOptions().GetDisableAccessLog(),
|
||||
"proto mapping must propagate DisableAccessLog=true so the proxy suppresses access-log emission")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_LogCollectionOn_PermitsAccessLog asserts the
|
||||
// inverse: once the account opts in, the synth target leaves DisableAccessLog
|
||||
// at its default false and the proto wire stays unset.
|
||||
func TestSynthesizeServices_RealStore_LogCollectionOn_PermitsAccessLog(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.EnableLogCollection = true
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1, "exactly one synth service expected")
|
||||
require.NotEmpty(t, services[0].Targets, "synth service must carry a target")
|
||||
assert.False(t, services[0].Targets[0].Options.DisableAccessLog,
|
||||
"EnableLogCollection=true must leave DisableAccessLog=false on the synth target")
|
||||
|
||||
mapping := services[0].ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{})
|
||||
require.NotEmpty(t, mapping.GetPath(), "proto mapping must carry a path")
|
||||
assert.False(t, mapping.GetPath()[0].GetOptions().GetDisableAccessLog(),
|
||||
"proto mapping must propagate DisableAccessLog=false so access-log emission stays on")
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// parserRedactConfig mirrors the on-wire shape of the redact + capture knobs
|
||||
// that both llm_request_parser and llm_response_parser unmarshal. We don't
|
||||
// import the proxy-side packages from a management test (cross-module), so we
|
||||
// decode the JSON directly and assert on the fields that are part of the
|
||||
// synth contract.
|
||||
type parserRedactConfig struct {
|
||||
RedactPii bool `json:"redact_pii,omitempty"`
|
||||
CapturePrompt *bool `json:"capture_prompt,omitempty"` // present only on the request parser
|
||||
CaptureCompletion *bool `json:"capture_completion,omitempty"` // present only on the response parser
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_ParserConfigsCarryRedactPii is the
|
||||
// management-side contract test for the request/response parser redaction
|
||||
// wiring. When settings.RedactPii is true, the synthesised middleware chain
|
||||
// MUST stamp redact_pii=true on both llm_request_parser and llm_response_parser
|
||||
// configs — otherwise the parsers ship raw prompts / completions to the
|
||||
// access log even though the account has opted in. This is exactly the live
|
||||
// leak path that motivated the parser-side redaction in the first place.
|
||||
func TestSynthesizeServices_RealStore_ParserConfigsCarryRedactPii(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.RedactPii = true
|
||||
settings.EnablePromptCollection = true
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1, "exactly one synth service expected")
|
||||
|
||||
for _, parserID := range []string{middlewareIDLLMRequestParser, middlewareIDLLMResponseParser} {
|
||||
raw := decodeMiddlewareRawConfig(t, services[0], parserID)
|
||||
var cfg parserRedactConfig
|
||||
require.NoError(t, json.Unmarshal(raw, &cfg), "%s config must be valid JSON", parserID)
|
||||
assert.True(t, cfg.RedactPii, "%s config must carry redact_pii=true when settings.RedactPii is on (otherwise the parser ships raw prompts/completions to the access log)", parserID)
|
||||
}
|
||||
// The capture flag is set explicitly to enable_prompt_collection on each
|
||||
// parser. With it on here, both must allow emission.
|
||||
reqCfg := decodeParserConfig(t, services[0], middlewareIDLLMRequestParser)
|
||||
require.NotNil(t, reqCfg.CapturePrompt, "request parser must carry an explicit capture_prompt")
|
||||
assert.True(t, *reqCfg.CapturePrompt, "capture_prompt=true when EnablePromptCollection=true")
|
||||
respCfg := decodeParserConfig(t, services[0], middlewareIDLLMResponseParser)
|
||||
require.NotNil(t, respCfg.CaptureCompletion, "response parser must carry an explicit capture_completion")
|
||||
assert.True(t, *respCfg.CaptureCompletion, "capture_completion=true when EnablePromptCollection=true")
|
||||
}
|
||||
|
||||
// decodeParserConfig is a small helper around decodeMiddlewareRawConfig that
|
||||
// also unmarshals into parserRedactConfig.
|
||||
func decodeParserConfig(t *testing.T, svc *rpservice.Service, parserID string) parserRedactConfig {
|
||||
t.Helper()
|
||||
raw := decodeMiddlewareRawConfig(t, svc, parserID)
|
||||
var cfg parserRedactConfig
|
||||
require.NoError(t, json.Unmarshal(raw, &cfg), "%s config must be valid JSON", parserID)
|
||||
return cfg
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_ParserConfigsSuppressCaptureWhenLogCollectionOnly
|
||||
// is the contract test for the bug: enable_log_collection=true with
|
||||
// enable_prompt_collection=false MUST result in capture_prompt=false on the
|
||||
// request parser AND capture_completion=false on the response parser, so the
|
||||
// access-log row stays metadata-only (provider, model, tokens, cost) and
|
||||
// carries NO prompt input nor response output. Without this, operators who
|
||||
// want billing-style logs end up with raw user prompts and model outputs in
|
||||
// every access-log entry.
|
||||
func TestSynthesizeServices_RealStore_ParserConfigsSuppressCaptureWhenLogCollectionOnly(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
settings := newSynthTestSettings()
|
||||
settings.EnableLogCollection = true // operator wants logs ON
|
||||
settings.EnablePromptCollection = false // but NOT content capture
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings))
|
||||
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
reqCfg := decodeParserConfig(t, services[0], middlewareIDLLMRequestParser)
|
||||
require.NotNil(t, reqCfg.CapturePrompt, "request parser must carry an explicit capture_prompt gate")
|
||||
assert.False(t, *reqCfg.CapturePrompt, "capture_prompt MUST be false when EnablePromptCollection is off — otherwise llm.request_prompt_raw leaks user input into the access log")
|
||||
|
||||
respCfg := decodeParserConfig(t, services[0], middlewareIDLLMResponseParser)
|
||||
require.NotNil(t, respCfg.CaptureCompletion, "response parser must carry an explicit capture_completion gate")
|
||||
assert.False(t, *respCfg.CaptureCompletion, "capture_completion MUST be false when EnablePromptCollection is off — otherwise llm.response_completion leaks model output into the access log")
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_ParserConfigsOmitRedactPiiWhenOff proves
|
||||
// the inverse: with the account toggle off, the parser configs stay clean (no
|
||||
// redact_pii field, which the parsers treat as zero / no redaction). This is
|
||||
// the operator-opt-out path — the access log keeps raw prompts/completions
|
||||
// for debugging until the operator opts in.
|
||||
func TestSynthesizeServices_RealStore_ParserConfigsOmitRedactPiiWhenOff(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
|
||||
// Default settings: RedactPii = false.
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
for _, parserID := range []string{middlewareIDLLMRequestParser, middlewareIDLLMResponseParser} {
|
||||
raw := decodeMiddlewareRawConfig(t, services[0], parserID)
|
||||
// Inspect the decoded JSON directly: a struct decode would also pass
|
||||
// if redact_pii were present-but-false. The contract is that the key
|
||||
// is omitted entirely while the account toggle is off.
|
||||
var rawCfg map[string]json.RawMessage
|
||||
require.NoError(t, json.Unmarshal(raw, &rawCfg), "%s config must be valid JSON", parserID)
|
||||
assert.NotContains(t, rawCfg, "redact_pii",
|
||||
"%s config must omit redact_pii entirely while the account toggle is off", parserID)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// decodeServiceRouterConfig finds the llm_router middleware on the synthesised
|
||||
// service's single target and decodes its config — the model→provider routing
|
||||
// table the proxy authorises against.
|
||||
func decodeServiceRouterConfig(t *testing.T, svc *rpservice.Service) routerConfig {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, svc.Targets, "synth service must carry a target")
|
||||
for _, mw := range svc.Targets[0].Options.Middlewares {
|
||||
if mw.ID == middlewareIDLLMRouter {
|
||||
var cfg routerConfig
|
||||
require.NoError(t, json.Unmarshal(mw.ConfigJSON, &cfg), "router config must decode")
|
||||
return cfg
|
||||
}
|
||||
}
|
||||
t.Fatal("llm_router middleware not present on synthesised service")
|
||||
return routerConfig{}
|
||||
}
|
||||
|
||||
// decodeMappingRouterConfig is the proto-wire equivalent: it pulls the
|
||||
// llm_router config off the ProxyMapping the proxy actually receives.
|
||||
func decodeMappingRouterConfig(t *testing.T, m *proto.ProxyMapping) routerConfig {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, m.GetPath(), "mapping must carry a path")
|
||||
for _, mw := range m.GetPath()[0].GetOptions().GetMiddlewares() {
|
||||
if mw.GetId() == middlewareIDLLMRouter {
|
||||
var cfg routerConfig
|
||||
require.NoError(t, json.Unmarshal(mw.GetConfigJson(), &cfg), "wire router config must decode")
|
||||
return cfg
|
||||
}
|
||||
}
|
||||
t.Fatal("llm_router middleware not present on proxy mapping")
|
||||
return routerConfig{}
|
||||
}
|
||||
|
||||
// TestSynthesizeServices_RealStore_SurvivesStatusToggle drives synthesis through
|
||||
// a REAL sqlite store (Save → gorm/JSON serialize → reload → decrypt) instead of
|
||||
// a MockStore, so it exercises the field round-trip that a provider/policy edit
|
||||
// actually hits. Mock-based tests can't catch a field that dies in persistence;
|
||||
// this one can. It then performs the exact operation that reproduced the live
|
||||
// 403 — disable then re-enable the provider — and asserts the re-enabled state
|
||||
// is fully routable again.
|
||||
func TestSynthesizeServices_RealStore_SurvivesStatusToggle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
assertRoutable := func(t *testing.T, stage string) {
|
||||
services, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err, stage)
|
||||
require.Len(t, services, 1, "%s: exactly one synth service expected", stage)
|
||||
svc := services[0]
|
||||
|
||||
assert.True(t, svc.Private, "%s: synth service must be Private after store round-trip", stage)
|
||||
assert.Equal(t, []string{"grp-eng"}, svc.AccessGroups, "%s: AccessGroups must survive the round-trip", stage)
|
||||
|
||||
m := svc.ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{})
|
||||
assert.True(t, m.GetPrivate(), "%s: proto mapping Private must be true (proxy gates tunnel-peer auth on it)", stage)
|
||||
|
||||
cfg := decodeServiceRouterConfig(t, svc)
|
||||
require.Len(t, cfg.Providers, 1, "%s: the enabled+linked provider must appear in the router config", stage)
|
||||
assert.Equal(t, []string{"gpt-5.4"}, cfg.Providers[0].Models, "%s: provider models must reach the route", stage)
|
||||
assert.Equal(t, []string{"grp-eng"}, cfg.Providers[0].AllowedGroupIDs, "%s: policy source groups must reach the route", stage)
|
||||
}
|
||||
|
||||
assertRoutable(t, "initial")
|
||||
|
||||
provider.Enabled = false
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
disabled, err := SynthesizeServices(ctx, s, testAccountID)
|
||||
require.NoError(t, err, "synthesis must not error with a disabled provider")
|
||||
for _, svc := range disabled {
|
||||
assert.Empty(t, decodeServiceRouterConfig(t, svc).Providers,
|
||||
"a disabled provider must not appear in the router config (otherwise it would route while off)")
|
||||
}
|
||||
|
||||
provider.Enabled = true
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
assertRoutable(t, "after disable->enable")
|
||||
}
|
||||
|
||||
// captureController is a proxy.Controller that records the mappings reconcile
|
||||
// pushes, so the test can inspect the exact wire payload — Private flag and
|
||||
// router config included.
|
||||
type captureController struct {
|
||||
rpproxy.Controller
|
||||
pushed []*proto.ProxyMapping
|
||||
}
|
||||
|
||||
func (c *captureController) GetOIDCValidationConfig() rpproxy.OIDCValidationConfig {
|
||||
return rpproxy.OIDCValidationConfig{}
|
||||
}
|
||||
|
||||
func (c *captureController) SendServiceUpdateToCluster(_ context.Context, _ string, update *proto.ProxyMapping, _ string) {
|
||||
c.pushed = append(c.pushed, update)
|
||||
}
|
||||
|
||||
// noopAccountManager satisfies the reconcile path's accountManager dependency.
|
||||
type noopAccountManager struct {
|
||||
account.Manager
|
||||
}
|
||||
|
||||
func (noopAccountManager) UpdateAccountPeers(context.Context, string, nbtypes.UpdateReason) {}
|
||||
|
||||
// TestReconcile_RealStore_PushesPrivateAfterStatusToggle reproduces the live
|
||||
// path end-to-end below the gRPC boundary: a real store + the real
|
||||
// managerImpl.reconcile + a capturing proxy controller. It runs the operation
|
||||
// that broke in production — provider disable then re-enable — and asserts the
|
||||
// mapping reconcile pushes to the cluster after re-enable is Private=true and
|
||||
// carries the routable provider. If reconcile ever pushes private=false (the
|
||||
// symptom that left UserGroups empty → no_authorised_provider), this fails.
|
||||
func TestReconcile_RealStore_PushesPrivateAfterStatusToggle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings()))
|
||||
provider := newSynthTestProvider()
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", "")))
|
||||
|
||||
ctrl := &captureController{}
|
||||
m := &managerImpl{
|
||||
store: s,
|
||||
accountManager: noopAccountManager{},
|
||||
proxyController: ctrl,
|
||||
reconcileCache: make(map[string]map[string]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
m.reconcile(ctx, testAccountID) // initial, provider enabled
|
||||
|
||||
provider.Enabled = false
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
m.reconcile(ctx, testAccountID) // disabled
|
||||
|
||||
provider.Enabled = true
|
||||
require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider))
|
||||
m.reconcile(ctx, testAccountID) // re-enabled — the reproduction step
|
||||
|
||||
require.NotEmpty(t, ctrl.pushed, "reconcile must push at least one mapping")
|
||||
last := ctrl.pushed[len(ctrl.pushed)-1]
|
||||
|
||||
assert.Equal(t, newSynthTestSettings().Endpoint(), last.GetDomain(), "synth domain on the wire")
|
||||
assert.True(t, last.GetPrivate(),
|
||||
"reconcile-pushed mapping after re-enable MUST be Private=true; a false here is the exact bug — the proxy skips ValidateTunnelPeer, UserGroups stays empty, and llm_router denies no_authorised_provider")
|
||||
|
||||
cfg := decodeMappingRouterConfig(t, last)
|
||||
require.Len(t, cfg.Providers, 1, "re-enabled provider must be back in the pushed router config")
|
||||
assert.Equal(t, []string{"gpt-5.4"}, cfg.Providers[0].Models, "model must be routable again after re-enable")
|
||||
assert.Equal(t, []string{"grp-eng"}, cfg.Providers[0].AllowedGroupIDs, "authorised groups must be present after re-enable")
|
||||
}
|
||||
1098
management/internals/modules/agentnetwork/synthesizer_test.go
Normal file
1098
management/internals/modules/agentnetwork/synthesizer_test.go
Normal file
File diff suppressed because it is too large
Load Diff
117
management/internals/modules/agentnetwork/types/accesslog.go
Normal file
117
management/internals/modules/agentnetwork/types/accesslog.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// AgentNetworkAccessLog is the dedicated, flattened agent-network access-log
|
||||
// row. Unlike the shared reverse-proxy AccessLogEntry (which kept LLM data in
|
||||
// an opaque metadata JSON blob), the LLM dimensions live in first-class,
|
||||
// indexed columns so the access-log surface can filter server-side by
|
||||
// user / group / provider / model / decision.
|
||||
type AgentNetworkAccessLog struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
ServiceID string `gorm:"index"`
|
||||
Timestamp time.Time `gorm:"index"`
|
||||
UserID string `gorm:"index"`
|
||||
SourceIP string
|
||||
Method string
|
||||
Host string
|
||||
Path string `gorm:"type:text"`
|
||||
Duration time.Duration
|
||||
StatusCode int `gorm:"index"`
|
||||
AuthMethod string
|
||||
BytesUpload int64
|
||||
BytesDownload int64
|
||||
|
||||
// Flattened LLM dimensions (queryable). Sourced from proxy metadata keys.
|
||||
Provider string `gorm:"index"` // vendor, e.g. "openai" (llm.provider)
|
||||
Model string `gorm:"index"` // llm.model
|
||||
SessionID string `gorm:"index"` // llm.session_id — groups a conversation / coding session
|
||||
ResolvedProviderID string `gorm:"index"` // llm.resolved_provider_id
|
||||
SelectedPolicyID string `gorm:"index"` // llm.selected_policy_id
|
||||
Decision string `gorm:"index"` // llm_policy.decision (allow/deny)
|
||||
DenyReason string // llm_policy.reason (raw code, mapped in the UI)
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
TotalTokens int64
|
||||
CostUSD float64
|
||||
Stream bool
|
||||
|
||||
// Prompt capture. Only populated when prompt collection is enabled
|
||||
// (account master switch AND policy guardrail). Heavy free text.
|
||||
RequestPrompt string `gorm:"type:text"`
|
||||
ResponseCompletion string `gorm:"type:text"`
|
||||
|
||||
CreatedAt time.Time
|
||||
|
||||
// GroupIDs is the authorising group ids for this entry, hydrated from the
|
||||
// group child table on read. Not a column.
|
||||
GroupIDs []string `gorm:"-"`
|
||||
}
|
||||
|
||||
// TableName keeps agent-network access logs in their own table, separate from
|
||||
// the reverse-proxy AccessLogEntry table.
|
||||
func (AgentNetworkAccessLog) TableName() string { return "agent_network_access_log" }
|
||||
|
||||
// ToAPIResponse renders the flattened entry as the API representation.
|
||||
func (a *AgentNetworkAccessLog) ToAPIResponse() api.AgentNetworkAccessLog {
|
||||
out := api.AgentNetworkAccessLog{
|
||||
Id: a.ID,
|
||||
ServiceId: a.ServiceID,
|
||||
Timestamp: a.Timestamp,
|
||||
StatusCode: a.StatusCode,
|
||||
DurationMs: int(a.Duration.Milliseconds()),
|
||||
InputTokens: a.InputTokens,
|
||||
OutputTokens: a.OutputTokens,
|
||||
TotalTokens: a.TotalTokens,
|
||||
CostUsd: a.CostUSD,
|
||||
Stream: &a.Stream,
|
||||
}
|
||||
|
||||
out.UserId = strPtr(a.UserID)
|
||||
out.SourceIp = strPtr(a.SourceIP)
|
||||
out.Method = strPtr(a.Method)
|
||||
out.Host = strPtr(a.Host)
|
||||
out.Path = strPtr(a.Path)
|
||||
out.Provider = strPtr(a.Provider)
|
||||
out.Model = strPtr(a.Model)
|
||||
out.SessionId = strPtr(a.SessionID)
|
||||
out.ResolvedProviderId = strPtr(a.ResolvedProviderID)
|
||||
out.SelectedPolicyId = strPtr(a.SelectedPolicyID)
|
||||
out.Decision = strPtr(a.Decision)
|
||||
out.DenyReason = strPtr(a.DenyReason)
|
||||
out.RequestPrompt = strPtr(a.RequestPrompt)
|
||||
out.ResponseCompletion = strPtr(a.ResponseCompletion)
|
||||
|
||||
if len(a.GroupIDs) > 0 {
|
||||
groups := a.GroupIDs
|
||||
out.GroupIds = &groups
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// strPtr returns a pointer to s, or nil when s is empty — so empty optional
|
||||
// fields are omitted from the JSON rather than serialised as "".
|
||||
func strPtr(s string) *string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
// AgentNetworkAccessLogGroup is the normalised many-to-many row linking a log
|
||||
// entry to one authorising group, so the access-log endpoint can filter by
|
||||
// group with a simple `group_id IN (...)` join instead of substring-matching a
|
||||
// CSV column.
|
||||
type AgentNetworkAccessLogGroup struct {
|
||||
LogID string `gorm:"primaryKey"`
|
||||
GroupID string `gorm:"primaryKey;index"`
|
||||
AccountID string `gorm:"index"`
|
||||
}
|
||||
|
||||
// TableName names the access-log group child table.
|
||||
func (AgentNetworkAccessLogGroup) TableName() string { return "agent_network_access_log_group" }
|
||||
@@ -0,0 +1,213 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
// AccessLogDefaultPageSize is the default number of records per page.
|
||||
AccessLogDefaultPageSize = 50
|
||||
// AccessLogMaxPageSize is the maximum number of records allowed per page.
|
||||
AccessLogMaxPageSize = 100
|
||||
|
||||
accessLogDefaultSortBy = "timestamp"
|
||||
accessLogDefaultSortOrder = "desc"
|
||||
|
||||
// usageOverviewDefaultLookback bounds an unbounded usage-overview query so
|
||||
// it never aggregates an account's entire history into memory.
|
||||
usageOverviewDefaultLookback = 90 * 24 * time.Hour
|
||||
// usageOverviewMaxRange caps how far back an explicit range may reach.
|
||||
usageOverviewMaxRange = 366 * 24 * time.Hour
|
||||
)
|
||||
|
||||
// ApplyUsageOverviewBounds bounds a missing or over-wide date range so the
|
||||
// in-memory usage aggregation can't load an account's full usage history. An
|
||||
// absent range defaults to the last usageOverviewDefaultLookback; a range wider
|
||||
// than usageOverviewMaxRange is clamped from the (possibly defaulted) end.
|
||||
func (f *AgentNetworkAccessLogFilter) ApplyUsageOverviewBounds(now time.Time) {
|
||||
end := now
|
||||
if f.EndDate != nil {
|
||||
end = *f.EndDate
|
||||
}
|
||||
f.EndDate = &end
|
||||
if f.StartDate == nil {
|
||||
start := end.Add(-usageOverviewDefaultLookback)
|
||||
f.StartDate = &start
|
||||
return
|
||||
}
|
||||
if end.Sub(*f.StartDate) > usageOverviewMaxRange {
|
||||
start := end.Add(-usageOverviewMaxRange)
|
||||
f.StartDate = &start
|
||||
}
|
||||
}
|
||||
|
||||
// accessLogSortFields maps the API sort_by values to their database columns.
|
||||
var accessLogSortFields = map[string]string{
|
||||
"timestamp": "timestamp",
|
||||
"model": "model",
|
||||
"provider": "provider",
|
||||
"status_code": "status_code",
|
||||
"duration": "duration",
|
||||
"cost_usd": "cost_usd",
|
||||
"total_tokens": "total_tokens",
|
||||
"user_id": "user_id",
|
||||
"decision": "decision",
|
||||
}
|
||||
|
||||
// AgentNetworkAccessLogFilter holds pagination, filtering and sorting
|
||||
// parameters for the agent-network access-log listing. Group / provider /
|
||||
// model are multi-valued (the UI uses multi-select; an entry matches when it
|
||||
// matches any selected value).
|
||||
type AgentNetworkAccessLogFilter struct {
|
||||
Page int
|
||||
PageSize int
|
||||
|
||||
SortBy string
|
||||
SortOrder string
|
||||
|
||||
Search *string // log id, host, path, model, user email/name
|
||||
UserID *string // exact user id (the dashboard sends the picked user's id)
|
||||
SessionID *string // exact session id — groups one conversation / coding session
|
||||
GroupIDs []string // authorising group ids (match any)
|
||||
ProviderIDs []string // resolved provider ids (match any)
|
||||
Models []string // models (match any)
|
||||
Decision *string // policy decision (allow/deny)
|
||||
PathPrefix *string // request path prefix (path LIKE 'prefix%')
|
||||
StartDate *time.Time // timestamp >= start_date
|
||||
EndDate *time.Time // timestamp <= end_date
|
||||
}
|
||||
|
||||
// ParseFromRequest fills the filter from the request query parameters. It
|
||||
// returns a validation error when a supplied start_date / end_date is present
|
||||
// but not valid RFC3339: silently dropping a malformed date would broaden the
|
||||
// query (and, for the usage overview, fall back to the default window).
|
||||
func (f *AgentNetworkAccessLogFilter) ParseFromRequest(r *http.Request) error {
|
||||
q := r.URL.Query()
|
||||
|
||||
f.Page = parseAccessLogPositiveInt(q.Get("page"), 1)
|
||||
f.PageSize = min(parseAccessLogPositiveInt(q.Get("page_size"), AccessLogDefaultPageSize), AccessLogMaxPageSize)
|
||||
|
||||
f.SortBy = parseAccessLogSortField(q.Get("sort_by"))
|
||||
f.SortOrder = parseAccessLogSortOrder(q.Get("sort_order"))
|
||||
|
||||
f.Search = parseAccessLogOptionalString(q.Get("search"))
|
||||
f.UserID = parseAccessLogOptionalString(q.Get("user_id"))
|
||||
f.SessionID = parseAccessLogOptionalString(q.Get("session_id"))
|
||||
f.Decision = parseAccessLogOptionalString(q.Get("decision"))
|
||||
f.PathPrefix = parseAccessLogOptionalString(q.Get("path"))
|
||||
// Multi-value filters accept either repeated params (?group_id=a&group_id=b)
|
||||
// or a single comma-separated value (?group_id=a,b) so both the OpenAPI
|
||||
// array form and the dashboard's single-value query builder work.
|
||||
f.GroupIDs = splitMultiValue(q["group_id"])
|
||||
f.ProviderIDs = splitMultiValue(q["provider_id"])
|
||||
f.Models = splitMultiValue(q["model"])
|
||||
|
||||
var err error
|
||||
if f.StartDate, err = parseAccessLogOptionalRFC3339(q.Get("start_date")); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "invalid start_date: %v", err)
|
||||
}
|
||||
if f.EndDate, err = parseAccessLogOptionalRFC3339(q.Get("end_date")); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "invalid end_date: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSortColumn returns the database column for the active sort field.
|
||||
func (f *AgentNetworkAccessLogFilter) GetSortColumn() string {
|
||||
if col, ok := accessLogSortFields[f.SortBy]; ok {
|
||||
return col
|
||||
}
|
||||
return accessLogSortFields[accessLogDefaultSortBy]
|
||||
}
|
||||
|
||||
// GetSortOrder returns the normalised sort order ("ASC"/"DESC").
|
||||
func (f *AgentNetworkAccessLogFilter) GetSortOrder() string {
|
||||
if strings.EqualFold(f.SortOrder, "asc") {
|
||||
return "ASC"
|
||||
}
|
||||
return "DESC"
|
||||
}
|
||||
|
||||
// GetLimit returns the page size, defaulting/clamping when unset.
|
||||
func (f *AgentNetworkAccessLogFilter) GetLimit() int {
|
||||
if f.PageSize <= 0 {
|
||||
return AccessLogDefaultPageSize
|
||||
}
|
||||
return min(f.PageSize, AccessLogMaxPageSize)
|
||||
}
|
||||
|
||||
// GetOffset returns the zero-based row offset for the active page. Page is
|
||||
// user-controlled, so the multiplication is guarded against int overflow.
|
||||
func (f *AgentNetworkAccessLogFilter) GetOffset() int {
|
||||
limit := f.GetLimit()
|
||||
if f.Page <= 1 || limit <= 0 {
|
||||
return 0
|
||||
}
|
||||
if f.Page-1 > math.MaxInt/limit {
|
||||
return math.MaxInt - (math.MaxInt % limit)
|
||||
}
|
||||
return (f.Page - 1) * limit
|
||||
}
|
||||
|
||||
func parseAccessLogPositiveInt(s string, def int) int {
|
||||
if v, err := strconv.Atoi(strings.TrimSpace(s)); err == nil && v > 0 {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func parseAccessLogSortField(s string) string {
|
||||
if _, ok := accessLogSortFields[s]; ok {
|
||||
return s
|
||||
}
|
||||
return accessLogDefaultSortBy
|
||||
}
|
||||
|
||||
func parseAccessLogSortOrder(s string) string {
|
||||
if strings.EqualFold(s, "asc") {
|
||||
return "asc"
|
||||
}
|
||||
return accessLogDefaultSortOrder
|
||||
}
|
||||
|
||||
func parseAccessLogOptionalString(s string) *string {
|
||||
if s = strings.TrimSpace(s); s != "" {
|
||||
return &s
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAccessLogOptionalRFC3339(s string) (*time.Time, error) {
|
||||
if s = strings.TrimSpace(s); s == "" {
|
||||
return nil, nil //nolint:nilnil // not provided: no value and no error
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
// splitMultiValue flattens repeated query params and comma-separated values
|
||||
// into a single trimmed, blank-free list. Returns nil when nothing remains so
|
||||
// callers can skip the filter entirely.
|
||||
func splitMultiValue(values []string) []string {
|
||||
out := make([]string, 0, len(values))
|
||||
for _, raw := range values {
|
||||
for _, v := range strings.Split(raw, ",") {
|
||||
if v = strings.TrimSpace(v); v != "" {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
106
management/internals/modules/agentnetwork/types/budgetrule.go
Normal file
106
management/internals/modules/agentnetwork/types/budgetrule.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// AccountBudgetRule is an account-level, limit-only rule bound to groups
|
||||
// and/or users. It mirrors the policy budget experience without any routing:
|
||||
// it carries the same cap shape as a policy (PolicyLimits) but never selects a
|
||||
// provider. Rules apply across policies as an always-on ceiling — every
|
||||
// applicable rule binds (min-wins), so a rule can only tighten a caller's
|
||||
// effective limit, never loosen it.
|
||||
//
|
||||
// TargetGroups matches when it intersects the caller's groups; TargetUsers
|
||||
// binds a specific user directly. Empty TargetGroups and TargetUsers means the
|
||||
// rule applies to every caller (the account-wide default).
|
||||
type AccountBudgetRule struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Enabled bool
|
||||
TargetGroups []string `gorm:"serializer:json;column:target_groups"`
|
||||
TargetUsers []string `gorm:"serializer:json;column:target_users"`
|
||||
Limits PolicyLimits `gorm:"serializer:json;column:limits"`
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// TableName puts budget rules in their own table.
|
||||
func (AccountBudgetRule) TableName() string { return "agent_network_budget_rules" }
|
||||
|
||||
// NewAccountBudgetRule returns a new rule with a freshly minted ID.
|
||||
func NewAccountBudgetRule(accountID string) *AccountBudgetRule {
|
||||
now := time.Now().UTC()
|
||||
return &AccountBudgetRule{
|
||||
ID: "ainbud_" + xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the rule, including its target slices.
|
||||
func (r *AccountBudgetRule) Copy() *AccountBudgetRule {
|
||||
c := *r
|
||||
c.TargetGroups = append([]string(nil), r.TargetGroups...)
|
||||
c.TargetUsers = append([]string(nil), r.TargetUsers...)
|
||||
return &c
|
||||
}
|
||||
|
||||
// EventMeta renders the rule for the activity log.
|
||||
func (r *AccountBudgetRule) EventMeta() map[string]any {
|
||||
return map[string]any{
|
||||
"name": r.Name,
|
||||
"enabled": r.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
// FromAPIRequest applies the request payload onto the receiver.
|
||||
func (r *AccountBudgetRule) FromAPIRequest(req *api.AgentNetworkBudgetRuleRequest) {
|
||||
r.Name = req.Name
|
||||
if req.Enabled != nil {
|
||||
r.Enabled = *req.Enabled
|
||||
}
|
||||
if req.TargetGroups != nil {
|
||||
r.TargetGroups = append([]string(nil), (*req.TargetGroups)...)
|
||||
} else {
|
||||
r.TargetGroups = []string{}
|
||||
}
|
||||
if req.TargetUsers != nil {
|
||||
r.TargetUsers = append([]string(nil), (*req.TargetUsers)...)
|
||||
} else {
|
||||
r.TargetUsers = []string{}
|
||||
}
|
||||
r.Limits = limitsFromAPI(req.Limits)
|
||||
}
|
||||
|
||||
// ToAPIResponse renders the rule as the API representation.
|
||||
func (r *AccountBudgetRule) ToAPIResponse() *api.AgentNetworkBudgetRule {
|
||||
groups := r.TargetGroups
|
||||
if groups == nil {
|
||||
groups = []string{}
|
||||
}
|
||||
users := r.TargetUsers
|
||||
if users == nil {
|
||||
users = []string{}
|
||||
}
|
||||
created := r.CreatedAt
|
||||
updated := r.UpdatedAt
|
||||
return &api.AgentNetworkBudgetRule{
|
||||
Id: r.ID,
|
||||
Name: r.Name,
|
||||
Enabled: r.Enabled,
|
||||
TargetGroups: groups,
|
||||
TargetUsers: users,
|
||||
Limits: limitsToAPI(r.Limits),
|
||||
CreatedAt: &created,
|
||||
UpdatedAt: &updated,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package types
|
||||
|
||||
import "time"
|
||||
|
||||
// ConsumptionDimension classifies which kind of identity a consumption
|
||||
// row counts against. The proxy-side enforcement layer ticks one row
|
||||
// per dimension per request — typically one user row plus one group
|
||||
// row.
|
||||
type ConsumptionDimension string
|
||||
|
||||
const (
|
||||
// DimensionUser counts tokens / spend for a single end user. The
|
||||
// dim_id column carries the netbird user id (or peer.ID when the
|
||||
// caller is a tunnel-peer principal).
|
||||
DimensionUser ConsumptionDimension = "user"
|
||||
// DimensionGroup counts tokens / spend for a single source group
|
||||
// across every member of that group. The dim_id column carries
|
||||
// the netbird group id.
|
||||
DimensionGroup ConsumptionDimension = "group"
|
||||
)
|
||||
|
||||
// Consumption is a per-dimension token + USD counter for a fixed
|
||||
// aligned window. The (account, dim_kind, dim_id, window_seconds,
|
||||
// window_start) tuple is the primary key; rows are rolled forward by
|
||||
// the proxy's post-flight RecordLLMUsage path on every request.
|
||||
//
|
||||
// The same dim_id (e.g. a group id) gets one row per distinct
|
||||
// window_seconds length in scope across the account's policies,
|
||||
// because two policies with different window lengths read independent
|
||||
// counters even though they share the dimension. Two policies with
|
||||
// identical window_seconds on the same dimension share one counter
|
||||
// (correct: their caps are checked against the same shared bucket).
|
||||
type Consumption struct {
|
||||
AccountID string `gorm:"primaryKey;type:varchar(255)"`
|
||||
DimensionKind ConsumptionDimension `gorm:"primaryKey;type:varchar(16);column:dim_kind"`
|
||||
DimensionID string `gorm:"primaryKey;type:varchar(255);column:dim_id"`
|
||||
WindowSeconds int64 `gorm:"primaryKey;column:window_seconds"`
|
||||
WindowStartUTC time.Time `gorm:"primaryKey;column:window_start_utc"`
|
||||
TokensInput int64 `gorm:"column:tokens_input"`
|
||||
TokensOutput int64 `gorm:"column:tokens_output"`
|
||||
CostUSD float64 `gorm:"column:cost_usd"`
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// TableName forces a stable name independent of GORM's pluraliser.
|
||||
func (Consumption) TableName() string { return "agent_network_consumption" }
|
||||
|
||||
// ConsumptionKey identifies a single consumption counter within an account:
|
||||
// the (dim_kind, dim_id, window_seconds, window_start) part of the row's
|
||||
// primary key. Used to batch-read and batch-increment many counters for one
|
||||
// request in a single store round-trip / transaction.
|
||||
type ConsumptionKey struct {
|
||||
Kind ConsumptionDimension
|
||||
DimID string
|
||||
WindowSeconds int64
|
||||
WindowStartUTC time.Time
|
||||
}
|
||||
|
||||
// WindowStart returns the aligned UTC start of the window of length
|
||||
// windowSeconds that contains t. Aligned to the unix epoch so the
|
||||
// same bucket boundary is computed deterministically across processes.
|
||||
func WindowStart(t time.Time, windowSeconds int64) time.Time {
|
||||
if windowSeconds <= 0 {
|
||||
return t.UTC()
|
||||
}
|
||||
step := windowSeconds * int64(time.Second)
|
||||
bucketed := t.UTC().UnixNano() / step * step
|
||||
return time.Unix(0, bucketed).UTC()
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestWindowStart_AlignedToUnixEpoch is the multi-node-convergence
|
||||
// guarantee: any two proxies computing WindowStart(now, s) for the
|
||||
// same s must land on the same boundary. The implementation aligns
|
||||
// to the unix epoch (UTC) rather than local time, calendar weeks, or
|
||||
// process start time — none of which are shared across nodes.
|
||||
//
|
||||
// Table covers the load-bearing window lengths (5m, 1h, 24h, 30d)
|
||||
// plus a few odd values that still need to align cleanly.
|
||||
func TestWindowStart_AlignedToUnixEpoch(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
instant time.Time
|
||||
windowSeconds int64
|
||||
want time.Time
|
||||
}{
|
||||
{
|
||||
name: "5m window — drops seconds inside the bucket",
|
||||
instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.UTC),
|
||||
windowSeconds: 300,
|
||||
want: time.Date(2026, 5, 6, 13, 45, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "1h window — drops minutes / seconds, keeps the hour",
|
||||
instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.UTC),
|
||||
windowSeconds: 3600,
|
||||
want: time.Date(2026, 5, 6, 13, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "24h window aligns to UTC midnight",
|
||||
instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.UTC),
|
||||
windowSeconds: 86_400,
|
||||
want: time.Date(2026, 5, 6, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "30d (2_592_000s) window aligns to the 30d epoch grid, not month boundaries",
|
||||
instant: time.Date(2026, 5, 6, 0, 0, 0, 0, time.UTC),
|
||||
windowSeconds: 2_592_000,
|
||||
// 2026-05-06 UTC = 1778025600s; 1778025600 / 2592000 = 685
|
||||
// 685 * 2592000 = 1775520000s = 2026-04-07 00:00:00 UTC
|
||||
want: time.Date(2026, 4, 7, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "non-UTC input still anchors on UTC epoch boundaries",
|
||||
instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.FixedZone("CEST", 2*3600)),
|
||||
windowSeconds: 86_400,
|
||||
// 2026-05-06 13:47:23 CEST = 11:47:23 UTC → bucket 2026-05-06 00:00:00 UTC
|
||||
want: time.Date(2026, 5, 6, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := WindowStart(tc.instant, tc.windowSeconds)
|
||||
assert.True(t, got.Equal(tc.want),
|
||||
"WindowStart(%v, %ds) = %v, want %v", tc.instant, tc.windowSeconds, got, tc.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWindowStart_WithinWindowConverges proves the determinism
|
||||
// contract: any two timestamps inside the same window land on the
|
||||
// exact same boundary. Two proxy nodes serving requests 7s apart
|
||||
// must agree on which counter row to upsert.
|
||||
func TestWindowStart_WithinWindowConverges(t *testing.T) {
|
||||
t1 := time.Date(2026, 5, 6, 14, 0, 0, 0, time.UTC)
|
||||
t2 := t1.Add(7 * time.Second)
|
||||
t3 := t1.Add(59*time.Minute + 59*time.Second)
|
||||
|
||||
a := WindowStart(t1, 3600)
|
||||
b := WindowStart(t2, 3600)
|
||||
c := WindowStart(t3, 3600)
|
||||
|
||||
assert.True(t, a.Equal(b), "two timestamps 7s apart in the same 1h window must align to the same boundary")
|
||||
assert.True(t, a.Equal(c), "the very last second of a 1h window still lands on the SAME bucket as the first second")
|
||||
}
|
||||
|
||||
// TestWindowStart_AcrossWindowsDiverges is the symmetric guarantee:
|
||||
// two timestamps separated by a window's worth of time MUST land on
|
||||
// different boundaries. Without this, a 24h window's "rollover"
|
||||
// would never reset the counter.
|
||||
func TestWindowStart_AcrossWindowsDiverges(t *testing.T) {
|
||||
t1 := time.Date(2026, 5, 6, 23, 59, 59, 0, time.UTC)
|
||||
t2 := t1.Add(2 * time.Second) // 2026-05-07 00:00:01
|
||||
|
||||
a := WindowStart(t1, 86_400)
|
||||
b := WindowStart(t2, 86_400)
|
||||
assert.False(t, a.Equal(b),
|
||||
"timestamps straddling a 24h-window boundary must land on different buckets — otherwise daily caps never reset")
|
||||
}
|
||||
|
||||
// TestWindowStart_DifferentWindowsHaveDifferentBuckets locks the
|
||||
// design fork "two policies with different window_seconds on the same
|
||||
// group produce independent counters". A 24h boundary at noon is NOT
|
||||
// the same as the 30d boundary that contains it.
|
||||
func TestWindowStart_DifferentWindowsHaveDifferentBuckets(t *testing.T) {
|
||||
now := time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC)
|
||||
short := WindowStart(now, 86_400)
|
||||
long := WindowStart(now, 2_592_000)
|
||||
assert.False(t, short.Equal(long),
|
||||
"the 24h bucket and 30d bucket containing the same instant must differ — independent counters require independent keys")
|
||||
}
|
||||
|
||||
// TestWindowStart_SubMinuteAndMinuteAlignment locks sub-hour windows.
|
||||
// A 5-minute window must align to multiples of 300s from the unix
|
||||
// epoch — minute marks 0/5/10/.../55 within an hour, deterministic
|
||||
// across nodes regardless of clock drift.
|
||||
func TestWindowStart_SubMinuteAndMinuteAlignment(t *testing.T) {
|
||||
t1 := time.Date(2026, 5, 6, 14, 12, 30, 0, time.UTC)
|
||||
t2 := time.Date(2026, 5, 6, 14, 14, 59, 0, time.UTC)
|
||||
t3 := time.Date(2026, 5, 6, 14, 15, 0, 0, time.UTC)
|
||||
|
||||
a := WindowStart(t1, 300)
|
||||
b := WindowStart(t2, 300)
|
||||
c := WindowStart(t3, 300)
|
||||
|
||||
assert.True(t, a.Equal(b),
|
||||
"14:12:30 and 14:14:59 fall in the same 5m bucket starting at 14:10:00")
|
||||
assert.True(t, a.Equal(time.Date(2026, 5, 6, 14, 10, 0, 0, time.UTC)),
|
||||
"5m bucket containing 14:12 starts at 14:10 — aligned to multiples of 300s from unix epoch")
|
||||
assert.False(t, a.Equal(c),
|
||||
"14:15:00 is the start of the next 5m bucket — must not fold into the previous one")
|
||||
}
|
||||
|
||||
// TestWindowStart_ZeroWindowReturnsInputUTC covers the defensive
|
||||
// path: caller hands a zero / negative window (shouldn't happen, but
|
||||
// might mid-refactor). The function returns the input as UTC rather
|
||||
// than dividing by zero.
|
||||
func TestWindowStart_ZeroWindowReturnsInputUTC(t *testing.T) {
|
||||
now := time.Date(2026, 5, 6, 12, 30, 45, 0, time.FixedZone("CEST", 2*3600))
|
||||
got := WindowStart(now, 0)
|
||||
assert.True(t, got.Equal(now.UTC()), "zero window must not panic — return input as UTC")
|
||||
}
|
||||
120
management/internals/modules/agentnetwork/types/guardrail.go
Normal file
120
management/internals/modules/agentnetwork/types/guardrail.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// GuardrailChecks is the configurable parameter set persisted with each
|
||||
// guardrail. Stored as a JSON blob to keep the table flat.
|
||||
type GuardrailChecks struct {
|
||||
ModelAllowlist GuardrailModelAllowlist `json:"model_allowlist"`
|
||||
PromptCapture GuardrailPromptCapture `json:"prompt_capture"`
|
||||
}
|
||||
|
||||
type GuardrailModelAllowlist struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Models []string `json:"models"`
|
||||
}
|
||||
|
||||
type GuardrailPromptCapture struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
RedactPii bool `json:"redact_pii"`
|
||||
}
|
||||
|
||||
// Guardrail is an Agent Network reusable guardrail set persisted per account.
|
||||
type Guardrail struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Description string
|
||||
Checks GuardrailChecks `gorm:"serializer:json"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// TableName uses an explicit name so guardrail rows live in their own
|
||||
// table.
|
||||
func (Guardrail) TableName() string { return "agent_network_guardrails" }
|
||||
|
||||
// NewGuardrail returns a new Guardrail with a freshly minted ID.
|
||||
func NewGuardrail(accountID string) *Guardrail {
|
||||
now := time.Now().UTC()
|
||||
return &Guardrail{
|
||||
ID: "ainguard_" + xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Checks: GuardrailChecks{ModelAllowlist: GuardrailModelAllowlist{Models: []string{}}},
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
// FromAPIRequest applies the request payload onto the receiver.
|
||||
func (g *Guardrail) FromAPIRequest(req *api.AgentNetworkGuardrailRequest) {
|
||||
g.Name = req.Name
|
||||
if req.Description != nil {
|
||||
g.Description = *req.Description
|
||||
}
|
||||
g.Checks = checksFromAPI(req.Checks)
|
||||
}
|
||||
|
||||
// ToAPIResponse renders the guardrail as the API representation.
|
||||
func (g *Guardrail) ToAPIResponse() *api.AgentNetworkGuardrail {
|
||||
created := g.CreatedAt
|
||||
updated := g.UpdatedAt
|
||||
return &api.AgentNetworkGuardrail{
|
||||
Id: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Checks: checksToAPI(g.Checks),
|
||||
CreatedAt: &created,
|
||||
UpdatedAt: &updated,
|
||||
}
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the guardrail.
|
||||
func (g *Guardrail) Copy() *Guardrail {
|
||||
clone := *g
|
||||
if g.Checks.ModelAllowlist.Models != nil {
|
||||
clone.Checks.ModelAllowlist.Models = append([]string(nil), g.Checks.ModelAllowlist.Models...)
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
|
||||
// EventMeta is the audit-log payload for activity events.
|
||||
func (g *Guardrail) EventMeta() map[string]any {
|
||||
return map[string]any{"name": g.Name}
|
||||
}
|
||||
|
||||
func checksFromAPI(c api.AgentNetworkGuardrailChecks) GuardrailChecks {
|
||||
models := append([]string(nil), c.ModelAllowlist.Models...)
|
||||
if models == nil {
|
||||
models = []string{}
|
||||
}
|
||||
return GuardrailChecks{
|
||||
ModelAllowlist: GuardrailModelAllowlist{
|
||||
Enabled: c.ModelAllowlist.Enabled,
|
||||
Models: models,
|
||||
},
|
||||
PromptCapture: GuardrailPromptCapture{
|
||||
Enabled: c.PromptCapture.Enabled,
|
||||
RedactPii: c.PromptCapture.RedactPii,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func checksToAPI(c GuardrailChecks) api.AgentNetworkGuardrailChecks {
|
||||
models := c.ModelAllowlist.Models
|
||||
if models == nil {
|
||||
models = []string{}
|
||||
}
|
||||
out := api.AgentNetworkGuardrailChecks{}
|
||||
out.ModelAllowlist.Enabled = c.ModelAllowlist.Enabled
|
||||
out.ModelAllowlist.Models = models
|
||||
out.PromptCapture.Enabled = c.PromptCapture.Enabled
|
||||
out.PromptCapture.RedactPii = c.PromptCapture.RedactPii
|
||||
return out
|
||||
}
|
||||
192
management/internals/modules/agentnetwork/types/policy.go
Normal file
192
management/internals/modules/agentnetwork/types/policy.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// Policy is an Agent Network policy persisted per account. A policy
|
||||
// authorises members of SourceGroups to reach the listed
|
||||
// DestinationProviderIDs under the attached GuardrailIDs and Limits.
|
||||
//
|
||||
// Token and budget limits live on the Policy itself (Limits field);
|
||||
// guardrails carry only model allowlist and prompt capture.
|
||||
type Policy struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Description string
|
||||
Enabled bool
|
||||
SourceGroups []string `gorm:"serializer:json;column:source_groups"`
|
||||
DestinationProviderIDs []string `gorm:"serializer:json;column:destination_provider_ids"`
|
||||
GuardrailIDs []string `gorm:"serializer:json;column:guardrail_ids"`
|
||||
Limits PolicyLimits `gorm:"serializer:json;column:limits"`
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// PolicyLimits aggregates the token and budget caps attached directly
|
||||
// to a policy. Both halves are always present; their Enabled flags
|
||||
// control whether the proxy enforces them.
|
||||
type PolicyLimits struct {
|
||||
TokenLimit PolicyTokenLimit `json:"token_limit"`
|
||||
BudgetLimit PolicyBudgetLimit `json:"budget_limit"`
|
||||
}
|
||||
|
||||
// PolicyTokenLimit is a token-count cap evaluated over an aligned
|
||||
// window of WindowSeconds seconds. GroupCap is applied to each
|
||||
// source group independently — every group in the policy's
|
||||
// SourceGroups gets its own bucket of GroupCap tokens. UserCap
|
||||
// applies independently to each individual user. A zero cap means
|
||||
// uncapped. WindowSeconds must be at least 60 (one minute) when the
|
||||
// limit is enabled.
|
||||
type PolicyTokenLimit struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
GroupCap int64 `json:"group_cap"`
|
||||
UserCap int64 `json:"user_cap"`
|
||||
WindowSeconds int64 `json:"window_seconds"`
|
||||
}
|
||||
|
||||
// PolicyBudgetLimit is a USD spend cap evaluated over an aligned
|
||||
// window of WindowSeconds seconds. GroupCapUsd is applied to each
|
||||
// source group independently — every group in the policy's
|
||||
// SourceGroups gets its own bucket of GroupCapUsd USD. UserCapUsd
|
||||
// applies independently to each individual user. A zero cap means
|
||||
// uncapped. WindowSeconds must be at least 60 (one minute) when the
|
||||
// limit is enabled.
|
||||
type PolicyBudgetLimit struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
GroupCapUsd float64 `json:"group_cap_usd"`
|
||||
UserCapUsd float64 `json:"user_cap_usd"`
|
||||
WindowSeconds int64 `json:"window_seconds"`
|
||||
}
|
||||
|
||||
// TableName forces a unique GORM table to avoid collision with the access
|
||||
// control Policy type, which also resolves to "policies" by default.
|
||||
func (Policy) TableName() string { return "agent_network_policies" }
|
||||
|
||||
// NewPolicy returns a new Policy with a freshly minted ID.
|
||||
func NewPolicy(accountID string) *Policy {
|
||||
now := time.Now().UTC()
|
||||
return &Policy{
|
||||
ID: "ainpol_" + xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
// FromAPIRequest applies the request payload onto the receiver.
|
||||
func (p *Policy) FromAPIRequest(req *api.AgentNetworkPolicyRequest) {
|
||||
p.Name = req.Name
|
||||
if req.Description != nil {
|
||||
p.Description = *req.Description
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
p.Enabled = *req.Enabled
|
||||
}
|
||||
p.SourceGroups = append([]string(nil), req.SourceGroups...)
|
||||
p.DestinationProviderIDs = append([]string(nil), req.DestinationProviderIds...)
|
||||
if req.GuardrailIds != nil {
|
||||
p.GuardrailIDs = append([]string(nil), (*req.GuardrailIds)...)
|
||||
} else {
|
||||
p.GuardrailIDs = []string{}
|
||||
}
|
||||
if req.Limits != nil {
|
||||
p.Limits = limitsFromAPI(*req.Limits)
|
||||
} else {
|
||||
p.Limits = PolicyLimits{}
|
||||
}
|
||||
}
|
||||
|
||||
// ToAPIResponse renders the policy as the API representation.
|
||||
func (p *Policy) ToAPIResponse() *api.AgentNetworkPolicy {
|
||||
src := p.SourceGroups
|
||||
if src == nil {
|
||||
src = []string{}
|
||||
}
|
||||
dst := p.DestinationProviderIDs
|
||||
if dst == nil {
|
||||
dst = []string{}
|
||||
}
|
||||
guardrails := p.GuardrailIDs
|
||||
if guardrails == nil {
|
||||
guardrails = []string{}
|
||||
}
|
||||
created := p.CreatedAt
|
||||
updated := p.UpdatedAt
|
||||
return &api.AgentNetworkPolicy{
|
||||
Id: p.ID,
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
Enabled: p.Enabled,
|
||||
SourceGroups: src,
|
||||
DestinationProviderIds: dst,
|
||||
GuardrailIds: guardrails,
|
||||
Limits: limitsToAPI(p.Limits),
|
||||
CreatedAt: &created,
|
||||
UpdatedAt: &updated,
|
||||
}
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the policy.
|
||||
func (p *Policy) Copy() *Policy {
|
||||
clone := *p
|
||||
if p.SourceGroups != nil {
|
||||
clone.SourceGroups = append([]string(nil), p.SourceGroups...)
|
||||
}
|
||||
if p.DestinationProviderIDs != nil {
|
||||
clone.DestinationProviderIDs = append([]string(nil), p.DestinationProviderIDs...)
|
||||
}
|
||||
if p.GuardrailIDs != nil {
|
||||
clone.GuardrailIDs = append([]string(nil), p.GuardrailIDs...)
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
|
||||
// EventMeta is the audit-log payload for activity events.
|
||||
func (p *Policy) EventMeta() map[string]any {
|
||||
return map[string]any{
|
||||
"name": p.Name,
|
||||
"enabled": p.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
func limitsFromAPI(in api.AgentNetworkPolicyLimits) PolicyLimits {
|
||||
return PolicyLimits{
|
||||
TokenLimit: PolicyTokenLimit{
|
||||
Enabled: in.TokenLimit.Enabled,
|
||||
GroupCap: in.TokenLimit.GroupCap,
|
||||
UserCap: in.TokenLimit.UserCap,
|
||||
WindowSeconds: in.TokenLimit.WindowSeconds,
|
||||
},
|
||||
BudgetLimit: PolicyBudgetLimit{
|
||||
Enabled: in.BudgetLimit.Enabled,
|
||||
GroupCapUsd: in.BudgetLimit.GroupCapUsd,
|
||||
UserCapUsd: in.BudgetLimit.UserCapUsd,
|
||||
WindowSeconds: in.BudgetLimit.WindowSeconds,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func limitsToAPI(in PolicyLimits) api.AgentNetworkPolicyLimits {
|
||||
return api.AgentNetworkPolicyLimits{
|
||||
TokenLimit: api.AgentNetworkPolicyTokenLimit{
|
||||
Enabled: in.TokenLimit.Enabled,
|
||||
GroupCap: in.TokenLimit.GroupCap,
|
||||
UserCap: in.TokenLimit.UserCap,
|
||||
WindowSeconds: in.TokenLimit.WindowSeconds,
|
||||
},
|
||||
BudgetLimit: api.AgentNetworkPolicyBudgetLimit{
|
||||
Enabled: in.BudgetLimit.Enabled,
|
||||
GroupCapUsd: in.BudgetLimit.GroupCapUsd,
|
||||
UserCapUsd: in.BudgetLimit.UserCapUsd,
|
||||
WindowSeconds: in.BudgetLimit.WindowSeconds,
|
||||
},
|
||||
}
|
||||
}
|
||||
252
management/internals/modules/agentnetwork/types/provider.go
Normal file
252
management/internals/modules/agentnetwork/types/provider.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
// ProviderModel is one row in the provider's models list. The operator
|
||||
// pins the per-1k input/output price for cost tracking; ID is the
|
||||
// model identifier the upstream provider expects on the wire.
|
||||
type ProviderModel struct {
|
||||
ID string `json:"id"`
|
||||
InputPer1k float64 `json:"input_per_1k"`
|
||||
OutputPer1k float64 `json:"output_per_1k"`
|
||||
}
|
||||
|
||||
// Provider is an Agent Network AI provider record persisted per account.
|
||||
// The proxy cluster fronting the account lives on the per-account
|
||||
// agent-network Settings row, not on the Provider — every provider in
|
||||
// an account routes through the same cluster.
|
||||
type Provider struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
ProviderID string `gorm:"index:idx_agent_network_provider"`
|
||||
Name string
|
||||
// UpstreamURL is the full upstream URL (e.g. https://api.openai.com)
|
||||
// the operator selected.
|
||||
UpstreamURL string `gorm:"column:upstream_url"`
|
||||
APIKey string `gorm:"column:api_key"`
|
||||
// ExtraValues holds operator-typed values for catalog-declared
|
||||
// ExtraHeaders (see catalog.Provider.ExtraHeaders). Keyed by
|
||||
// header name (e.g. "x-portkey-config"); a non-empty value is
|
||||
// stamped on every upstream request to this provider via the
|
||||
// proxy's identity-inject middleware (anti-spoof Remove + Add).
|
||||
// Empty / missing keys = no header stamped. Stored as a JSON
|
||||
// blob so the schema doesn't grow per-catalog-entry.
|
||||
ExtraValues map[string]string `gorm:"serializer:json;column:extra_values"`
|
||||
// Models is the operator's curated list of models exposed by this
|
||||
// provider together with their per-1k input/output prices (USD).
|
||||
// Empty means all catalog models are allowed at catalog prices.
|
||||
Models []ProviderModel `gorm:"serializer:json"`
|
||||
Enabled bool
|
||||
// SessionPrivateKey + SessionPublicKey are the ed25519 keypair the
|
||||
// synthesised reverse-proxy service uses to sign / verify session
|
||||
// JWTs after a successful OIDC handshake. Generated once on
|
||||
// provider create and never rotated by the manager so existing
|
||||
// session cookies survive provider edits. SessionPrivateKey is
|
||||
// encrypted at rest via EncryptSensitiveData /
|
||||
// DecryptSensitiveData; SessionPublicKey is plain.
|
||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||
// IdentityHeaderUserID + IdentityHeaderGroups are the operator-
|
||||
// chosen wire header names for HeaderPair-style identity
|
||||
// injection on catalog entries that flag the shape as
|
||||
// Customizable (e.g. Bifrost, where the operator picks between
|
||||
// the always-on x-bf-lh- log-metadata family and the
|
||||
// label-declared x-bf-dim- telemetry family). Empty value
|
||||
// disables stamping for that dimension; the inject middleware
|
||||
// already no-ops on empty header names. Catalog entries with
|
||||
// Customizable=false ignore these fields and use the static
|
||||
// header names defined in their HeaderPairInjection block.
|
||||
IdentityHeaderUserID string `gorm:"column:identity_header_user_id"`
|
||||
IdentityHeaderGroups string `gorm:"column:identity_header_groups"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// TableName uses an explicit name so the Agent Network provider rows live
|
||||
// in their own table, separate from any future "providers"-named entity.
|
||||
func (Provider) TableName() string { return "agent_network_providers" }
|
||||
|
||||
// NewProvider returns a new Provider with a freshly minted ID.
|
||||
func NewProvider(accountID string) *Provider {
|
||||
now := time.Now().UTC()
|
||||
return &Provider{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
// FromAPIRequest applies the request payload onto the receiver. The api_key
|
||||
// is only overwritten when the caller provided one — empty/nil leaves the
|
||||
// existing key intact, so updates can omit it.
|
||||
func (p *Provider) FromAPIRequest(req *api.AgentNetworkProviderRequest) {
|
||||
p.ProviderID = req.ProviderId
|
||||
p.Name = req.Name
|
||||
p.UpstreamURL = req.UpstreamUrl
|
||||
if req.ApiKey != nil && strings.TrimSpace(*req.ApiKey) != "" {
|
||||
p.APIKey = *req.ApiKey
|
||||
}
|
||||
if req.ExtraValues != nil {
|
||||
// Replace the whole map (rather than merge) so unsetting a
|
||||
// value on the dashboard actually clears it. Empty strings
|
||||
// are dropped so we don't waste a row on no-op values.
|
||||
next := make(map[string]string, len(*req.ExtraValues))
|
||||
for k, v := range *req.ExtraValues {
|
||||
v = strings.TrimSpace(v)
|
||||
if v != "" {
|
||||
next[k] = v
|
||||
}
|
||||
}
|
||||
if len(next) == 0 {
|
||||
p.ExtraValues = nil
|
||||
} else {
|
||||
p.ExtraValues = next
|
||||
}
|
||||
}
|
||||
p.Models = p.Models[:0]
|
||||
if req.Models != nil {
|
||||
for _, m := range *req.Models {
|
||||
p.Models = append(p.Models, ProviderModel{
|
||||
ID: m.Id,
|
||||
InputPer1k: m.InputPer1k,
|
||||
OutputPer1k: m.OutputPer1k,
|
||||
})
|
||||
}
|
||||
}
|
||||
if p.Models == nil {
|
||||
p.Models = []ProviderModel{}
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
p.Enabled = *req.Enabled
|
||||
}
|
||||
// Identity-header overrides for catalogs flagged Customizable.
|
||||
// nil pointer = "field omitted on the wire" → leave the stored
|
||||
// value untouched (per the openapi description). Empty string is
|
||||
// an explicit clear that disables stamping for this dimension.
|
||||
if req.IdentityHeaderUserId != nil {
|
||||
p.IdentityHeaderUserID = strings.TrimSpace(*req.IdentityHeaderUserId)
|
||||
}
|
||||
if req.IdentityHeaderGroups != nil {
|
||||
p.IdentityHeaderGroups = strings.TrimSpace(*req.IdentityHeaderGroups)
|
||||
}
|
||||
}
|
||||
|
||||
// ToAPIResponse renders the provider as the API representation. The API
|
||||
// key is intentionally never surfaced.
|
||||
func (p *Provider) ToAPIResponse() *api.AgentNetworkProvider {
|
||||
models := make([]api.AgentNetworkProviderModel, 0, len(p.Models))
|
||||
for _, m := range p.Models {
|
||||
models = append(models, api.AgentNetworkProviderModel{
|
||||
Id: m.ID,
|
||||
InputPer1k: m.InputPer1k,
|
||||
OutputPer1k: m.OutputPer1k,
|
||||
})
|
||||
}
|
||||
created := p.CreatedAt
|
||||
updated := p.UpdatedAt
|
||||
resp := &api.AgentNetworkProvider{
|
||||
Id: p.ID,
|
||||
ProviderId: p.ProviderID,
|
||||
Name: p.Name,
|
||||
UpstreamUrl: p.UpstreamURL,
|
||||
Models: models,
|
||||
Enabled: p.Enabled,
|
||||
CreatedAt: &created,
|
||||
UpdatedAt: &updated,
|
||||
}
|
||||
if len(p.ExtraValues) > 0 {
|
||||
out := make(map[string]string, len(p.ExtraValues))
|
||||
for k, v := range p.ExtraValues {
|
||||
out[k] = v
|
||||
}
|
||||
resp.ExtraValues = &out
|
||||
}
|
||||
if p.IdentityHeaderUserID != "" {
|
||||
v := p.IdentityHeaderUserID
|
||||
resp.IdentityHeaderUserId = &v
|
||||
}
|
||||
if p.IdentityHeaderGroups != "" {
|
||||
v := p.IdentityHeaderGroups
|
||||
resp.IdentityHeaderGroups = &v
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the provider.
|
||||
func (p *Provider) Copy() *Provider {
|
||||
clone := *p
|
||||
if p.Models != nil {
|
||||
clone.Models = append([]ProviderModel(nil), p.Models...)
|
||||
}
|
||||
if p.ExtraValues != nil {
|
||||
clone.ExtraValues = make(map[string]string, len(p.ExtraValues))
|
||||
for k, v := range p.ExtraValues {
|
||||
clone.ExtraValues[k] = v
|
||||
}
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
|
||||
// EventMeta is the audit-log payload for activity events.
|
||||
func (p *Provider) EventMeta() map[string]any {
|
||||
return map[string]any{
|
||||
"name": p.Name,
|
||||
"provider_id": p.ProviderID,
|
||||
}
|
||||
}
|
||||
|
||||
// EncryptSensitiveData encrypts the upstream API key and the session
|
||||
// signing key in place.
|
||||
func (p *Provider) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
if p.APIKey != "" {
|
||||
encrypted, err := enc.Encrypt(p.APIKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt agent network provider api key: %w", err)
|
||||
}
|
||||
p.APIKey = encrypted
|
||||
}
|
||||
if p.SessionPrivateKey != "" {
|
||||
encrypted, err := enc.Encrypt(p.SessionPrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt agent network provider session key: %w", err)
|
||||
}
|
||||
p.SessionPrivateKey = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecryptSensitiveData decrypts the upstream API key and the session
|
||||
// signing key in place.
|
||||
func (p *Provider) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
if p.APIKey != "" {
|
||||
decrypted, err := enc.Decrypt(p.APIKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt agent network provider api key: %w", err)
|
||||
}
|
||||
p.APIKey = decrypted
|
||||
}
|
||||
if p.SessionPrivateKey != "" {
|
||||
decrypted, err := enc.Decrypt(p.SessionPrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt agent network provider session key: %w", err)
|
||||
}
|
||||
p.SessionPrivateKey = decrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
78
management/internals/modules/agentnetwork/types/settings.go
Normal file
78
management/internals/modules/agentnetwork/types/settings.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// DefaultAccessLogRetentionDays is the retention applied to new accounts'
|
||||
// agent-network access logs. Usage records are not subject to this — they are
|
||||
// the long-term aggregate and are retained independently.
|
||||
const DefaultAccessLogRetentionDays = 30
|
||||
|
||||
// Settings is the per-account agent-network configuration row. One
|
||||
// row per account. Cluster + Subdomain are immutable once written and
|
||||
// produce the public endpoint agents call (`<subdomain>.<cluster>`).
|
||||
type Settings struct {
|
||||
AccountID string `gorm:"primaryKey"`
|
||||
Cluster string
|
||||
Subdomain string `gorm:"index:idx_agent_network_settings_cluster_subdomain"`
|
||||
|
||||
// Account-level collection controls sourced by the synthesizer.
|
||||
// EnableLogCollection gates the per-request access-log trail and defaults
|
||||
// ON for new accounts. EnablePromptCollection is the master gate for
|
||||
// request/response prompt capture (AND-gated with the policy-level
|
||||
// guardrail). RedactPii enables PII redaction on captured prompts;
|
||||
// effective redaction is account OR policy.
|
||||
EnableLogCollection bool
|
||||
EnablePromptCollection bool
|
||||
RedactPii bool
|
||||
|
||||
// AccessLogRetentionDays bounds how long full access-log rows are kept; a
|
||||
// periodic sweep deletes older rows. <= 0 means keep indefinitely. Usage
|
||||
// records are unaffected.
|
||||
AccessLogRetentionDays int
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// TableName puts the rows in their own table to keep the agent-network
|
||||
// schema cohesive.
|
||||
func (Settings) TableName() string { return "agent_network_settings" }
|
||||
|
||||
// Endpoint returns the bare hostname agents reach this account at:
|
||||
// `<subdomain>.<cluster>`.
|
||||
func (s *Settings) Endpoint() string {
|
||||
return s.Subdomain + "." + s.Cluster
|
||||
}
|
||||
|
||||
// ToAPIResponse renders the settings as the API representation.
|
||||
func (s *Settings) ToAPIResponse() *api.AgentNetworkSettings {
|
||||
created := s.CreatedAt
|
||||
updated := s.UpdatedAt
|
||||
retention := s.AccessLogRetentionDays
|
||||
return &api.AgentNetworkSettings{
|
||||
Cluster: s.Cluster,
|
||||
Subdomain: s.Subdomain,
|
||||
Endpoint: s.Endpoint(),
|
||||
EnableLogCollection: s.EnableLogCollection,
|
||||
EnablePromptCollection: s.EnablePromptCollection,
|
||||
RedactPii: s.RedactPii,
|
||||
AccessLogRetentionDays: &retention,
|
||||
CreatedAt: &created,
|
||||
UpdatedAt: &updated,
|
||||
}
|
||||
}
|
||||
|
||||
// FromAPIRequest applies the mutable settings fields from the request. Cluster
|
||||
// and Subdomain are immutable and intentionally not touched here.
|
||||
func (s *Settings) FromAPIRequest(req *api.AgentNetworkSettingsRequest) {
|
||||
s.EnableLogCollection = req.EnableLogCollection
|
||||
s.EnablePromptCollection = req.EnablePromptCollection
|
||||
s.RedactPii = req.RedactPii
|
||||
if req.AccessLogRetentionDays != nil {
|
||||
s.AccessLogRetentionDays = *req.AccessLogRetentionDays
|
||||
}
|
||||
}
|
||||
47
management/internals/modules/agentnetwork/types/usage.go
Normal file
47
management/internals/modules/agentnetwork/types/usage.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// AgentNetworkUsage is the stripped, always-collected per-request usage record
|
||||
// powering the Usage overview. Unlike AgentNetworkAccessLog it carries no
|
||||
// request detail (host/path/source IP/prompt) — only the dimensions needed to
|
||||
// aggregate and filter spend by user / group / provider / model over time.
|
||||
//
|
||||
// It is written unconditionally on every served agent-network request,
|
||||
// independent of the account's EnableLogCollection toggle: when log collection
|
||||
// is off the proxy ships a stripped, usage-only entry and management still
|
||||
// records the usage row (but skips the full AgentNetworkAccessLog row).
|
||||
type AgentNetworkUsage struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Timestamp time.Time `gorm:"index"`
|
||||
UserID string `gorm:"index"`
|
||||
ResolvedProviderID string `gorm:"index"`
|
||||
Provider string // vendor, e.g. "openai"
|
||||
Model string `gorm:"index"`
|
||||
SessionID string `gorm:"index"` // llm.session_id — groups a conversation / coding session
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
TotalTokens int64
|
||||
CostUSD float64
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// TableName keeps usage records in their own stripped table. Named
|
||||
// distinctly (…_request_usage) to avoid colliding with any pre-existing
|
||||
// agent_network_usage table in a shared database.
|
||||
func (AgentNetworkUsage) TableName() string { return "agent_network_request_usage" }
|
||||
|
||||
// AgentNetworkUsageGroup is the normalised many-to-many row linking a usage
|
||||
// record to one authorising group, mirroring AgentNetworkAccessLogGroup so the
|
||||
// usage overview can filter by group with a `group_id IN (...)` join.
|
||||
type AgentNetworkUsageGroup struct {
|
||||
UsageID string `gorm:"primaryKey"`
|
||||
GroupID string `gorm:"primaryKey;index"`
|
||||
AccountID string `gorm:"index"`
|
||||
}
|
||||
|
||||
// TableName names the usage group child table.
|
||||
func (AgentNetworkUsageGroup) TableName() string { return "agent_network_request_usage_group" }
|
||||
@@ -0,0 +1,96 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// UsageGranularity is the time-bucket width for the usage overview. New values
|
||||
// can be added here and handled in bucketStart without touching the store.
|
||||
type UsageGranularity string
|
||||
|
||||
const (
|
||||
UsageGranularityDay UsageGranularity = "day"
|
||||
UsageGranularityWeek UsageGranularity = "week"
|
||||
UsageGranularityMonth UsageGranularity = "month"
|
||||
)
|
||||
|
||||
// ParseUsageGranularity maps the API query value to a granularity, defaulting
|
||||
// to day for empty/unknown input.
|
||||
func ParseUsageGranularity(s string) UsageGranularity {
|
||||
switch UsageGranularity(s) {
|
||||
case UsageGranularityWeek:
|
||||
return UsageGranularityWeek
|
||||
case UsageGranularityMonth:
|
||||
return UsageGranularityMonth
|
||||
default:
|
||||
return UsageGranularityDay
|
||||
}
|
||||
}
|
||||
|
||||
// AgentNetworkUsageBucket is one aggregated usage time bucket. PeriodStart is
|
||||
// the UTC start of the bucket as YYYY-MM-DD.
|
||||
type AgentNetworkUsageBucket struct {
|
||||
PeriodStart string
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
TotalTokens int64
|
||||
CostUSD float64
|
||||
}
|
||||
|
||||
// ToAPIResponse renders the bucket as the API representation.
|
||||
func (b *AgentNetworkUsageBucket) ToAPIResponse() api.AgentNetworkUsageBucket {
|
||||
return api.AgentNetworkUsageBucket{
|
||||
PeriodStart: b.PeriodStart,
|
||||
InputTokens: b.InputTokens,
|
||||
OutputTokens: b.OutputTokens,
|
||||
TotalTokens: b.TotalTokens,
|
||||
CostUsd: b.CostUSD,
|
||||
}
|
||||
}
|
||||
|
||||
// bucketStart truncates t (in UTC) to the start of its bucket for the given
|
||||
// granularity. Week buckets start on Monday (ISO week).
|
||||
func bucketStart(t time.Time, g UsageGranularity) time.Time {
|
||||
t = t.UTC()
|
||||
switch g {
|
||||
case UsageGranularityWeek:
|
||||
// Monday-start week. time.Weekday: Sunday=0..Saturday=6.
|
||||
offset := (int(t.Weekday()) + 6) % 7
|
||||
day := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
return day.AddDate(0, 0, -offset)
|
||||
case UsageGranularityMonth:
|
||||
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
default: // day
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
}
|
||||
|
||||
// AggregateUsageByGranularity buckets the usage rows by the requested
|
||||
// granularity and returns the buckets ordered oldest-first. Aggregation is done
|
||||
// in Go (rather than per-engine SQL date_trunc) so granularities stay portable
|
||||
// across SQLite/Postgres/MySQL and easy to extend.
|
||||
func AggregateUsageByGranularity(rows []*AgentNetworkUsage, g UsageGranularity) []*AgentNetworkUsageBucket {
|
||||
byPeriod := make(map[string]*AgentNetworkUsageBucket)
|
||||
for _, r := range rows {
|
||||
key := bucketStart(r.Timestamp, g).Format("2006-01-02")
|
||||
b := byPeriod[key]
|
||||
if b == nil {
|
||||
b = &AgentNetworkUsageBucket{PeriodStart: key}
|
||||
byPeriod[key] = b
|
||||
}
|
||||
b.InputTokens += r.InputTokens
|
||||
b.OutputTokens += r.OutputTokens
|
||||
b.TotalTokens += r.TotalTokens
|
||||
b.CostUSD += r.CostUSD
|
||||
}
|
||||
|
||||
out := make([]*AgentNetworkUsageBucket, 0, len(byPeriod))
|
||||
for _, b := range byPeriod {
|
||||
out = append(out, b)
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool { return out[i].PeriodStart < out[j].PeriodStart })
|
||||
return out
|
||||
}
|
||||
109
management/internals/modules/agentnetwork/wire_shape_test.go
Normal file
109
management/internals/modules/agentnetwork/wire_shape_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// TestSynthesizedService_WireShape locks down the proto shape that
|
||||
// flows from the synthesizer through ToProtoMapping to the proxy.
|
||||
// Drift between this test and what the proxy expects manifests as
|
||||
// "service not matching" — the proxy receives a mapping but can't
|
||||
// register an SNI/HTTP route from it.
|
||||
func TestSynthesizedService_WireShape(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
provider := newSynthTestProvider()
|
||||
policy := newSynthTestPolicy(provider.ID, "grp-eng", "")
|
||||
|
||||
expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(),
|
||||
[]*types.Provider{provider},
|
||||
[]*types.Policy{policy},
|
||||
[]*types.Guardrail{})
|
||||
|
||||
services, err := SynthesizeServices(ctx, mockStore, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
svc := services[0]
|
||||
mapping := svc.ToProtoMapping(rpservice.Create, "test-token", proxy.OIDCValidationConfig{})
|
||||
|
||||
// Identifiers — account-scoped service ID, settings-derived domain.
|
||||
assert.Equal(t, "agent-net-svc-acct-1", mapping.GetId(), "stable account-scoped virtual service ID")
|
||||
assert.Equal(t, testAccountID, mapping.GetAccountId(), "account id round-trips")
|
||||
assert.Equal(t, testEndpoint, mapping.GetDomain(), "domain matches settings.Endpoint() output")
|
||||
|
||||
// Mode + listen port — addMapping at proxy/server.go switches on Mode.
|
||||
assert.Equal(t, "http", mapping.GetMode(), "synthesised services are HTTP mode")
|
||||
assert.Equal(t, int32(0), mapping.GetListenPort(), "no custom listen port for HTTP services")
|
||||
|
||||
// Auth token + private/tunnel shape: agent-network endpoints authenticate
|
||||
// inbound agents via ValidateTunnelPeer against AccessGroups, not OIDC.
|
||||
assert.Equal(t, "test-token", mapping.GetAuthToken(), "auth token round-trips for proxy CreateProxyPeer")
|
||||
assert.True(t, mapping.GetPrivate(), "synthesised services are private (tunnel-peer auth via AccessGroups)")
|
||||
require.NotNil(t, mapping.GetAuth(), "auth payload carries the session key")
|
||||
assert.False(t, mapping.GetAuth().GetOidc(), "OIDC is off for tunnel-auth agent-network services")
|
||||
|
||||
// Path mappings — proxy/server.go::setupHTTPMapping early-returns when
|
||||
// len(mapping.GetPath()) == 0, so this is a critical assertion.
|
||||
require.Len(t, mapping.GetPath(), 1, "exactly one path mapping for the cluster target")
|
||||
pm := mapping.GetPath()[0]
|
||||
assert.Equal(t, "/", pm.GetPath(), "default path is '/'")
|
||||
assert.Equal(t, "https://noop.invalid/", pm.GetTarget(),
|
||||
"target URL is the placeholder; the router middleware rewrites it per request")
|
||||
require.NotNil(t, pm.GetOptions(), "target options must be populated so direct_upstream + middleware chain reach the proxy")
|
||||
assert.True(t, pm.GetOptions().GetDirectUpstream(), "synth targets imply direct_upstream so the proxy dials via the host stack")
|
||||
assert.True(t, pm.GetOptions().GetAgentNetwork(), "agent_network flag must travel on the wire so the proxy can tag access logs")
|
||||
|
||||
mws := pm.GetOptions().GetMiddlewares()
|
||||
require.Len(t, mws, 8, "eight middlewares reach the proxy: request_parser, router, limit_check, identity_inject, guardrail, limit_record, cost_meter, response_parser")
|
||||
|
||||
assert.Equal(t, middlewareIDLLMRequestParser, mws[0].GetId(), "first middleware id")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[0].GetSlot(), "request parser slot")
|
||||
|
||||
assert.Equal(t, middlewareIDLLMRouter, mws[1].GetId(), "second middleware id")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[1].GetSlot(), "router slot")
|
||||
require.NotEmpty(t, mws[1].GetConfigJson(), "router config must travel on the wire")
|
||||
var routerCfg routerConfig
|
||||
require.NoError(t, json.Unmarshal(mws[1].GetConfigJson(), &routerCfg), "router config decodes")
|
||||
require.Len(t, routerCfg.Providers, 1, "the only enabled provider reaches the router")
|
||||
assert.Equal(t, provider.ID, routerCfg.Providers[0].ID, "router provider id matches synth provider")
|
||||
assert.Equal(t, "Bearer sk-test-key", routerCfg.Providers[0].AuthHeaderValue,
|
||||
"openai catalog template substitutes the API key on the wire")
|
||||
|
||||
assert.Equal(t, middlewareIDLLMLimitCheck, mws[2].GetId(),
|
||||
"limit_check runs after the router so the resolved provider id is available, before identity_inject so a deny doesn't pay the header-stamp cost")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[2].GetSlot())
|
||||
|
||||
assert.Equal(t, middlewareIDLLMIdentityInject, mws[3].GetId(), "fourth middleware id")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[3].GetSlot(), "identity inject slot")
|
||||
require.NotEmpty(t, mws[3].GetConfigJson(), "identity inject config JSON must travel on the wire")
|
||||
|
||||
assert.Equal(t, middlewareIDLLMGuardrail, mws[4].GetId(), "fifth middleware id")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[4].GetSlot(), "guardrail slot")
|
||||
require.NotEmpty(t, mws[4].GetConfigJson(), "guardrail middleware config JSON must travel on the wire")
|
||||
|
||||
assert.Equal(t, middlewareIDLLMLimitRecord, mws[5].GetId(),
|
||||
"limit_record sits FIRST in the response section so it RUNS LAST at runtime — slot order on the response leg is reverse-of-slice")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE, mws[5].GetSlot())
|
||||
|
||||
assert.Equal(t, middlewareIDCostMeter, mws[6].GetId(), "seventh middleware id")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE, mws[6].GetSlot(), "cost meter slot")
|
||||
|
||||
assert.Equal(t, middlewareIDLLMResponseParser, mws[7].GetId(), "eighth middleware id")
|
||||
assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE, mws[7].GetSlot(), "response parser slot")
|
||||
}
|
||||
@@ -220,12 +220,36 @@ func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, er
|
||||
func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
||||
existingPeerID, err := m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
||||
if err == nil && existingPeerID != "" {
|
||||
// Peer already exists
|
||||
// Same pubkey already registered — idempotent.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Dedupe stale embedded peer records for the same (account, cluster).
|
||||
// The proxy generates a fresh WireGuard keypair on every startup
|
||||
// (proxy/internal/roundtrip/netbird.go), so without this sweep the
|
||||
// prior embedded peer would linger forever — holding its CGNAT IP
|
||||
// allocation, polluting other peers' rosters, and (most visibly)
|
||||
// leaving the synth DNS pointing at the dead address. The
|
||||
// (account, cluster) tuple identifies "the embedded peer for this
|
||||
// proxy instance at this cluster"; any record matching that tuple
|
||||
// with a different pubkey is by definition stale and must go.
|
||||
staleIDs, err := m.findStaleEmbeddedProxyPeers(ctx, accountID, cluster, peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("scan for stale embedded proxy peers: %w", err)
|
||||
}
|
||||
if len(staleIDs) > 0 {
|
||||
// userID="" + checkConnected=false: the deletion is initiated
|
||||
// by management itself on behalf of the freshly-registering
|
||||
// proxy, not by an end user; the stale peer may still be
|
||||
// marked Connected from its prior session, but its session is
|
||||
// dead by definition (its key no longer exists).
|
||||
if err := m.DeletePeers(ctx, accountID, staleIDs, "", false); err != nil {
|
||||
return fmt.Errorf("delete stale embedded proxy peers %v: %w", staleIDs, err)
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("proxy-%s", xid.New().String())
|
||||
peer := &peer.Peer{
|
||||
newPeer := &peer.Peer{
|
||||
Ephemeral: true,
|
||||
ProxyMeta: peer.ProxyMeta{
|
||||
Cluster: cluster,
|
||||
@@ -242,10 +266,36 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee
|
||||
},
|
||||
}
|
||||
|
||||
_, _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
|
||||
_, _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", newPeer, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create proxy peer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findStaleEmbeddedProxyPeers returns the peer IDs of embedded proxy peer
|
||||
// records in accountID that target the same cluster but carry a different
|
||||
// WireGuard pubkey than the freshly-registering one. Used by CreateProxyPeer
|
||||
// to garbage-collect stale records left behind when the proxy restarts with a
|
||||
// regenerated keypair.
|
||||
func (m *managerImpl) findStaleEmbeddedProxyPeers(ctx context.Context, accountID, cluster, newKey string) ([]string, error) {
|
||||
account, err := m.store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var stale []string
|
||||
for _, p := range account.Peers {
|
||||
if p == nil || !p.ProxyMeta.Embedded {
|
||||
continue
|
||||
}
|
||||
if p.ProxyMeta.Cluster != cluster {
|
||||
continue
|
||||
}
|
||||
if p.Key == newKey {
|
||||
continue
|
||||
}
|
||||
stale = append(stale, p.ID)
|
||||
}
|
||||
return stale, nil
|
||||
}
|
||||
|
||||
@@ -39,6 +39,10 @@ type AccessLogEntry struct {
|
||||
BytesDownload int64 `gorm:"index"`
|
||||
Protocol AccessLogProtocol `gorm:"index"`
|
||||
Metadata map[string]string `gorm:"serializer:json"`
|
||||
// AgentNetwork marks the entry as emitted by a synthesised agent-network
|
||||
// service. Sourced from proto.AccessLog.AgentNetwork the proxy stamps
|
||||
// before shipping. Indexed so the agent-network log surface filters cheaply.
|
||||
AgentNetwork bool `gorm:"index"`
|
||||
}
|
||||
|
||||
// FromProto creates an AccessLogEntry from a proto.AccessLog
|
||||
@@ -58,6 +62,7 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
|
||||
a.BytesDownload = serviceLog.GetBytesDownload()
|
||||
a.Protocol = AccessLogProtocol(serviceLog.GetProtocol())
|
||||
a.Metadata = maps.Clone(serviceLog.GetMetadata())
|
||||
a.AgentNetwork = serviceLog.GetAgentNetwork()
|
||||
|
||||
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
|
||||
if addr, err := netip.ParseAddr(sourceIP); err == nil {
|
||||
|
||||
@@ -2,12 +2,15 @@ package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
@@ -16,6 +19,28 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// Metadata keys the proxy stamps on agent-network access-log entries. These
|
||||
// mirror the constants in proxy/internal/middleware/keys.go and form the wire
|
||||
// contract between the proxy and management; management flattens them into
|
||||
// queryable columns. Keep in sync with the proxy side.
|
||||
const (
|
||||
metaKeyProvider = "llm.provider"
|
||||
metaKeyModel = "llm.model"
|
||||
metaKeyResolvedProviderID = "llm.resolved_provider_id"
|
||||
metaKeySelectedPolicyID = "llm.selected_policy_id"
|
||||
metaKeyPolicyDecision = "llm_policy.decision"
|
||||
metaKeyPolicyReason = "llm_policy.reason"
|
||||
metaKeyInputTokens = "llm.input_tokens" //nolint:gosec // metadata key name, not a credential
|
||||
metaKeyOutputTokens = "llm.output_tokens" //nolint:gosec // metadata key name, not a credential
|
||||
metaKeyTotalTokens = "llm.total_tokens" //nolint:gosec // metadata key name, not a credential
|
||||
metaKeyCostUSDTotal = "cost.usd_total"
|
||||
metaKeyStream = "llm.stream"
|
||||
metaKeySessionID = "llm.session_id"
|
||||
metaKeyAuthorisingGroups = "llm.authorising_groups"
|
||||
metaKeyRequestPrompt = "llm.request_prompt"
|
||||
metaKeyResponseCompletion = "llm.response_completion"
|
||||
)
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
@@ -31,8 +56,14 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, geo g
|
||||
}
|
||||
}
|
||||
|
||||
// SaveAccessLog saves an access log entry to the database after enriching it
|
||||
// SaveAccessLog saves an access log entry to the database after enriching it.
|
||||
// Agent-network entries are flattened into their own dedicated table (queryable
|
||||
// LLM columns + group child rows) instead of the shared reverse-proxy table.
|
||||
func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.AccessLogEntry) error {
|
||||
if logEntry.AgentNetwork {
|
||||
return m.saveAgentNetworkAccessLog(ctx, logEntry)
|
||||
}
|
||||
|
||||
if m.geo != nil && logEntry.GeoLocation.ConnectionIP != nil {
|
||||
location, err := m.geo.Lookup(logEntry.GeoLocation.ConnectionIP)
|
||||
if err != nil {
|
||||
@@ -61,6 +92,184 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveAgentNetworkAccessLog flattens the metadata-bearing access-log entry and
|
||||
// persists it in two parts:
|
||||
//
|
||||
// - The stripped usage record is written unconditionally — usage/cost is
|
||||
// collected on every request regardless of the account's log-collection
|
||||
// toggle (the proxy ships a usage-only entry when logging is disabled).
|
||||
// - The full access-log row (with request detail + prompt) is written only
|
||||
// when the account's EnableLogCollection setting is on. This setting read
|
||||
// is the authoritative gate; the proxy-side strip is defense in depth.
|
||||
func (m *managerImpl) saveAgentNetworkAccessLog(ctx context.Context, logEntry *accesslogs.AccessLogEntry) error {
|
||||
entry, groups := flattenAgentNetworkLog(logEntry)
|
||||
|
||||
usage, usageGroups := usageFromFlattenedLog(entry, groups)
|
||||
if err := m.store.CreateAgentNetworkUsage(ctx, usage, usageGroups); err != nil {
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"account_id": entry.AccountID,
|
||||
"model": entry.Model,
|
||||
}).Errorf("failed to save agent-network usage: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
settings, err := m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, entry.AccountID)
|
||||
if err != nil {
|
||||
// No settings row (or a transient read error) means we can't confirm
|
||||
// log collection is enabled — usage is already saved, so skip the full
|
||||
// row rather than fail the whole ingest.
|
||||
log.WithContext(ctx).Debugf("skipping full agent-network access-log row for account %s: %v", entry.AccountID, err)
|
||||
return nil
|
||||
}
|
||||
if !settings.EnableLogCollection {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.store.CreateAgentNetworkAccessLog(ctx, entry, groups); err != nil {
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"account_id": entry.AccountID,
|
||||
"service_id": entry.ServiceID,
|
||||
"model": entry.Model,
|
||||
"status": entry.StatusCode,
|
||||
}).Errorf("failed to save agent-network access log: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// flattenAgentNetworkLog converts a reverse-proxy AccessLogEntry (whose LLM
|
||||
// dimensions live in the opaque Metadata map) into the flattened
|
||||
// agent-network row + authorising-group child rows.
|
||||
func flattenAgentNetworkLog(e *accesslogs.AccessLogEntry) (*agentNetworkTypes.AgentNetworkAccessLog, []agentNetworkTypes.AgentNetworkAccessLogGroup) {
|
||||
meta := e.Metadata
|
||||
|
||||
var sourceIP string
|
||||
if e.GeoLocation.ConnectionIP != nil {
|
||||
sourceIP = e.GeoLocation.ConnectionIP.String()
|
||||
}
|
||||
|
||||
entry := &agentNetworkTypes.AgentNetworkAccessLog{
|
||||
ID: e.ID,
|
||||
AccountID: e.AccountID,
|
||||
ServiceID: e.ServiceID,
|
||||
Timestamp: e.Timestamp,
|
||||
UserID: e.UserId,
|
||||
SourceIP: sourceIP,
|
||||
Method: e.Method,
|
||||
Host: e.Host,
|
||||
Path: e.Path,
|
||||
Duration: e.Duration,
|
||||
StatusCode: e.StatusCode,
|
||||
AuthMethod: e.AuthMethodUsed,
|
||||
BytesUpload: e.BytesUpload,
|
||||
BytesDownload: e.BytesDownload,
|
||||
|
||||
Provider: meta[metaKeyProvider],
|
||||
Model: meta[metaKeyModel],
|
||||
SessionID: meta[metaKeySessionID],
|
||||
ResolvedProviderID: meta[metaKeyResolvedProviderID],
|
||||
SelectedPolicyID: meta[metaKeySelectedPolicyID],
|
||||
Decision: meta[metaKeyPolicyDecision],
|
||||
DenyReason: meta[metaKeyPolicyReason],
|
||||
InputTokens: parseMetaInt(meta, metaKeyInputTokens),
|
||||
OutputTokens: parseMetaInt(meta, metaKeyOutputTokens),
|
||||
TotalTokens: parseMetaInt(meta, metaKeyTotalTokens),
|
||||
CostUSD: parseMetaFloat(meta, metaKeyCostUSDTotal),
|
||||
Stream: parseMetaBool(meta, metaKeyStream),
|
||||
RequestPrompt: meta[metaKeyRequestPrompt],
|
||||
ResponseCompletion: meta[metaKeyResponseCompletion],
|
||||
}
|
||||
|
||||
var groups []agentNetworkTypes.AgentNetworkAccessLogGroup
|
||||
for _, gid := range parseGroupCSV(meta[metaKeyAuthorisingGroups]) {
|
||||
groups = append(groups, agentNetworkTypes.AgentNetworkAccessLogGroup{
|
||||
LogID: entry.ID,
|
||||
GroupID: gid,
|
||||
AccountID: entry.AccountID,
|
||||
})
|
||||
}
|
||||
return entry, groups
|
||||
}
|
||||
|
||||
// usageFromFlattenedLog derives the stripped usage record (and its group child
|
||||
// rows) from an already-flattened access-log entry. The usage row shares the
|
||||
// log's ID so the two correlate.
|
||||
func usageFromFlattenedLog(e *agentNetworkTypes.AgentNetworkAccessLog, groups []agentNetworkTypes.AgentNetworkAccessLogGroup) (*agentNetworkTypes.AgentNetworkUsage, []agentNetworkTypes.AgentNetworkUsageGroup) {
|
||||
usage := &agentNetworkTypes.AgentNetworkUsage{
|
||||
ID: e.ID,
|
||||
AccountID: e.AccountID,
|
||||
Timestamp: e.Timestamp,
|
||||
UserID: e.UserID,
|
||||
ResolvedProviderID: e.ResolvedProviderID,
|
||||
Provider: e.Provider,
|
||||
Model: e.Model,
|
||||
SessionID: e.SessionID,
|
||||
InputTokens: e.InputTokens,
|
||||
OutputTokens: e.OutputTokens,
|
||||
TotalTokens: e.TotalTokens,
|
||||
CostUSD: e.CostUSD,
|
||||
}
|
||||
|
||||
usageGroups := make([]agentNetworkTypes.AgentNetworkUsageGroup, 0, len(groups))
|
||||
for _, g := range groups {
|
||||
usageGroups = append(usageGroups, agentNetworkTypes.AgentNetworkUsageGroup{
|
||||
UsageID: usage.ID,
|
||||
GroupID: g.GroupID,
|
||||
AccountID: g.AccountID,
|
||||
})
|
||||
}
|
||||
return usage, usageGroups
|
||||
}
|
||||
|
||||
// parseMetaInt parses a non-negative token count. Negative or unparseable
|
||||
// values are clamped to 0 so a malformed metric can't persist a negative
|
||||
// counter.
|
||||
func parseMetaInt(meta map[string]string, key string) int64 {
|
||||
if v, err := strconv.ParseInt(strings.TrimSpace(meta[key]), 10, 64); err == nil && v >= 0 {
|
||||
return v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// parseMetaFloat parses a non-negative, finite cost. Negative, NaN, Inf, or
|
||||
// unparseable values are clamped to 0 so a malformed metric can't poison the
|
||||
// stored cost.
|
||||
func parseMetaFloat(meta map[string]string, key string) float64 {
|
||||
if v, err := strconv.ParseFloat(strings.TrimSpace(meta[key]), 64); err == nil && v >= 0 && !math.IsInf(v, 0) {
|
||||
return v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func parseMetaBool(meta map[string]string, key string) bool {
|
||||
v, _ := strconv.ParseBool(strings.TrimSpace(meta[key]))
|
||||
return v
|
||||
}
|
||||
|
||||
// parseGroupCSV splits the comma-separated authorising-group id list the proxy
|
||||
// emits, trimming blanks and de-duplicating. Dedup matters because the group
|
||||
// rows are keyed by (log_id, group_id) / (usage_id, group_id): a repeated id
|
||||
// in the CSV would otherwise produce a duplicate primary key and fail the
|
||||
// insert transaction.
|
||||
func parseGroupCSV(raw string) []string {
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(raw, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
seen := make(map[string]struct{}, len(parts))
|
||||
for _, p := range parts {
|
||||
if p = strings.TrimSpace(p); p != "" {
|
||||
if _, dup := seen[p]; dup {
|
||||
continue
|
||||
}
|
||||
seen[p] = struct{}{}
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
|
||||
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
|
||||
@@ -66,6 +66,51 @@ type TargetOptions struct {
|
||||
// reachable without WireGuard (public APIs, LAN services, localhost
|
||||
// sidecars). Default false.
|
||||
DirectUpstream bool `json:"direct_upstream,omitempty"`
|
||||
// Middlewares carries per-target agent-network middleware configs. Empty
|
||||
// for private and operator-defined services; populated only by the
|
||||
// agent-network synthesizer.
|
||||
Middlewares []MiddlewareConfig `gorm:"serializer:json" json:"middlewares,omitempty"`
|
||||
CaptureMaxRequestBytes int64 `json:"capture_max_request_bytes,omitempty"`
|
||||
CaptureMaxResponseBytes int64 `json:"capture_max_response_bytes,omitempty"`
|
||||
CaptureContentTypes []string `gorm:"serializer:json" json:"capture_content_types,omitempty"`
|
||||
// AgentNetwork marks targets synthesised from Agent Network state. The
|
||||
// proxy uses it to gate agent-network-specific behaviour (access log
|
||||
// tagging, observability, etc.).
|
||||
AgentNetwork bool `json:"agent_network,omitempty"`
|
||||
// DisableAccessLog suppresses the per-request access-log emission for this
|
||||
// target. Defaults false to preserve access-log behaviour for every
|
||||
// non-agent-network target. The agent-network synthesizer sets this true
|
||||
// only when the account's EnableLogCollection toggle is off.
|
||||
DisableAccessLog bool `json:"disable_access_log,omitempty"`
|
||||
}
|
||||
|
||||
// MiddlewareSlot mirrors proto.MiddlewareSlot / middleware.Slot.
|
||||
type MiddlewareSlot string
|
||||
|
||||
const (
|
||||
MiddlewareSlotOnRequest MiddlewareSlot = "on_request"
|
||||
MiddlewareSlotOnResponse MiddlewareSlot = "on_response"
|
||||
MiddlewareSlotTerminal MiddlewareSlot = "terminal"
|
||||
)
|
||||
|
||||
// MiddlewareFailMode mirrors proto.MiddlewareConfig_FailMode.
|
||||
type MiddlewareFailMode string
|
||||
|
||||
const (
|
||||
MiddlewareFailOpen MiddlewareFailMode = "fail_open"
|
||||
MiddlewareFailClosed MiddlewareFailMode = "fail_closed"
|
||||
)
|
||||
|
||||
// MiddlewareConfig is the per-target configuration for a single
|
||||
// middleware instance. Mirrors proto.MiddlewareConfig.
|
||||
type MiddlewareConfig struct {
|
||||
ID string `json:"id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Slot MiddlewareSlot `json:"slot"`
|
||||
ConfigJSON []byte `json:"config_json,omitempty"`
|
||||
FailMode MiddlewareFailMode `json:"fail_mode,omitempty"`
|
||||
TimeoutMs int32 `json:"timeout_ms,omitempty"`
|
||||
CanMutate bool `json:"can_mutate"`
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
@@ -504,21 +549,75 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||
|
||||
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 &&
|
||||
len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
|
||||
len(opts.CustomHeaders) == 0 && !opts.DirectUpstream &&
|
||||
len(opts.Middlewares) == 0 && opts.CaptureMaxRequestBytes == 0 &&
|
||||
opts.CaptureMaxResponseBytes == 0 && len(opts.CaptureContentTypes) == 0 &&
|
||||
!opts.AgentNetwork && !opts.DisableAccessLog {
|
||||
return nil
|
||||
}
|
||||
popts := &proto.PathTargetOptions{
|
||||
SkipTlsVerify: opts.SkipTLSVerify,
|
||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||
CustomHeaders: opts.CustomHeaders,
|
||||
DirectUpstream: opts.DirectUpstream,
|
||||
SkipTlsVerify: opts.SkipTLSVerify,
|
||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||
CustomHeaders: opts.CustomHeaders,
|
||||
DirectUpstream: opts.DirectUpstream,
|
||||
AgentNetwork: opts.AgentNetwork,
|
||||
DisableAccessLog: opts.DisableAccessLog,
|
||||
}
|
||||
if opts.RequestTimeout != 0 {
|
||||
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
|
||||
}
|
||||
if len(opts.Middlewares) > 0 {
|
||||
popts.Middlewares = middlewaresToProto(opts.Middlewares)
|
||||
}
|
||||
popts.CaptureMaxRequestBytes = opts.CaptureMaxRequestBytes
|
||||
popts.CaptureMaxResponseBytes = opts.CaptureMaxResponseBytes
|
||||
if len(opts.CaptureContentTypes) > 0 {
|
||||
popts.CaptureContentTypes = append([]string(nil), opts.CaptureContentTypes...)
|
||||
}
|
||||
return popts
|
||||
}
|
||||
|
||||
// middlewaresToProto converts the internal middleware slice to the proto
|
||||
// representation sent to the proxy via the mapping stream.
|
||||
func middlewaresToProto(in []MiddlewareConfig) []*proto.MiddlewareConfig {
|
||||
out := make([]*proto.MiddlewareConfig, 0, len(in))
|
||||
for _, m := range in {
|
||||
pm := &proto.MiddlewareConfig{
|
||||
Id: m.ID,
|
||||
Enabled: m.Enabled,
|
||||
Slot: middlewareSlotToProto(m.Slot),
|
||||
ConfigJson: append([]byte(nil), m.ConfigJSON...),
|
||||
CanMutate: m.CanMutate,
|
||||
FailMode: middlewareFailModeToProto(m.FailMode),
|
||||
}
|
||||
if m.TimeoutMs > 0 {
|
||||
pm.Timeout = durationpb.New(time.Duration(m.TimeoutMs) * time.Millisecond)
|
||||
}
|
||||
out = append(out, pm)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func middlewareSlotToProto(s MiddlewareSlot) proto.MiddlewareSlot {
|
||||
switch s {
|
||||
case MiddlewareSlotOnRequest:
|
||||
return proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST
|
||||
case MiddlewareSlotOnResponse:
|
||||
return proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE
|
||||
case MiddlewareSlotTerminal:
|
||||
return proto.MiddlewareSlot_MIDDLEWARE_SLOT_TERMINAL
|
||||
default:
|
||||
return proto.MiddlewareSlot_MIDDLEWARE_SLOT_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func middlewareFailModeToProto(m MiddlewareFailMode) proto.MiddlewareConfig_FailMode {
|
||||
if m == MiddlewareFailClosed {
|
||||
return proto.MiddlewareConfig_FAIL_CLOSED
|
||||
}
|
||||
return proto.MiddlewareConfig_FAIL_OPEN
|
||||
}
|
||||
|
||||
// l4TargetOptionsToProto converts L4-relevant target options to proto.
|
||||
func l4TargetOptionsToProto(target *Target) *proto.PathTargetOptions {
|
||||
if !target.ProxyProtocol && target.Options.RequestTimeout == 0 && target.Options.SessionIdleTimeout == 0 {
|
||||
|
||||
@@ -26,9 +26,11 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
@@ -120,7 +122,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount, s.AgentNetworkManager())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -223,11 +225,35 @@ func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
proxyService.SetServiceManager(s.ServiceManager())
|
||||
proxyService.SetProxyController(s.ServiceProxyController())
|
||||
proxyService.SetAgentNetworkSynthesizer(newAgentNetworkSynthesizer(s.Store()))
|
||||
proxyService.SetAgentNetworkLimitsService(s.AgentNetworkManager())
|
||||
})
|
||||
return proxyService
|
||||
})
|
||||
}
|
||||
|
||||
// agentNetworkSynthesizerAdapter implements nbgrpc.AgentNetworkSynthesizer by
|
||||
// delegating to the agentnetwork package's store-backed synthesiser.
|
||||
type agentNetworkSynthesizerAdapter struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
func newAgentNetworkSynthesizer(s store.Store) *agentNetworkSynthesizerAdapter {
|
||||
return &agentNetworkSynthesizerAdapter{store: s}
|
||||
}
|
||||
|
||||
func (a *agentNetworkSynthesizerAdapter) SynthesizeServicesForCluster(ctx context.Context, clusterAddr string) ([]*rpservice.Service, error) {
|
||||
return agentnetwork.SynthesizeServicesForCluster(ctx, a.store, clusterAddr)
|
||||
}
|
||||
|
||||
func (a *agentNetworkSynthesizerAdapter) SynthesizeServicesForAccount(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
|
||||
return agentnetwork.SynthesizeServices(ctx, a.store, accountID)
|
||||
}
|
||||
|
||||
func (a *agentNetworkSynthesizerAdapter) SynthesizeServiceForDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||
return agentnetwork.SynthesizeServiceForDomain(ctx, a.store, domain)
|
||||
}
|
||||
|
||||
func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
|
||||
return Create(s, func() nbgrpc.ProxyOIDCConfig {
|
||||
return nbgrpc.ProxyOIDCConfig{
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
@@ -194,6 +195,24 @@ func (s *BaseServer) NetworksManager() networks.Manager {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AgentNetworkManager() agentnetwork.Manager {
|
||||
return Create(s, func() agentnetwork.Manager {
|
||||
mgr := agentnetwork.NewManager(
|
||||
s.Store(),
|
||||
s.PermissionsManager(),
|
||||
s.AccountManager(),
|
||||
s.ServiceProxyController(),
|
||||
)
|
||||
// Sweep expired agent-network access logs per account retention,
|
||||
// reusing the reverse-proxy cleanup interval config.
|
||||
mgr.StartAccessLogCleanup(
|
||||
context.Background(),
|
||||
s.Config.ReverseProxy.AccessLogCleanupIntervalHours,
|
||||
)
|
||||
return mgr
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ZonesManager() zones.Manager {
|
||||
return Create(s, func() zones.Manager {
|
||||
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
|
||||
const (
|
||||
reconnThreshold = 5 * time.Minute
|
||||
baseBlockDuration = 30 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
|
||||
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
|
||||
metaChangeLimit = 5 // Number of reconnections with different metadata that triggers a ban of one peer
|
||||
)
|
||||
|
||||
type lfConfig struct {
|
||||
@@ -142,6 +142,7 @@ func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
|
||||
func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
|
||||
h := fnv.New64a()
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -35,6 +36,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
proxyauth "github.com/netbirdio/netbird/proxy/auth"
|
||||
@@ -60,6 +62,23 @@ type ProxyTokenChecker interface {
|
||||
}
|
||||
|
||||
// ProxyServiceServer implements the ProxyService gRPC server
|
||||
// AgentNetworkSynthesizer produces in-memory reverse-proxy services from
|
||||
// Agent Network provider/policy state for the proxy snapshot path; synthesised
|
||||
// services never appear in the reverseproxy_services table.
|
||||
type AgentNetworkSynthesizer interface {
|
||||
SynthesizeServicesForCluster(ctx context.Context, clusterAddr string) ([]*rpservice.Service, error)
|
||||
SynthesizeServicesForAccount(ctx context.Context, accountID string) ([]*rpservice.Service, error)
|
||||
SynthesizeServiceForDomain(ctx context.Context, domain string) (*rpservice.Service, error)
|
||||
}
|
||||
|
||||
// AgentNetworkLimitsService is the minimal slice of agentnetwork.Manager the
|
||||
// gRPC layer needs for CheckLLMPolicyLimits + RecordLLMUsage — kept narrow so
|
||||
// the grpc package doesn't take a hard import on the full manager.
|
||||
type AgentNetworkLimitsService interface {
|
||||
SelectPolicyForRequest(ctx context.Context, in agentnetwork.PolicySelectionInput) (*agentnetwork.PolicySelectionResult, error)
|
||||
RecordUsage(ctx context.Context, in agentnetwork.RecordUsageInput) error
|
||||
}
|
||||
|
||||
type ProxyServiceServer struct {
|
||||
proto.UnimplementedProxyServiceServer
|
||||
|
||||
@@ -72,6 +91,14 @@ type ProxyServiceServer struct {
|
||||
mu sync.RWMutex
|
||||
// Manager for reverse proxy operations
|
||||
serviceManager rpservice.Manager
|
||||
// agentNetworkSynth produces synthesised reverse-proxy services from
|
||||
// Agent Network state. Optional — when nil the snapshot path only ships
|
||||
// persisted services.
|
||||
agentNetworkSynth AgentNetworkSynthesizer
|
||||
// agentNetworkLimits handles the pre-flight selection (CheckLLMPolicyLimits)
|
||||
// and the post-flight consumption write (RecordLLMUsage). Optional — when
|
||||
// nil both RPCs return Unimplemented.
|
||||
agentNetworkLimits AgentNetworkLimitsService
|
||||
// ProxyController for service updates and cluster management
|
||||
proxyController proxy.Controller
|
||||
|
||||
@@ -209,6 +236,127 @@ func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
|
||||
s.serviceManager = manager
|
||||
}
|
||||
|
||||
// SetAgentNetworkSynthesizer wires the agent-network service synthesiser.
|
||||
// Optional — when nil the snapshot path skips agent-network synthesis. The
|
||||
// modules layer injects this after both the proxy server and the agent-network
|
||||
// manager are constructed.
|
||||
func (s *ProxyServiceServer) SetAgentNetworkSynthesizer(synth AgentNetworkSynthesizer) {
|
||||
s.mu.Lock()
|
||||
s.agentNetworkSynth = synth
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetAgentNetworkLimitsService wires the policy-selection + post-flight
|
||||
// consumption sink. Pass nil to disable; both RPCs return Unimplemented while
|
||||
// unset so partial wiring surfaces during integration.
|
||||
func (s *ProxyServiceServer) SetAgentNetworkLimitsService(svc AgentNetworkLimitsService) {
|
||||
s.mu.Lock()
|
||||
s.agentNetworkLimits = svc
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// agentNetworkSynthesizer returns the synthesiser under read lock.
|
||||
func (s *ProxyServiceServer) agentNetworkSynthesizer() AgentNetworkSynthesizer {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.agentNetworkSynth
|
||||
}
|
||||
|
||||
// CheckLLMPolicyLimits is the pre-flight policy gate the proxy calls before
|
||||
// forwarding an LLM request upstream. Delegates to the agent-network selector,
|
||||
// which scores applicable policies by remaining headroom and returns the
|
||||
// policy that pays for this request (or a deny when all are exhausted).
|
||||
func (s *ProxyServiceServer) CheckLLMPolicyLimits(ctx context.Context, req *proto.CheckLLMPolicyLimitsRequest) (*proto.CheckLLMPolicyLimitsResponse, error) {
|
||||
s.mu.RLock()
|
||||
svc := s.agentNetworkLimits
|
||||
s.mu.RUnlock()
|
||||
if svc == nil {
|
||||
return nil, status.Errorf(codes.Unimplemented, "agent-network limits service not configured on management")
|
||||
}
|
||||
if req.GetAccountId() == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "account_id is required")
|
||||
}
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := svc.SelectPolicyForRequest(ctx, agentnetwork.PolicySelectionInput{
|
||||
AccountID: req.GetAccountId(),
|
||||
UserID: req.GetUserId(),
|
||||
GroupIDs: req.GetGroupIds(),
|
||||
ProviderID: req.GetProviderId(),
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("select policy for request: %v", err)
|
||||
return nil, status.Error(codes.Internal, "select policy failed")
|
||||
}
|
||||
|
||||
if !res.Allow {
|
||||
return &proto.CheckLLMPolicyLimitsResponse{
|
||||
Decision: "deny",
|
||||
SelectedPolicyId: res.SelectedPolicyID,
|
||||
AttributionGroupId: res.AttributionGroupID,
|
||||
WindowSeconds: res.WindowSeconds,
|
||||
DenyCode: res.DenyCode,
|
||||
DenyReason: res.DenyReason,
|
||||
}, nil
|
||||
}
|
||||
return &proto.CheckLLMPolicyLimitsResponse{
|
||||
Decision: "allow",
|
||||
SelectedPolicyId: res.SelectedPolicyID,
|
||||
AttributionGroupId: res.AttributionGroupID,
|
||||
WindowSeconds: res.WindowSeconds,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RecordLLMUsage increments the per-(dimension, window) consumption counter for
|
||||
// the user and optional attribution group after a served request. Returns
|
||||
// Unimplemented when the agent-network limits service hasn't been wired.
|
||||
func (s *ProxyServiceServer) RecordLLMUsage(ctx context.Context, req *proto.RecordLLMUsageRequest) (*proto.RecordLLMUsageResponse, error) {
|
||||
s.mu.RLock()
|
||||
svc := s.agentNetworkLimits
|
||||
s.mu.RUnlock()
|
||||
if svc == nil {
|
||||
return nil, status.Errorf(codes.Unimplemented, "agent-network limits service not configured on management")
|
||||
}
|
||||
|
||||
accountID := req.GetAccountId()
|
||||
if accountID == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "account_id is required")
|
||||
}
|
||||
if err := enforceAccountScope(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokensIn := req.GetTokensInput()
|
||||
tokensOut := req.GetTokensOutput()
|
||||
costUSD := req.GetCostUsd()
|
||||
|
||||
// Reject impossible counters at the boundary instead of recording them:
|
||||
// a negative window, negative tokens, or a negative / non-finite cost
|
||||
// would otherwise decrement or poison the persisted consumption totals.
|
||||
if req.GetWindowSeconds() < 0 || tokensIn < 0 || tokensOut < 0 || costUSD < 0 || math.IsNaN(costUSD) || math.IsInf(costUSD, 0) {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "usage counters must be non-negative and finite")
|
||||
}
|
||||
|
||||
// Book the policy-window dimensions (when a policy cap bound this request)
|
||||
// and every applicable account budget rule's window in a single batched
|
||||
// transaction.
|
||||
if err := svc.RecordUsage(ctx, agentnetwork.RecordUsageInput{
|
||||
AccountID: accountID,
|
||||
UserID: req.GetUserId(),
|
||||
AttributionGroupID: req.GetGroupId(),
|
||||
GroupIDs: req.GetGroupIds(),
|
||||
WindowSeconds: req.GetWindowSeconds(),
|
||||
TokensIn: tokensIn,
|
||||
TokensOut: tokensOut,
|
||||
CostUSD: costUSD,
|
||||
}); err != nil {
|
||||
log.WithContext(ctx).Errorf("record usage: %v", err)
|
||||
return nil, status.Error(codes.Internal, "record usage failed")
|
||||
}
|
||||
return &proto.RecordLLMUsageResponse{}, nil
|
||||
}
|
||||
|
||||
// SetProxyController sets the proxy controller. Must be called before serving.
|
||||
func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) {
|
||||
s.mu.Lock()
|
||||
@@ -623,12 +771,40 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
return nil, fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
|
||||
if synth := s.agentNetworkSynthesizer(); synth != nil {
|
||||
var synthesised []*rpservice.Service
|
||||
var serr error
|
||||
// Account-scoped connections synthesise only their own account, so the
|
||||
// snapshot can never carry another tenant's mappings (which embed the
|
||||
// upstream auth header derived from that tenant's provider API key).
|
||||
// Global connections still see the whole cluster.
|
||||
if conn.accountID != nil {
|
||||
synthesised, serr = synth.SynthesizeServicesForAccount(ctx, *conn.accountID)
|
||||
} else {
|
||||
synthesised, serr = synth.SynthesizeServicesForCluster(ctx, conn.address)
|
||||
}
|
||||
if serr != nil {
|
||||
// Surface a real synthesis failure instead of silently shipping an
|
||||
// incomplete snapshot (which would drop the account's agent-network
|
||||
// routes). Consistent with the persisted-services error above; the
|
||||
// proxy retries the snapshot on connection error.
|
||||
return nil, fmt.Errorf("synthesise agent-network services: %w", serr)
|
||||
}
|
||||
services = append(services, synthesised...)
|
||||
}
|
||||
|
||||
oidcCfg := s.GetOIDCValidationConfig()
|
||||
var mappings []*proto.ProxyMapping
|
||||
for _, service := range services {
|
||||
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
||||
continue
|
||||
}
|
||||
// Defense in depth: an account-scoped proxy must never receive another
|
||||
// account's mapping, matching the per-account filtering the incremental
|
||||
// update path already applies.
|
||||
if conn.accountID != nil && service.AccountID != *conn.accountID {
|
||||
continue
|
||||
}
|
||||
|
||||
m := service.ToProtoMapping(rpservice.Create, "", oidcCfg)
|
||||
if !proxyAcceptsMapping(conn, m) {
|
||||
@@ -1617,7 +1793,29 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||
return s.serviceManager.GetServiceByDomain(ctx, domain)
|
||||
service, err := s.serviceManager.GetServiceByDomain(ctx, domain)
|
||||
if err == nil {
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// Fall back to the Agent Network synthesiser scoped directly to the domain's
|
||||
// account. Synthesised services are never persisted, so they must resolve
|
||||
// here for OIDC / session / tunnel-peer flows against agent-network
|
||||
// endpoints. Resolving by domain synthesises only the owning account rather
|
||||
// than every tenant on the cluster.
|
||||
if synth := s.agentNetworkSynthesizer(); synth != nil {
|
||||
svc, serr := synth.SynthesizeServiceForDomain(ctx, domain)
|
||||
if serr != nil {
|
||||
// A real synthesis failure must surface, not be masked by the
|
||||
// original store miss — otherwise a transient DB error looks like
|
||||
// "no such service".
|
||||
return nil, fmt.Errorf("synthesize agent-network service for %s: %w", domain, serr)
|
||||
}
|
||||
if svc != nil {
|
||||
return svc, nil
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
|
||||
|
||||
@@ -245,6 +245,37 @@ const (
|
||||
// tunnel. Distinct from UserLoggedInPeer (full interactive login).
|
||||
UserExtendedPeerSession Activity = 125
|
||||
|
||||
// AgentNetworkProviderCreated indicates that a user created an Agent Network provider
|
||||
AgentNetworkProviderCreated Activity = 126
|
||||
// AgentNetworkProviderUpdated indicates that a user updated an Agent Network provider
|
||||
AgentNetworkProviderUpdated Activity = 127
|
||||
// AgentNetworkProviderDeleted indicates that a user deleted an Agent Network provider
|
||||
AgentNetworkProviderDeleted Activity = 128
|
||||
|
||||
// AgentNetworkPolicyCreated indicates that a user created an Agent Network policy
|
||||
AgentNetworkPolicyCreated Activity = 129
|
||||
// AgentNetworkPolicyUpdated indicates that a user updated an Agent Network policy
|
||||
AgentNetworkPolicyUpdated Activity = 130
|
||||
// AgentNetworkPolicyDeleted indicates that a user deleted an Agent Network policy
|
||||
AgentNetworkPolicyDeleted Activity = 131
|
||||
|
||||
// AgentNetworkGuardrailCreated indicates that a user created an Agent Network guardrail
|
||||
AgentNetworkGuardrailCreated Activity = 132
|
||||
// AgentNetworkGuardrailUpdated indicates that a user updated an Agent Network guardrail
|
||||
AgentNetworkGuardrailUpdated Activity = 133
|
||||
// AgentNetworkGuardrailDeleted indicates that a user deleted an Agent Network guardrail
|
||||
AgentNetworkGuardrailDeleted Activity = 134
|
||||
|
||||
// AgentNetworkBudgetRuleCreated indicates that a user created an Agent Network budget rule
|
||||
AgentNetworkBudgetRuleCreated Activity = 135
|
||||
// AgentNetworkBudgetRuleUpdated indicates that a user updated an Agent Network budget rule
|
||||
AgentNetworkBudgetRuleUpdated Activity = 136
|
||||
// AgentNetworkBudgetRuleDeleted indicates that a user deleted an Agent Network budget rule
|
||||
AgentNetworkBudgetRuleDeleted Activity = 137
|
||||
|
||||
// AgentNetworkSettingsUpdated indicates that a user updated Agent Network account settings
|
||||
AgentNetworkSettingsUpdated Activity = 139
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -400,6 +431,24 @@ var activityMap = map[Activity]Code{
|
||||
|
||||
UserExtendedPeerSession: {"User extended peer session", "user.peer.session.extend"},
|
||||
|
||||
AgentNetworkProviderCreated: {"Agent Network provider created", "agent_network.provider.create"},
|
||||
AgentNetworkProviderUpdated: {"Agent Network provider updated", "agent_network.provider.update"},
|
||||
AgentNetworkProviderDeleted: {"Agent Network provider deleted", "agent_network.provider.delete"},
|
||||
|
||||
AgentNetworkPolicyCreated: {"Agent Network policy created", "agent_network.policy.create"},
|
||||
AgentNetworkPolicyUpdated: {"Agent Network policy updated", "agent_network.policy.update"},
|
||||
AgentNetworkPolicyDeleted: {"Agent Network policy deleted", "agent_network.policy.delete"},
|
||||
|
||||
AgentNetworkGuardrailCreated: {"Agent Network guardrail created", "agent_network.guardrail.create"},
|
||||
AgentNetworkGuardrailUpdated: {"Agent Network guardrail updated", "agent_network.guardrail.update"},
|
||||
AgentNetworkGuardrailDeleted: {"Agent Network guardrail deleted", "agent_network.guardrail.delete"},
|
||||
|
||||
AgentNetworkBudgetRuleCreated: {"Agent Network budget rule created", "agent_network.budget_rule.create"},
|
||||
AgentNetworkBudgetRuleUpdated: {"Agent Network budget rule updated", "agent_network.budget_rule.update"},
|
||||
AgentNetworkBudgetRuleDeleted: {"Agent Network budget rule deleted", "agent_network.budget_rule.delete"},
|
||||
|
||||
AgentNetworkSettingsUpdated: {"Agent Network settings updated", "agent_network.settings.update"},
|
||||
|
||||
DomainAdded: {"Domain added", "domain.add"},
|
||||
DomainDeleted: {"Domain deleted", "domain.delete"},
|
||||
DomainValidated: {"Domain validated", "domain.validate"},
|
||||
|
||||
95
management/server/affectedpeers/proxy_synth_test.go
Normal file
95
management/server/affectedpeers/proxy_synth_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package affectedpeers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// fakeProxyStore implements only the two store methods loadProxyServices calls;
|
||||
// the embedded nil store.Store panics if anything else is invoked, which keeps
|
||||
// the test honest about the surface under test.
|
||||
type fakeProxyStore struct {
|
||||
store.Store
|
||||
proxyByCluster map[string][]string
|
||||
persisted []*rpservice.Service
|
||||
}
|
||||
|
||||
func (f *fakeProxyStore) GetEmbeddedProxyPeerIDsByCluster(_ context.Context, _ string) (map[string][]string, error) {
|
||||
return f.proxyByCluster, nil
|
||||
}
|
||||
|
||||
func (f *fakeProxyStore) GetAccountServices(_ context.Context, _ store.LockingStrength, _ string) ([]*rpservice.Service, error) {
|
||||
return f.persisted, nil
|
||||
}
|
||||
|
||||
func serviceIDs(svcs []*rpservice.Service) []string {
|
||||
ids := make([]string, 0, len(svcs))
|
||||
for _, s := range svcs {
|
||||
ids = append(ids, s.ID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// loadProxyServices must merge the synthesised agent-network services (which are
|
||||
// never persisted) with the persisted ones, so the proxy-affected expansion can
|
||||
// see agent-network AccessGroups. Without this the embedded proxy peer is never
|
||||
// flagged on a client group change and only a full resync (restart) recovers.
|
||||
func TestLoadProxyServices_MergesSynthesizedAgentNetworkServices(t *testing.T) {
|
||||
prev := agentNetworkSynthesizer
|
||||
t.Cleanup(func() { agentNetworkSynthesizer = prev })
|
||||
SetAgentNetworkSynthesizer(func(_ context.Context, _ store.Store, _ string) ([]*rpservice.Service, error) {
|
||||
return []*rpservice.Service{
|
||||
{ID: "agent-net-svc-acc", ProxyCluster: "proxy.netbird.local", Private: true, AccessGroups: []string{"gB"}},
|
||||
}, nil
|
||||
})
|
||||
|
||||
s := &fakeProxyStore{
|
||||
proxyByCluster: map[string][]string{"proxy.netbird.local": {"proxy-peer-1"}},
|
||||
persisted: []*rpservice.Service{{ID: "persisted-rp-svc", ProxyCluster: "proxy.netbird.local"}},
|
||||
}
|
||||
snap := &Snapshot{}
|
||||
require.NoError(t, snap.loadProxyServices(context.Background(), s, "acc"))
|
||||
|
||||
ids := serviceIDs(snap.services)
|
||||
assert.Contains(t, ids, "persisted-rp-svc", "persisted services must be kept")
|
||||
assert.Contains(t, ids, "agent-net-svc-acc", "synthesised agent-network service must be merged in")
|
||||
}
|
||||
|
||||
// With no synthesiser registered, loadProxyServices falls back to persisted
|
||||
// services only (no panic, no behaviour change for non-agent-network builds).
|
||||
func TestLoadProxyServices_NoSynthesizerRegistered(t *testing.T) {
|
||||
prev := agentNetworkSynthesizer
|
||||
t.Cleanup(func() { agentNetworkSynthesizer = prev })
|
||||
agentNetworkSynthesizer = nil
|
||||
|
||||
s := &fakeProxyStore{
|
||||
proxyByCluster: map[string][]string{"c": {"proxy-1"}},
|
||||
persisted: []*rpservice.Service{{ID: "persisted"}},
|
||||
}
|
||||
snap := &Snapshot{}
|
||||
require.NoError(t, snap.loadProxyServices(context.Background(), s, "acc"))
|
||||
assert.Equal(t, []string{"persisted"}, serviceIDs(snap.services))
|
||||
}
|
||||
|
||||
// No embedded proxy peers → skip entirely (don't even call the synthesiser).
|
||||
func TestLoadProxyServices_NoEmbeddedProxyPeersSkips(t *testing.T) {
|
||||
prev := agentNetworkSynthesizer
|
||||
t.Cleanup(func() { agentNetworkSynthesizer = prev })
|
||||
called := false
|
||||
SetAgentNetworkSynthesizer(func(_ context.Context, _ store.Store, _ string) ([]*rpservice.Service, error) {
|
||||
called = true
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
s := &fakeProxyStore{proxyByCluster: map[string][]string{}}
|
||||
snap := &Snapshot{}
|
||||
require.NoError(t, snap.loadProxyServices(context.Background(), s, "acc"))
|
||||
assert.False(t, called, "synthesiser must not run for accounts without embedded proxy peers")
|
||||
assert.Empty(t, snap.services)
|
||||
}
|
||||
@@ -29,6 +29,19 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// agentNetworkSynthesizer returns the account's synthesised (never-persisted)
|
||||
// agent-network reverse-proxy services. It is registered at boot via
|
||||
// SetAgentNetworkSynthesizer to avoid an import cycle (agentnetwork → account →
|
||||
// affectedpeers). nil when agent-network is not wired, in which case only
|
||||
// persisted services are considered.
|
||||
var agentNetworkSynthesizer func(ctx context.Context, s store.Store, accountID string) ([]*rpservice.Service, error)
|
||||
|
||||
// SetAgentNetworkSynthesizer registers the agent-network service synthesiser.
|
||||
// Called once during boot, before any request is served.
|
||||
func SetAgentNetworkSynthesizer(fn func(ctx context.Context, s store.Store, accountID string) ([]*rpservice.Service, error)) {
|
||||
agentNetworkSynthesizer = fn
|
||||
}
|
||||
|
||||
// Snapshot is an in-memory view of the collections needed to expand a Change.
|
||||
// Loaded in-tx, walked by Expand after commit. Only the collections the Change
|
||||
// can touch are loaded; the rest stay nil (see Load).
|
||||
@@ -124,7 +137,12 @@ func (snap *Snapshot) loadDNS(ctx context.Context, s store.Store, accountID stri
|
||||
}
|
||||
|
||||
// loadProxyServices loads the embedded-proxy cluster index, and the services only
|
||||
// when the account actually has embedded proxy peers.
|
||||
// when the account actually has embedded proxy peers. Both the persisted
|
||||
// reverse-proxy services and the synthesised agent-network services are loaded:
|
||||
// agent-network services are never persisted, so without synthesising them here
|
||||
// collectFromProxyServices can't fold the embedded proxy peer into the affected
|
||||
// set when a client's group changes, and the proxy never learns a newly
|
||||
// authorised client until it reconnects (full network-map resync).
|
||||
func (snap *Snapshot) loadProxyServices(ctx context.Context, s store.Store, accountID string) error {
|
||||
var err error
|
||||
if snap.proxyByCluster, err = s.GetEmbeddedProxyPeerIDsByCluster(ctx, accountID); err != nil {
|
||||
@@ -133,8 +151,21 @@ func (snap *Snapshot) loadProxyServices(ctx context.Context, s store.Store, acco
|
||||
if len(snap.proxyByCluster) == 0 {
|
||||
return nil
|
||||
}
|
||||
snap.services, err = s.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
return err
|
||||
if snap.services, err = s.GetAccountServices(ctx, store.LockingStrengthNone, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
if agentNetworkSynthesizer == nil {
|
||||
return nil
|
||||
}
|
||||
synth, serr := agentNetworkSynthesizer(ctx, s, accountID)
|
||||
if serr != nil {
|
||||
// Non-fatal: fall back to persisted services. The next full
|
||||
// network-map resync still converges the proxy.
|
||||
log.WithContext(ctx).Warnf("affectedpeers: synthesise agent-network services for account %s: %v", accountID, serr)
|
||||
return nil
|
||||
}
|
||||
snap.services = append(snap.services, synth...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadGroupIndex loads all groups (for group.Resources) and builds the
|
||||
|
||||
126
management/server/agentnetwork_budgetrule_realstack_test.go
Normal file
126
management/server/agentnetwork_budgetrule_realstack_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
agenttypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// TestAgentNetwork_BudgetRuleCRUD_RealManager is the GC-1 no-mock guard for the
|
||||
// account budget-rule manager surface: real DefaultAccountManager, real store,
|
||||
// real permissions. It exercises create/get/list/update/delete through the
|
||||
// permission-gated manager (not the store directly) and asserts the reused
|
||||
// PolicyLimits cap shape and targets survive each step.
|
||||
func TestAgentNetwork_BudgetRuleCRUD_RealManager(t *testing.T) {
|
||||
am, _, err := createManager(t)
|
||||
require.NoError(t, err, "createManager must succeed")
|
||||
ctx := context.Background()
|
||||
|
||||
const (
|
||||
accountID = "agent-net-budget-acct"
|
||||
adminUserID = "agent-net-budget-admin"
|
||||
)
|
||||
account := newAccountWithId(ctx, accountID, adminUserID, "agent-net.test", "", "", false)
|
||||
require.NoError(t, am.Store.SaveAccount(ctx, account), "SaveAccount must succeed")
|
||||
|
||||
mgr := agentnetwork.NewManager(am.Store, permissions.NewManager(am.Store), am, nil)
|
||||
|
||||
created, err := mgr.CreateBudgetRule(ctx, adminUserID, &agenttypes.AccountBudgetRule{
|
||||
AccountID: accountID,
|
||||
Name: "eng-monthly",
|
||||
Enabled: true,
|
||||
TargetGroups: []string{"grp-eng"},
|
||||
TargetUsers: []string{"user-alice"},
|
||||
Limits: agenttypes.PolicyLimits{
|
||||
TokenLimit: agenttypes.PolicyTokenLimit{Enabled: true, GroupCap: 100_000, UserCap: 10_000, WindowSeconds: 2_592_000},
|
||||
BudgetLimit: agenttypes.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 500, WindowSeconds: 2_592_000},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err, "CreateBudgetRule must succeed")
|
||||
require.NotEmpty(t, created.ID, "create must mint an ID")
|
||||
|
||||
got, err := mgr.GetBudgetRule(ctx, accountID, adminUserID, created.ID)
|
||||
require.NoError(t, err, "GetBudgetRule must succeed")
|
||||
assert.Equal(t, "eng-monthly", got.Name, "name round-trips through the manager")
|
||||
assert.Equal(t, []string{"grp-eng"}, got.TargetGroups, "target groups round-trip")
|
||||
assert.Equal(t, int64(100_000), got.Limits.TokenLimit.GroupCap, "token group cap round-trips")
|
||||
|
||||
list, err := mgr.GetAllBudgetRules(ctx, accountID, adminUserID)
|
||||
require.NoError(t, err, "GetAllBudgetRules must succeed")
|
||||
require.Len(t, list, 1, "exactly the one created rule must be listed")
|
||||
|
||||
created.Limits.TokenLimit.GroupCap = 200_000
|
||||
updated, err := mgr.UpdateBudgetRule(ctx, adminUserID, created)
|
||||
require.NoError(t, err, "UpdateBudgetRule must succeed")
|
||||
assert.Equal(t, int64(200_000), updated.Limits.TokenLimit.GroupCap, "updated cap must persist")
|
||||
|
||||
require.NoError(t, mgr.DeleteBudgetRule(ctx, accountID, adminUserID, created.ID), "DeleteBudgetRule must succeed")
|
||||
_, err = mgr.GetBudgetRule(ctx, accountID, adminUserID, created.ID)
|
||||
assert.Error(t, err, "get after delete must fail")
|
||||
}
|
||||
|
||||
// TestAgentNetwork_UpdateSettings_PreservesImmutableAndTogglesCollection is the
|
||||
// GC-1 guard for UpdateSettings: it must apply the collection toggles while
|
||||
// preserving the immutable Cluster/Subdomain pinned at bootstrap.
|
||||
func TestAgentNetwork_UpdateSettings_PreservesImmutableAndTogglesCollection(t *testing.T) {
|
||||
am, _, err := createManager(t)
|
||||
require.NoError(t, err, "createManager must succeed")
|
||||
ctx := context.Background()
|
||||
|
||||
const (
|
||||
accountID = "agent-net-settings-acct"
|
||||
adminUserID = "agent-net-settings-admin"
|
||||
clusterAddr = "eu.proxy.netbird.io"
|
||||
)
|
||||
account := newAccountWithId(ctx, accountID, adminUserID, "agent-net.test", "", "", false)
|
||||
require.NoError(t, am.Store.SaveAccount(ctx, account), "SaveAccount must succeed")
|
||||
|
||||
mgr := agentnetwork.NewManager(am.Store, permissions.NewManager(am.Store), am, nil)
|
||||
|
||||
// Creating a provider bootstraps the settings row (cluster + subdomain).
|
||||
_, err = mgr.CreateProvider(ctx, adminUserID, &agenttypes.Provider{
|
||||
AccountID: accountID,
|
||||
ProviderID: "openai_api",
|
||||
Name: "openai",
|
||||
UpstreamURL: "https://api.openai.com",
|
||||
APIKey: "sk-test",
|
||||
Enabled: true,
|
||||
Models: []agenttypes.ProviderModel{{ID: "gpt-5.4"}},
|
||||
}, clusterAddr)
|
||||
require.NoError(t, err, "CreateProvider must bootstrap settings")
|
||||
|
||||
before, err := mgr.GetSettings(ctx, accountID, adminUserID)
|
||||
require.NoError(t, err, "GetSettings must succeed after bootstrap")
|
||||
require.Equal(t, clusterAddr, before.Cluster, "cluster pinned at bootstrap")
|
||||
require.NotEmpty(t, before.Subdomain, "subdomain pinned at bootstrap")
|
||||
assert.False(t, before.EnablePromptCollection, "prompt collection defaults off")
|
||||
|
||||
// Attempt to flip toggles AND smuggle a different cluster/subdomain — the
|
||||
// immutable fields must be ignored.
|
||||
updated, err := mgr.UpdateSettings(ctx, adminUserID, &agenttypes.Settings{
|
||||
AccountID: accountID,
|
||||
Cluster: "attacker.cluster",
|
||||
Subdomain: "evil",
|
||||
EnableLogCollection: true,
|
||||
EnablePromptCollection: true,
|
||||
RedactPii: true,
|
||||
})
|
||||
require.NoError(t, err, "UpdateSettings must succeed")
|
||||
assert.Equal(t, before.Cluster, updated.Cluster, "cluster is immutable and must be preserved")
|
||||
assert.Equal(t, before.Subdomain, updated.Subdomain, "subdomain is immutable and must be preserved")
|
||||
assert.True(t, updated.EnableLogCollection, "log collection toggle must apply")
|
||||
assert.True(t, updated.EnablePromptCollection, "prompt collection toggle must apply")
|
||||
assert.True(t, updated.RedactPii, "redact toggle must apply")
|
||||
|
||||
reloaded, err := am.Store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, before.Cluster, reloaded.Cluster, "persisted cluster unchanged")
|
||||
assert.True(t, reloaded.EnablePromptCollection, "persisted prompt collection toggled on")
|
||||
}
|
||||
199
management/server/agentnetwork_proxypeer_restart_test.go
Normal file
199
management/server/agentnetwork_proxypeer_restart_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
agenttypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// TestAgentNetwork_ProxyRestart_PropagatesNewPeerAndDropsStale is the no-mock
|
||||
// regression guard for the bug the user reported: restarting the proxy creates
|
||||
// a fresh embedded peer with a NEW WireGuard public key (the proxy generates
|
||||
// the keypair on every startup at proxy/internal/roundtrip/netbird.go:312).
|
||||
// The PRIOR embedded peer record is never deleted on management, so the
|
||||
// account accumulates a stale peer holding a stale CGNAT IP. Other peers
|
||||
// in the account either keep routing to the dead IP, or — if synth DNS
|
||||
// picks the wrong record — never see the new IP at all.
|
||||
//
|
||||
// What this test exercises (no mocks):
|
||||
// - real SQLite test store
|
||||
// - real DefaultAccountManager, network-map controller, peer-update channels
|
||||
// - real peers.Manager.CreateProxyPeer path (the very method the proxy
|
||||
// invokes over gRPC on every startup)
|
||||
// - real agentnetwork.Manager + synth chain so the client receives a
|
||||
// concrete DNS record that must point at the LATEST proxy peer.
|
||||
//
|
||||
// Pre-fix expected behavior (red): two embedded peers exist after the
|
||||
// "restart"; the synth DNS record points at the stale one; the client
|
||||
// receives an update reflecting the new peer but the old one lingers.
|
||||
// Post-fix expected behavior (green): exactly one embedded peer exists
|
||||
// after restart (with the new key) AND the client's network map carries
|
||||
// the synth DNS pointing at that new peer's CGNAT IP.
|
||||
func TestAgentNetwork_ProxyRestart_PropagatesNewPeerAndDropsStale(t *testing.T) {
|
||||
am, updateManager, err := createManager(t)
|
||||
require.NoError(t, err, "createManager must succeed")
|
||||
ctx := context.Background()
|
||||
|
||||
const (
|
||||
accountID = "an-restart-acct"
|
||||
adminUserID = "an-restart-admin"
|
||||
groupAID = "an-restart-grp-A"
|
||||
clusterAddr = "eu.proxy.netbird.io"
|
||||
clientKey = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8="
|
||||
// Two different proxy pubkeys — the "before" and "after" of a
|
||||
// proxy-process restart with fresh-keypair generation.
|
||||
proxyKey1 = "Aaaaa1aaaaYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8="
|
||||
proxyKey2 = "Bbbbb2bbbbYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8="
|
||||
)
|
||||
|
||||
// --- Account scaffold ---
|
||||
account := newAccountWithId(ctx, accountID, adminUserID, "an-restart.test", "", "", false)
|
||||
require.NoError(t, am.Store.SaveAccount(ctx, account))
|
||||
|
||||
clientPeer := &nbpeer.Peer{
|
||||
Key: clientKey,
|
||||
Name: "an-restart-client",
|
||||
DNSLabel: "an-restart-client",
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "an-restart-client", GoOS: "linux", WtVersion: "development"},
|
||||
}
|
||||
addedClient, _, _, _, err := am.AddPeer(ctx, "", "", adminUserID, clientPeer, false)
|
||||
require.NoError(t, err, "AddPeer for client must succeed")
|
||||
require.NoError(t, am.MarkPeerConnected(ctx, clientKey, accountID, time.Now().UnixNano(), &types.NetworkMap{}),
|
||||
"MarkPeerConnected for the client peer must succeed (affected-peer fan-out skips disconnected peers)")
|
||||
|
||||
// Place the client in group A so the synth policy reaches it.
|
||||
account, err = am.Store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
account.Groups[groupAID] = &types.Group{ID: groupAID, Name: "groupA", Peers: []string{addedClient.ID}}
|
||||
require.NoError(t, am.Store.SaveAccount(ctx, account), "SaveAccount must persist group A")
|
||||
|
||||
// --- Real peers + agent-network managers ---
|
||||
permMgr := permissions.NewManager(am.Store)
|
||||
peersMgr := peers.NewManager(am.Store, permMgr)
|
||||
peersMgr.SetAccountManager(am)
|
||||
peersMgr.SetNetworkMapController(am.networkMapController)
|
||||
agentMgr := agentnetwork.NewManager(am.Store, permMgr, am, nil)
|
||||
|
||||
// Subscribe BEFORE any state-mutating call so we don't lose the update
|
||||
// that contains the synth DNS record.
|
||||
clientCh := updateManager.CreateChannel(ctx, addedClient.ID)
|
||||
t.Cleanup(func() { updateManager.CloseChannel(ctx, addedClient.ID) })
|
||||
drain(clientCh)
|
||||
|
||||
// --- First proxy startup: register peer key K1, then mark it
|
||||
// connected. In production the proxy follows CreateProxyPeer with the
|
||||
// regular sync stream which lands on MarkPeerConnected; the synth DNS
|
||||
// path filters out peers that aren't Connected (types/account.go:323),
|
||||
// so without this step no DNS record would be emitted.
|
||||
require.NoError(t, peersMgr.CreateProxyPeer(ctx, accountID, proxyKey1, clusterAddr),
|
||||
"first CreateProxyPeer (proxy startup) must succeed")
|
||||
|
||||
peer1ID, err := am.Store.GetPeerIDByKey(ctx, store.LockingStrengthNone, proxyKey1)
|
||||
require.NoError(t, err, "proxy peer for K1 must be persisted after CreateProxyPeer")
|
||||
require.NotEmpty(t, peer1ID)
|
||||
|
||||
require.NoError(t, am.MarkPeerConnected(ctx, proxyKey1, accountID, time.Now().UnixNano(), &types.NetworkMap{}),
|
||||
"MarkPeerConnected for K1 must succeed")
|
||||
|
||||
account, err = am.Store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
proxyIP1 := account.Peers[peer1ID].IP.String()
|
||||
require.NotEmpty(t, proxyIP1, "K1 must have an assigned overlay IP")
|
||||
|
||||
// --- Provider + policy. CreateProvider / CreatePolicy trigger the
|
||||
// agentnetwork reconcile which runs UpdateAccountPeers; the resulting
|
||||
// NetworkMap delivered to the client carries the synth DNS record
|
||||
// pointing at K1's IP. ---
|
||||
provider, err := agentMgr.CreateProvider(ctx, adminUserID, &agenttypes.Provider{
|
||||
AccountID: accountID,
|
||||
ProviderID: "openai_api",
|
||||
Name: "openai-test",
|
||||
UpstreamURL: "https://api.openai.com",
|
||||
APIKey: "sk-test-key",
|
||||
Enabled: true,
|
||||
Models: []agenttypes.ProviderModel{{ID: "gpt-5.4"}},
|
||||
}, clusterAddr)
|
||||
require.NoError(t, err, "CreateProvider must succeed")
|
||||
|
||||
_, err = agentMgr.CreatePolicy(ctx, adminUserID, &agenttypes.Policy{
|
||||
AccountID: accountID,
|
||||
Name: "p1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{groupAID},
|
||||
DestinationProviderIDs: []string{provider.ID},
|
||||
})
|
||||
require.NoError(t, err, "CreatePolicy must succeed")
|
||||
|
||||
settings, err := am.Store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
fqdn := settings.Endpoint()
|
||||
|
||||
rdata1 := awaitZoneRData(clientCh, clusterAddr, fqdn, true)
|
||||
require.Equal(t, proxyIP1, rdata1,
|
||||
"client must receive a synth DNS record pointing at K1's overlay IP after the synth path runs")
|
||||
drain(clientCh)
|
||||
|
||||
// --- Proxy restart: NEW keypair K2, same account, same cluster ---
|
||||
require.NoError(t, peersMgr.CreateProxyPeer(ctx, accountID, proxyKey2, clusterAddr),
|
||||
"second CreateProxyPeer (proxy restart with fresh keypair) must succeed")
|
||||
|
||||
peer2ID, err := am.Store.GetPeerIDByKey(ctx, store.LockingStrengthNone, proxyKey2)
|
||||
require.NoError(t, err, "proxy peer for K2 must be persisted after restart")
|
||||
require.NotEmpty(t, peer2ID)
|
||||
|
||||
require.NoError(t, am.MarkPeerConnected(ctx, proxyKey2, accountID, time.Now().UnixNano(), &types.NetworkMap{}),
|
||||
"MarkPeerConnected for K2 must succeed")
|
||||
|
||||
// In production the agent's sync stream pulls a fresh NetworkMap as
|
||||
// part of its normal reconcile cadence; in this isolated test
|
||||
// MarkPeerConnected's affected-peer fan-out can race the channel-side
|
||||
// buffer in a way that swallows the synth-DNS-bearing update before
|
||||
// our await reads it. Trigger an explicit account-wide fan-out so the
|
||||
// assertion below tests what production actually delivers, not the
|
||||
// in-test buffer race.
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
|
||||
|
||||
account, err = am.Store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
proxyIP2 := account.Peers[peer2ID].IP.String()
|
||||
require.NotEmpty(t, proxyIP2, "K2 must have an assigned overlay IP")
|
||||
require.NotEqual(t, proxyIP1, proxyIP2, "K2 must get a different overlay IP than K1 (sanity)")
|
||||
|
||||
// CRITICAL ASSERTION 1: K1 must no longer be in the store. The SqlStore
|
||||
// returns ("", nil) for a missing key rather than NotFound, so assert
|
||||
// on the returned ID being empty.
|
||||
staleID, err := am.Store.GetPeerIDByKey(ctx, store.LockingStrengthNone, proxyKey1)
|
||||
require.NoError(t, err, "GetPeerIDByKey for a missing peer must not error")
|
||||
assert.Empty(t, staleID,
|
||||
"stale embedded proxy peer K1 must be removed when a new embedded peer registers for the same (account, cluster); pre-fix this assertion fails because management never cleans up the prior peer record")
|
||||
|
||||
// CRITICAL ASSERTION 2: exactly one embedded proxy peer remains, and it
|
||||
// is K2.
|
||||
account, err = am.Store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
embeddedKeys := []string{}
|
||||
for _, p := range account.Peers {
|
||||
if p.ProxyMeta.Embedded {
|
||||
embeddedKeys = append(embeddedKeys, p.Key)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, []string{proxyKey2}, embeddedKeys,
|
||||
"after a proxy restart exactly one embedded proxy peer should remain — the one with the new key K2")
|
||||
|
||||
// CRITICAL ASSERTION 3: the synth DNS record the client receives now
|
||||
// points at K2's IP, not K1's.
|
||||
rdata2 := awaitZoneRData(clientCh, clusterAddr, fqdn, true)
|
||||
assert.Equal(t, proxyIP2, rdata2,
|
||||
"after proxy restart, the client's synth DNS record must point at the NEW embedded peer's IP, not the stale K1 IP")
|
||||
}
|
||||
212
management/server/agentnetwork_realstack_test.go
Normal file
212
management/server/agentnetwork_realstack_test.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
networkmap "github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
agenttypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbproto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// TestAgentNetwork_ProviderCRUD_FansOutToProxyAndClientPeers is the no-mock
|
||||
// integration test for the live propagation path: a provider/policy mutation
|
||||
// through the real agentnetwork.Manager triggers the real
|
||||
// DefaultAccountManager.UpdateAccountPeers, which runs the real network-map
|
||||
// controller (including AN-2b's injectAllProxyPolicies), and a network map is
|
||||
// computed and fanned out to BOTH the embedded proxy peer and the client peer.
|
||||
//
|
||||
// Unlike the synthesizer/reconcile unit tests, nothing here is mocked: real
|
||||
// SQLite store, real account manager + network-map controller, real
|
||||
// agentnetwork manager, real peer update channels. The client peer's delivered
|
||||
// map is asserted to actually carry the synth DNS surface, and provider
|
||||
// create/delete are exercised end to end.
|
||||
func TestAgentNetwork_ProviderCRUD_FansOutToProxyAndClientPeers(t *testing.T) {
|
||||
am, updateManager, err := createManager(t)
|
||||
require.NoError(t, err, "createManager must succeed")
|
||||
ctx := context.Background()
|
||||
|
||||
const (
|
||||
accountID = "agent-net-acct-1"
|
||||
adminUserID = "agent-net-admin-1"
|
||||
groupAID = "agent-net-grp-A"
|
||||
clusterAddr = "eu.proxy.netbird.io"
|
||||
clientKey = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8="
|
||||
proxyPeerID = "agent-net-proxy-peer-1"
|
||||
proxyPeerKey = "/yF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NyI="
|
||||
proxyIP = "100.64.0.99"
|
||||
)
|
||||
|
||||
account := newAccountWithId(ctx, accountID, adminUserID, "agent-net.test", "", "", false)
|
||||
require.NoError(t, am.Store.SaveAccount(ctx, account), "SaveAccount must succeed")
|
||||
|
||||
// Real client peer through the production AddPeer path.
|
||||
clientPeer := &nbpeer.Peer{
|
||||
Key: clientKey,
|
||||
Name: "agent-net-client",
|
||||
DNSLabel: "agent-net-client",
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "agent-net-client", GoOS: "linux", WtVersion: "development"},
|
||||
}
|
||||
addedClient, _, _, _, err := am.AddPeer(ctx, "", "", adminUserID, clientPeer, false)
|
||||
require.NoError(t, err, "AddPeer must add the client peer")
|
||||
|
||||
// Inject a connected embedded proxy peer + put the client in the source group.
|
||||
account, err = am.Store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
account.Peers[proxyPeerID] = &nbpeer.Peer{
|
||||
ID: proxyPeerID,
|
||||
AccountID: accountID,
|
||||
Key: proxyPeerKey,
|
||||
IP: netip.MustParseAddr(proxyIP),
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: clusterAddr},
|
||||
DNSLabel: "agent-net-proxy",
|
||||
}
|
||||
account.Groups[groupAID] = &types.Group{ID: groupAID, Name: "groupA", Peers: []string{addedClient.ID}}
|
||||
require.NoError(t, am.Store.SaveAccount(ctx, account), "SaveAccount must persist proxy peer + group")
|
||||
|
||||
// Subscribe to BOTH peers' update channels — this is how we observe the
|
||||
// real fan-out.
|
||||
clientCh := updateManager.CreateChannel(ctx, addedClient.ID)
|
||||
proxyCh := updateManager.CreateChannel(ctx, proxyPeerID)
|
||||
t.Cleanup(func() {
|
||||
updateManager.CloseChannel(ctx, addedClient.ID)
|
||||
updateManager.CloseChannel(ctx, proxyPeerID)
|
||||
})
|
||||
drain(clientCh)
|
||||
drain(proxyCh)
|
||||
|
||||
// Real agentnetwork manager wired to the real account manager. proxyController
|
||||
// is nil (no gRPC cluster fan-out here) — the reconcile still fires
|
||||
// UpdateAccountPeers, which is the path under test.
|
||||
agentMgr := agentnetwork.NewManager(am.Store, permissions.NewManager(am.Store), am, nil)
|
||||
|
||||
provider, err := agentMgr.CreateProvider(ctx, adminUserID, &agenttypes.Provider{
|
||||
AccountID: accountID,
|
||||
ProviderID: "openai_api",
|
||||
Name: "openai-test",
|
||||
UpstreamURL: "https://api.openai.com",
|
||||
APIKey: "sk-test-key",
|
||||
Enabled: true,
|
||||
Models: []agenttypes.ProviderModel{{ID: "gpt-5.4"}},
|
||||
}, clusterAddr)
|
||||
require.NoError(t, err, "CreateProvider must succeed")
|
||||
|
||||
policy, err := agentMgr.CreatePolicy(ctx, adminUserID, &agenttypes.Policy{
|
||||
AccountID: accountID,
|
||||
Name: "p1",
|
||||
Enabled: true,
|
||||
SourceGroups: []string{groupAID},
|
||||
DestinationProviderIDs: []string{provider.ID},
|
||||
})
|
||||
require.NoError(t, err, "CreatePolicy must succeed")
|
||||
|
||||
settings, err := am.Store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
fqdn := settings.Endpoint()
|
||||
|
||||
// Both peers must receive a fan-out. The provider-create reconcile fires
|
||||
// before the policy exists (synth service then has no AccessGroups, so no
|
||||
// zone), and the async update buffer can collapse/reorder updates — so we
|
||||
// poll until the client's delivered map actually carries the synth record.
|
||||
rdata := awaitZoneRData(clientCh, clusterAddr, fqdn, true)
|
||||
assert.Equal(t, proxyIP, rdata,
|
||||
"client peer's delivered network map must contain the synth DNS record pointing at the embedded proxy peer")
|
||||
require.True(t, awaitUpdate(proxyCh), "embedded proxy peer must also receive a netmap update after create")
|
||||
|
||||
// UPDATE the provider — a new model on the existing service must still
|
||||
// reconcile and keep the private surface routable (the live MODIFIED path).
|
||||
provider.Models = append(provider.Models, agenttypes.ProviderModel{ID: "gpt-5.4-mini"})
|
||||
_, err = agentMgr.UpdateProvider(ctx, adminUserID, provider)
|
||||
require.NoError(t, err, "UpdateProvider must succeed")
|
||||
assert.Equal(t, proxyIP, awaitZoneRData(clientCh, clusterAddr, fqdn, true),
|
||||
"client peer must still resolve the synth record after the provider is updated")
|
||||
require.True(t, awaitUpdate(proxyCh), "embedded proxy peer must also receive a netmap update after update")
|
||||
|
||||
// DELETE: detach the policy first (provider is in use), then drop the
|
||||
// provider. Both peers update again and the synth surface disappears.
|
||||
require.NoError(t, agentMgr.DeletePolicy(ctx, accountID, adminUserID, policy.ID), "DeletePolicy must succeed")
|
||||
require.NoError(t, agentMgr.DeleteProvider(ctx, accountID, adminUserID, provider.ID), "DeleteProvider must succeed")
|
||||
|
||||
require.True(t, awaitUpdate(proxyCh), "embedded proxy peer must also receive a netmap update after delete")
|
||||
assert.Empty(t, awaitZoneRData(clientCh, clusterAddr, fqdn, false),
|
||||
"synth DNS record must be gone from the client's map after the provider is deleted")
|
||||
}
|
||||
|
||||
// awaitZoneRData drains the channel for up to 8s. When wantPresent is true it
|
||||
// returns as soon as the synth record appears (its RData). When false it drains
|
||||
// to quiescence and returns the RData of the last delivered map (expected empty
|
||||
// once the provider is gone), tolerating stale buffered updates that still
|
||||
// carry the zone.
|
||||
func awaitZoneRData(ch <-chan *networkmap.UpdateMessage, clusterAddr, fqdn string, wantPresent bool) string {
|
||||
deadline := time.After(8 * time.Second)
|
||||
last := ""
|
||||
for {
|
||||
select {
|
||||
case m := <-ch:
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
last = synthZoneRData(m.Update, clusterAddr, fqdn)
|
||||
if wantPresent && last != "" {
|
||||
return last
|
||||
}
|
||||
case <-time.After(750 * time.Millisecond):
|
||||
return last
|
||||
case <-deadline:
|
||||
return last
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// awaitUpdate reports whether at least one update arrives within the window.
|
||||
func awaitUpdate(ch <-chan *networkmap.UpdateMessage) bool {
|
||||
select {
|
||||
case m := <-ch:
|
||||
return m != nil
|
||||
case <-time.After(5 * time.Second):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// drain empties any buffered updates (e.g. from AddPeer/SaveAccount) so the
|
||||
// next observation reflects the operation under test.
|
||||
func drain(ch <-chan *networkmap.UpdateMessage) {
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// synthZoneRData returns the RData of the synth A record (record name == fqdn)
|
||||
// inside the cluster's custom zone, or "" when absent.
|
||||
func synthZoneRData(sync *nbproto.SyncResponse, clusterAddr, fqdn string) string {
|
||||
if sync == nil {
|
||||
return ""
|
||||
}
|
||||
for _, zone := range sync.GetNetworkMap().GetDNSConfig().GetCustomZones() {
|
||||
if zone.GetDomain() != dns.Fqdn(clusterAddr) {
|
||||
continue
|
||||
}
|
||||
for _, rec := range zone.GetRecords() {
|
||||
if rec.GetName() == dns.Fqdn(fqdn) {
|
||||
return rec.GetRData()
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -23,6 +23,8 @@ import (
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/agentnetwork"
|
||||
agentnetworkhandlers "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/handlers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
@@ -59,7 +61,7 @@ import (
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc, agentNetworkManager agentnetwork.Manager) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -124,6 +126,9 @@ func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager accou
|
||||
zonesManager.RegisterEndpoints(router, zManager)
|
||||
recordsManager.RegisterEndpoints(router, rManager)
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
if agentNetworkManager != nil {
|
||||
agentnetworkhandlers.RegisterEndpoints(agentNetworkManager, router)
|
||||
}
|
||||
instance.AddEndpoints(instanceManager, accountManager, router)
|
||||
instance.AddVersionEndpoint(instanceManager, router)
|
||||
if serviceManager != nil && reverseProxyDomainManager != nil {
|
||||
|
||||
@@ -137,7 +137,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -267,7 +267,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ const (
|
||||
Pats Module = "pats"
|
||||
IdentityProviders Module = "identity_providers"
|
||||
Services Module = "services"
|
||||
AgentNetwork Module = "agent_network"
|
||||
)
|
||||
|
||||
var All = map[Module]struct{}{
|
||||
@@ -38,4 +39,5 @@ var All = map[Module]struct{}{
|
||||
Pats: {},
|
||||
IdentityProviders: {},
|
||||
Services: {},
|
||||
AgentNetwork: {},
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -137,6 +138,10 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
|
||||
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
|
||||
&agentNetworkTypes.Provider{}, &agentNetworkTypes.Policy{}, &agentNetworkTypes.Guardrail{}, &agentNetworkTypes.Settings{},
|
||||
&agentNetworkTypes.Consumption{}, &agentNetworkTypes.AccountBudgetRule{},
|
||||
&agentNetworkTypes.AgentNetworkAccessLog{}, &agentNetworkTypes.AgentNetworkAccessLogGroup{},
|
||||
&agentNetworkTypes.AgentNetworkUsage{}, &agentNetworkTypes.AgentNetworkUsageGroup{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
@@ -5573,6 +5578,255 @@ func (s *SqlStore) CreateAccessLog(ctx context.Context, logEntry *accesslogs.Acc
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateAgentNetworkAccessLog persists a flattened agent-network access-log
|
||||
// entry together with its authorising-group child rows in a single
|
||||
// transaction.
|
||||
func (s *SqlStore) CreateAgentNetworkAccessLog(ctx context.Context, entry *agentNetworkTypes.AgentNetworkAccessLog, groups []agentNetworkTypes.AgentNetworkAccessLogGroup) error {
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Idempotent on the log id / (log_id, group_id) so a proxy resend of the
|
||||
// same entry can't fail the request.
|
||||
if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(entry).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if len(groups) > 0 {
|
||||
if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&groups).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"account_id": entry.AccountID,
|
||||
"service_id": entry.ServiceID,
|
||||
"model": entry.Model,
|
||||
}).Errorf("failed to create agent-network access log entry in store: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to create agent-network access log entry in store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateAgentNetworkUsage persists a stripped agent-network usage record
|
||||
// together with its authorising-group child rows in a single transaction.
|
||||
func (s *SqlStore) CreateAgentNetworkUsage(ctx context.Context, usage *agentNetworkTypes.AgentNetworkUsage, groups []agentNetworkTypes.AgentNetworkUsageGroup) error {
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Idempotent on the usage id / (usage_id, group_id) so a proxy resend of
|
||||
// the same entry can't fail the request.
|
||||
if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(usage).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if len(groups) > 0 {
|
||||
if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&groups).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"account_id": usage.AccountID,
|
||||
"model": usage.Model,
|
||||
}).Errorf("failed to create agent-network usage record in store: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to create agent-network usage record in store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteOldAgentNetworkAccessLogs deletes an account's access-log rows (and
|
||||
// their authorising-group child rows) older than the cutoff. Usage records are
|
||||
// untouched — they are the long-term aggregate. Returns the number of log rows
|
||||
// deleted.
|
||||
func (s *SqlStore) DeleteOldAgentNetworkAccessLogs(ctx context.Context, accountID string, olderThan time.Time) (int64, error) {
|
||||
var deleted int64
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Remove group child rows for the soon-to-be-deleted logs first.
|
||||
if err := tx.Exec(
|
||||
"DELETE FROM agent_network_access_log_group WHERE account_id = ? AND log_id IN (SELECT id FROM agent_network_access_log WHERE account_id = ? AND timestamp < ?)",
|
||||
accountID, accountID, olderThan,
|
||||
).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
res := tx.Where("account_id = ? AND timestamp < ?", accountID, olderThan).
|
||||
Delete(&agentNetworkTypes.AgentNetworkAccessLog{})
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
}
|
||||
deleted = res.RowsAffected
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete old agent-network access logs for account %s: %v", accountID, err)
|
||||
return 0, status.Errorf(status.Internal, "failed to delete old agent-network access logs")
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// GetAgentNetworkUsageRows returns the stripped usage rows for an account that
|
||||
// match the filter (date / user / group / provider / model). Aggregation into
|
||||
// time buckets happens in the manager so granularities stay engine-portable.
|
||||
func (s *SqlStore) GetAgentNetworkUsageRows(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkUsage, error) {
|
||||
var rows []*agentNetworkTypes.AgentNetworkUsage
|
||||
|
||||
query := s.applyAgentNetworkUsageFilters(
|
||||
s.db.Where(accountIDCondition, accountID),
|
||||
filter,
|
||||
).Order("timestamp ASC")
|
||||
|
||||
if lockStrength != LockingStrengthNone {
|
||||
query = query.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
if err := query.Find(&rows).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get agent-network usage rows from store: %v", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get agent-network usage rows from store")
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// applyAgentNetworkUsageFilters applies the shared access-log filter's
|
||||
// date/user/group/provider/model conditions to a usage-table query. Pagination,
|
||||
// sort and free-text search are ignored — the overview is an aggregate.
|
||||
func (s *SqlStore) applyAgentNetworkUsageFilters(query *gorm.DB, filter agentNetworkTypes.AgentNetworkAccessLogFilter) *gorm.DB {
|
||||
if filter.UserID != nil {
|
||||
query = query.Where("user_id = ?", *filter.UserID)
|
||||
}
|
||||
if filter.SessionID != nil {
|
||||
query = query.Where("session_id = ?", *filter.SessionID)
|
||||
}
|
||||
if len(filter.ProviderIDs) > 0 {
|
||||
query = query.Where("resolved_provider_id IN ?", filter.ProviderIDs)
|
||||
}
|
||||
if len(filter.Models) > 0 {
|
||||
query = query.Where("model IN ?", filter.Models)
|
||||
}
|
||||
if len(filter.GroupIDs) > 0 {
|
||||
query = query.Where(
|
||||
"id IN (SELECT usage_id FROM agent_network_request_usage_group WHERE group_id IN ?)",
|
||||
filter.GroupIDs,
|
||||
)
|
||||
}
|
||||
if filter.StartDate != nil {
|
||||
query = query.Where("timestamp >= ?", *filter.StartDate)
|
||||
}
|
||||
if filter.EndDate != nil {
|
||||
query = query.Where("timestamp <= ?", *filter.EndDate)
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// GetAgentNetworkAccessLogs retrieves flattened agent-network access logs for
|
||||
// an account with server-side pagination, filtering and sorting. Authorising
|
||||
// group ids are hydrated from the group child table for the returned page.
|
||||
func (s *SqlStore) GetAgentNetworkAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkAccessLog, int64, error) {
|
||||
var logs []*agentNetworkTypes.AgentNetworkAccessLog
|
||||
var totalCount int64
|
||||
|
||||
countQuery := s.applyAgentNetworkAccessLogFilters(
|
||||
s.db.Model(&agentNetworkTypes.AgentNetworkAccessLog{}).Where(accountIDCondition, accountID),
|
||||
filter,
|
||||
)
|
||||
if err := countQuery.Count(&totalCount).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to count agent-network access logs: %v", err)
|
||||
return nil, 0, status.Errorf(status.Internal, "failed to count agent-network access logs")
|
||||
}
|
||||
|
||||
query := s.applyAgentNetworkAccessLogFilters(
|
||||
s.db.Where(accountIDCondition, accountID),
|
||||
filter,
|
||||
).
|
||||
Order(filter.GetSortColumn() + " " + filter.GetSortOrder()).
|
||||
Limit(filter.GetLimit()).
|
||||
Offset(filter.GetOffset())
|
||||
|
||||
if lockStrength != LockingStrengthNone {
|
||||
query = query.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
if err := query.Find(&logs).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get agent-network access logs from store: %v", err)
|
||||
return nil, 0, status.Errorf(status.Internal, "failed to get agent-network access logs from store")
|
||||
}
|
||||
|
||||
if err := s.hydrateAgentNetworkAccessLogGroups(ctx, accountID, logs); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return logs, totalCount, nil
|
||||
}
|
||||
|
||||
// applyAgentNetworkAccessLogFilters applies the filter conditions to a query.
|
||||
func (s *SqlStore) applyAgentNetworkAccessLogFilters(query *gorm.DB, filter agentNetworkTypes.AgentNetworkAccessLogFilter) *gorm.DB {
|
||||
if filter.Search != nil {
|
||||
p := "%" + *filter.Search + "%"
|
||||
query = query.Where(
|
||||
"id LIKE ? OR host LIKE ? OR path LIKE ? OR model LIKE ? OR user_id IN (SELECT id FROM users WHERE email LIKE ? OR name LIKE ?)",
|
||||
p, p, p, p, p, p,
|
||||
)
|
||||
}
|
||||
if filter.UserID != nil {
|
||||
query = query.Where("user_id = ?", *filter.UserID)
|
||||
}
|
||||
if filter.SessionID != nil {
|
||||
query = query.Where("session_id = ?", *filter.SessionID)
|
||||
}
|
||||
if filter.Decision != nil {
|
||||
query = query.Where("decision = ?", *filter.Decision)
|
||||
}
|
||||
if filter.PathPrefix != nil {
|
||||
query = query.Where("path LIKE ?", *filter.PathPrefix+"%")
|
||||
}
|
||||
if len(filter.ProviderIDs) > 0 {
|
||||
query = query.Where("resolved_provider_id IN ?", filter.ProviderIDs)
|
||||
}
|
||||
if len(filter.Models) > 0 {
|
||||
query = query.Where("model IN ?", filter.Models)
|
||||
}
|
||||
if len(filter.GroupIDs) > 0 {
|
||||
query = query.Where(
|
||||
"id IN (SELECT log_id FROM agent_network_access_log_group WHERE group_id IN ?)",
|
||||
filter.GroupIDs,
|
||||
)
|
||||
}
|
||||
if filter.StartDate != nil {
|
||||
query = query.Where("timestamp >= ?", *filter.StartDate)
|
||||
}
|
||||
if filter.EndDate != nil {
|
||||
query = query.Where("timestamp <= ?", *filter.EndDate)
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// hydrateAgentNetworkAccessLogGroups loads the authorising group ids for the
|
||||
// given page of entries and assigns them onto each entry's GroupIDs field.
|
||||
func (s *SqlStore) hydrateAgentNetworkAccessLogGroups(ctx context.Context, accountID string, logs []*agentNetworkTypes.AgentNetworkAccessLog) error {
|
||||
if len(logs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ids := make([]string, 0, len(logs))
|
||||
for _, l := range logs {
|
||||
ids = append(ids, l.ID)
|
||||
}
|
||||
|
||||
var rows []agentNetworkTypes.AgentNetworkAccessLogGroup
|
||||
if err := s.db.
|
||||
Where(accountIDCondition, accountID).
|
||||
Where("log_id IN ?", ids).
|
||||
Find(&rows).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to hydrate agent-network access log groups: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to hydrate agent-network access log groups")
|
||||
}
|
||||
|
||||
byLog := make(map[string][]string, len(logs))
|
||||
for _, r := range rows {
|
||||
byLog[r.LogID] = append(byLog[r.LogID], r.GroupID)
|
||||
}
|
||||
for _, l := range logs {
|
||||
l.GroupIDs = byLog[l.ID]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountAccessLogs retrieves access logs for a given account with pagination and filtering
|
||||
func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
||||
var logs []*accesslogs.AccessLogEntry
|
||||
|
||||
623
management/server/store/sql_store_agentnetwork.go
Normal file
623
management/server/store/sql_store_agentnetwork.go
Normal file
@@ -0,0 +1,623 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// GetAllAgentNetworkProviders returns Agent Network providers across
|
||||
// every account. Used by the synthesizer to build the global service map.
|
||||
func (s *SqlStore) GetAllAgentNetworkProviders(ctx context.Context, lockStrength LockingStrength) ([]*agentNetworkTypes.Provider, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var providers []*agentNetworkTypes.Provider
|
||||
if result := tx.Find(&providers); result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get all agent network providers from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get all agent network providers from store")
|
||||
}
|
||||
|
||||
for _, provider := range providers {
|
||||
if err := provider.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to decrypt agent network provider %s: %v", provider.ID, err)
|
||||
return nil, status.Errorf(status.Internal, "failed to decrypt agent network provider")
|
||||
}
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountAgentNetworkProviders(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Provider, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var providers []*agentNetworkTypes.Provider
|
||||
result := tx.Find(&providers, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get agent network providers from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get agent network providers from store")
|
||||
}
|
||||
|
||||
for _, provider := range providers {
|
||||
if err := provider.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to decrypt agent network provider %s: %v", provider.ID, err)
|
||||
return nil, status.Errorf(status.Internal, "failed to decrypt agent network provider")
|
||||
}
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAgentNetworkProviderByID(ctx context.Context, lockStrength LockingStrength, accountID, providerID string) (*agentNetworkTypes.Provider, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var provider *agentNetworkTypes.Provider
|
||||
result := tx.Take(&provider, accountAndIDQueryCondition, accountID, providerID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAgentNetworkProviderNotFoundError(providerID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed to get agent network provider from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get agent network provider from store")
|
||||
}
|
||||
|
||||
if err := provider.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to decrypt agent network provider %s: %v", provider.ID, err)
|
||||
return nil, status.Errorf(status.Internal, "failed to decrypt agent network provider")
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SaveAgentNetworkProvider(ctx context.Context, provider *agentNetworkTypes.Provider) error {
|
||||
providerCopy := provider.Copy()
|
||||
if err := providerCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to encrypt agent network provider %s: %v", provider.ID, err)
|
||||
return status.Errorf(status.Internal, "failed to encrypt agent network provider")
|
||||
}
|
||||
|
||||
result := s.db.Save(providerCopy)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save agent network provider to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save agent network provider to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteAgentNetworkProvider(ctx context.Context, accountID, providerID string) error {
|
||||
result := s.db.Delete(&agentNetworkTypes.Provider{}, accountAndIDQueryCondition, accountID, providerID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete agent network provider from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete agent network provider from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewAgentNetworkProviderNotFoundError(providerID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountAgentNetworkPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Policy, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var policies []*agentNetworkTypes.Policy
|
||||
result := tx.Find(&policies, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get agent network policies from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get agent network policies from store")
|
||||
}
|
||||
|
||||
return policies, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAgentNetworkPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*agentNetworkTypes.Policy, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var policy *agentNetworkTypes.Policy
|
||||
result := tx.Take(&policy, accountAndIDQueryCondition, accountID, policyID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAgentNetworkPolicyNotFoundError(policyID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed to get agent network policy from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get agent network policy from store")
|
||||
}
|
||||
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SaveAgentNetworkPolicy(ctx context.Context, policy *agentNetworkTypes.Policy) error {
|
||||
result := s.db.Save(policy)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save agent network policy to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save agent network policy to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteAgentNetworkPolicy(ctx context.Context, accountID, policyID string) error {
|
||||
result := s.db.Delete(&agentNetworkTypes.Policy{}, accountAndIDQueryCondition, accountID, policyID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete agent network policy from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete agent network policy from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewAgentNetworkPolicyNotFoundError(policyID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountAgentNetworkGuardrails(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Guardrail, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var guardrails []*agentNetworkTypes.Guardrail
|
||||
result := tx.Find(&guardrails, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get agent network guardrails from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get agent network guardrails from store")
|
||||
}
|
||||
|
||||
return guardrails, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAgentNetworkGuardrailByID(ctx context.Context, lockStrength LockingStrength, accountID, guardrailID string) (*agentNetworkTypes.Guardrail, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var guardrail *agentNetworkTypes.Guardrail
|
||||
result := tx.Take(&guardrail, accountAndIDQueryCondition, accountID, guardrailID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAgentNetworkGuardrailNotFoundError(guardrailID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed to get agent network guardrail from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get agent network guardrail from store")
|
||||
}
|
||||
|
||||
return guardrail, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SaveAgentNetworkGuardrail(ctx context.Context, guardrail *agentNetworkTypes.Guardrail) error {
|
||||
result := s.db.Save(guardrail)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save agent network guardrail to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save agent network guardrail to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteAgentNetworkGuardrail(ctx context.Context, accountID, guardrailID string) error {
|
||||
result := s.db.Delete(&agentNetworkTypes.Guardrail{}, accountAndIDQueryCondition, accountID, guardrailID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete agent network guardrail from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete agent network guardrail from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewAgentNetworkGuardrailNotFoundError(guardrailID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAgentNetworkSettings returns the per-account Agent Network
|
||||
// settings row. Returns status.NotFound when no row exists.
|
||||
func (s *SqlStore) GetAgentNetworkSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*agentNetworkTypes.Settings, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var settings agentNetworkTypes.Settings
|
||||
result := tx.Take(&settings, "account_id = ?", accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "agent network settings for account %s not found", accountID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed to get agent network settings from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get agent network settings from store")
|
||||
}
|
||||
|
||||
return &settings, nil
|
||||
}
|
||||
|
||||
// GetAllAgentNetworkSettings returns every account's settings row. Used by the
|
||||
// access-log retention sweep to learn each account's retention window.
|
||||
func (s *SqlStore) GetAllAgentNetworkSettings(ctx context.Context, lockStrength LockingStrength) ([]*agentNetworkTypes.Settings, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var settings []*agentNetworkTypes.Settings
|
||||
if err := tx.Find(&settings).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list agent network settings: %v", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to list agent network settings")
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// GetAgentNetworkSettingsByCluster returns every Settings row pinned to
|
||||
// the given proxy cluster. Used by the bootstrap label generator to
|
||||
// build the set of subdomains already taken on a cluster.
|
||||
func (s *SqlStore) GetAgentNetworkSettingsByCluster(ctx context.Context, lockStrength LockingStrength, cluster string) ([]*agentNetworkTypes.Settings, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var settings []*agentNetworkTypes.Settings
|
||||
result := tx.Find(&settings, "cluster = ?", cluster)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get agent network settings by cluster from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get agent network settings by cluster from store")
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// SaveAgentNetworkSettings upserts the per-account Agent Network
|
||||
// settings row.
|
||||
func (s *SqlStore) SaveAgentNetworkSettings(ctx context.Context, settings *agentNetworkTypes.Settings) error {
|
||||
result := s.db.Save(settings)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save agent network settings to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save agent network settings to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IncrementAgentNetworkConsumption atomically upserts the consumption
|
||||
// row keyed on (account, dim_kind, dim_id, window_seconds, window_start)
|
||||
// and adds the supplied deltas. Concurrent calls from multiple proxy
|
||||
// nodes converge — the database performs the increment server-side via
|
||||
// ON CONFLICT DO UPDATE so no read-modify-write race exists.
|
||||
func (s *SqlStore) IncrementAgentNetworkConsumption(
|
||||
ctx context.Context,
|
||||
accountID string,
|
||||
kind agentNetworkTypes.ConsumptionDimension,
|
||||
dimID string,
|
||||
windowSeconds int64,
|
||||
windowStart time.Time,
|
||||
tokensIn, tokensOut int64,
|
||||
costUSD float64,
|
||||
) error {
|
||||
if accountID == "" || dimID == "" || windowSeconds <= 0 {
|
||||
return status.Errorf(status.InvalidArgument, "account_id, dim_id and window_seconds must be set")
|
||||
}
|
||||
// Deltas are added server-side via ON CONFLICT; a negative or non-finite
|
||||
// value would silently decrement / poison the persisted totals.
|
||||
if tokensIn < 0 || tokensOut < 0 || costUSD < 0 || math.IsNaN(costUSD) || math.IsInf(costUSD, 0) {
|
||||
return status.Errorf(status.InvalidArgument, "consumption deltas must be non-negative and finite")
|
||||
}
|
||||
row := agentNetworkTypes.Consumption{
|
||||
AccountID: accountID,
|
||||
DimensionKind: kind,
|
||||
DimensionID: dimID,
|
||||
WindowSeconds: windowSeconds,
|
||||
WindowStartUTC: windowStart.UTC(),
|
||||
TokensInput: tokensIn,
|
||||
TokensOutput: tokensOut,
|
||||
CostUSD: costUSD,
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
const tbl = "agent_network_consumption"
|
||||
err := s.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "account_id"},
|
||||
{Name: "dim_kind"},
|
||||
{Name: "dim_id"},
|
||||
{Name: "window_seconds"},
|
||||
{Name: "window_start_utc"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(map[string]any{
|
||||
"tokens_input": gorm.Expr(tbl+".tokens_input + ?", tokensIn),
|
||||
"tokens_output": gorm.Expr(tbl+".tokens_output + ?", tokensOut),
|
||||
"cost_usd": gorm.Expr(tbl+".cost_usd + ?", costUSD),
|
||||
"updated_at": time.Now().UTC(),
|
||||
}),
|
||||
}).Create(&row).Error
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to increment agent network consumption: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to increment agent network consumption")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAgentNetworkConsumption returns the consumption row for the exact
|
||||
// window key. Returns a zero-valued row (not found mapped to zero) so
|
||||
// callers can use the result as the headroom basis without nil checks.
|
||||
func (s *SqlStore) GetAgentNetworkConsumption(
|
||||
ctx context.Context,
|
||||
lockStrength LockingStrength,
|
||||
accountID string,
|
||||
kind agentNetworkTypes.ConsumptionDimension,
|
||||
dimID string,
|
||||
windowSeconds int64,
|
||||
windowStart time.Time,
|
||||
) (*agentNetworkTypes.Consumption, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
var row agentNetworkTypes.Consumption
|
||||
result := tx.Take(&row,
|
||||
"account_id = ? AND dim_kind = ? AND dim_id = ? AND window_seconds = ? AND window_start_utc = ?",
|
||||
accountID, kind, dimID, windowSeconds, windowStart.UTC())
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return &agentNetworkTypes.Consumption{
|
||||
AccountID: accountID,
|
||||
DimensionKind: kind,
|
||||
DimensionID: dimID,
|
||||
WindowSeconds: windowSeconds,
|
||||
WindowStartUTC: windowStart.UTC(),
|
||||
}, nil
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get agent network consumption: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get agent network consumption")
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// GetAgentNetworkConsumptionBatch reads many consumption counters for one
|
||||
// account in a single query, returning a map keyed by the exact
|
||||
// ConsumptionKey. Missing counters are simply absent from the map (callers
|
||||
// treat absence as a zero counter). Replaces the per-cap point reads the
|
||||
// policy selector previously issued one at a time.
|
||||
func (s *SqlStore) GetAgentNetworkConsumptionBatch(
|
||||
ctx context.Context,
|
||||
lockStrength LockingStrength,
|
||||
accountID string,
|
||||
keys []agentNetworkTypes.ConsumptionKey,
|
||||
) (map[agentNetworkTypes.ConsumptionKey]*agentNetworkTypes.Consumption, error) {
|
||||
out := make(map[agentNetworkTypes.ConsumptionKey]*agentNetworkTypes.Consumption, len(keys))
|
||||
if len(keys) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Collect the distinct dim ids, windows and window starts so a single
|
||||
// query scopes to exactly the current windows in play, then filter the
|
||||
// returned rows down to the exact requested keys.
|
||||
wanted := make(map[agentNetworkTypes.ConsumptionKey]struct{}, len(keys))
|
||||
dimSet := make(map[string]struct{})
|
||||
winSet := make(map[int64]struct{})
|
||||
startSet := make(map[time.Time]struct{})
|
||||
for _, k := range keys {
|
||||
k.WindowStartUTC = k.WindowStartUTC.UTC()
|
||||
wanted[k] = struct{}{}
|
||||
dimSet[k.DimID] = struct{}{}
|
||||
winSet[k.WindowSeconds] = struct{}{}
|
||||
startSet[k.WindowStartUTC] = struct{}{}
|
||||
}
|
||||
dimIDs := make([]string, 0, len(dimSet))
|
||||
for d := range dimSet {
|
||||
dimIDs = append(dimIDs, d)
|
||||
}
|
||||
windows := make([]int64, 0, len(winSet))
|
||||
for w := range winSet {
|
||||
windows = append(windows, w)
|
||||
}
|
||||
starts := make([]time.Time, 0, len(startSet))
|
||||
for t := range startSet {
|
||||
starts = append(starts, t)
|
||||
}
|
||||
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
var rows []*agentNetworkTypes.Consumption
|
||||
result := tx.Find(&rows,
|
||||
"account_id = ? AND dim_id IN ? AND window_seconds IN ? AND window_start_utc IN ?",
|
||||
accountID, dimIDs, windows, starts)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to batch-get agent network consumption: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get agent network consumption")
|
||||
}
|
||||
for _, row := range rows {
|
||||
k := agentNetworkTypes.ConsumptionKey{
|
||||
Kind: row.DimensionKind,
|
||||
DimID: row.DimensionID,
|
||||
WindowSeconds: row.WindowSeconds,
|
||||
WindowStartUTC: row.WindowStartUTC.UTC(),
|
||||
}
|
||||
if _, ok := wanted[k]; ok {
|
||||
out[k] = row
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// IncrementAgentNetworkConsumptionBatch applies the same usage delta to every
|
||||
// supplied counter inside a single transaction, so all per-(dimension, window)
|
||||
// counters a served request books are written atomically in one round-trip
|
||||
// instead of one upsert per counter. Keys are deduplicated by the caller.
|
||||
func (s *SqlStore) IncrementAgentNetworkConsumptionBatch(
|
||||
ctx context.Context,
|
||||
accountID string,
|
||||
keys []agentNetworkTypes.ConsumptionKey,
|
||||
tokensIn, tokensOut int64,
|
||||
costUSD float64,
|
||||
) error {
|
||||
if accountID == "" {
|
||||
return status.Errorf(status.InvalidArgument, "account_id must be set")
|
||||
}
|
||||
if tokensIn < 0 || tokensOut < 0 || costUSD < 0 || math.IsNaN(costUSD) || math.IsInf(costUSD, 0) {
|
||||
return status.Errorf(status.InvalidArgument, "consumption deltas must be non-negative and finite")
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
const tbl = "agent_network_consumption"
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
for _, k := range keys {
|
||||
if k.DimID == "" || k.WindowSeconds <= 0 {
|
||||
return status.Errorf(status.InvalidArgument, "dim_id and window_seconds must be set")
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
row := agentNetworkTypes.Consumption{
|
||||
AccountID: accountID,
|
||||
DimensionKind: k.Kind,
|
||||
DimensionID: k.DimID,
|
||||
WindowSeconds: k.WindowSeconds,
|
||||
WindowStartUTC: k.WindowStartUTC.UTC(),
|
||||
TokensInput: tokensIn,
|
||||
TokensOutput: tokensOut,
|
||||
CostUSD: costUSD,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "account_id"},
|
||||
{Name: "dim_kind"},
|
||||
{Name: "dim_id"},
|
||||
{Name: "window_seconds"},
|
||||
{Name: "window_start_utc"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(map[string]any{
|
||||
"tokens_input": gorm.Expr(tbl+".tokens_input + ?", tokensIn),
|
||||
"tokens_output": gorm.Expr(tbl+".tokens_output + ?", tokensOut),
|
||||
"cost_usd": gorm.Expr(tbl+".cost_usd + ?", costUSD),
|
||||
"updated_at": now,
|
||||
}),
|
||||
}).Create(&row).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to batch-increment agent network consumption: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to increment agent network consumption")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListAgentNetworkConsumption returns every consumption row recorded
|
||||
// for the account, ordered by window_start descending. Backs the
|
||||
// dashboard's basic counter view.
|
||||
func (s *SqlStore) ListAgentNetworkConsumption(
|
||||
ctx context.Context,
|
||||
lockStrength LockingStrength,
|
||||
accountID string,
|
||||
) ([]*agentNetworkTypes.Consumption, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
var rows []*agentNetworkTypes.Consumption
|
||||
result := tx.
|
||||
Order("window_start_utc DESC").
|
||||
Find(&rows, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list agent network consumption: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to list agent network consumption")
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// GetAccountAgentNetworkBudgetRules returns every account-level budget rule for
|
||||
// the account.
|
||||
func (s *SqlStore) GetAccountAgentNetworkBudgetRules(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.AccountBudgetRule, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var rules []*agentNetworkTypes.AccountBudgetRule
|
||||
result := tx.Find(&rules, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get agent network budget rules from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get agent network budget rules from store")
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
// GetAgentNetworkBudgetRuleByID returns a single budget rule scoped to the
|
||||
// account, or a NotFound error.
|
||||
func (s *SqlStore) GetAgentNetworkBudgetRuleByID(ctx context.Context, lockStrength LockingStrength, accountID, ruleID string) (*agentNetworkTypes.AccountBudgetRule, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var rule *agentNetworkTypes.AccountBudgetRule
|
||||
result := tx.Take(&rule, accountAndIDQueryCondition, accountID, ruleID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAgentNetworkBudgetRuleNotFoundError(ruleID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed to get agent network budget rule from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get agent network budget rule from store")
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// SaveAgentNetworkBudgetRule upserts a budget rule.
|
||||
func (s *SqlStore) SaveAgentNetworkBudgetRule(ctx context.Context, rule *agentNetworkTypes.AccountBudgetRule) error {
|
||||
result := s.db.Save(rule)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save agent network budget rule to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save agent network budget rule to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteAgentNetworkBudgetRule removes a budget rule scoped to the account.
|
||||
func (s *SqlStore) DeleteAgentNetworkBudgetRule(ctx context.Context, accountID, ruleID string) error {
|
||||
result := s.db.Delete(&agentNetworkTypes.AccountBudgetRule{}, accountAndIDQueryCondition, accountID, ruleID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete agent network budget rule from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete agent network budget rule from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewAgentNetworkBudgetRuleNotFoundError(ruleID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
202
management/server/store/sql_store_agentnetwork_accesslog_test.go
Normal file
202
management/server/store/sql_store_agentnetwork_accesslog_test.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
)
|
||||
|
||||
// TestAgentNetworkUsage_RealStore_RoundTrip drives CreateAgentNetworkUsage and
|
||||
// CreateAgentNetworkAccessLog through a real sqlite store to prove the schema
|
||||
// migrates and the inserts succeed for both a populated (allowed) entry and a
|
||||
// stripped (denied) entry.
|
||||
func TestAgentNetworkUsage_RealStore_RoundTrip(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
const accountID = "acc-anet-usage-1"
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Populated (allowed) usage row with two authorising groups.
|
||||
usage := &agentNetworkTypes.AgentNetworkUsage{
|
||||
ID: "log-allowed-1",
|
||||
AccountID: accountID,
|
||||
Timestamp: now,
|
||||
UserID: "user-alice",
|
||||
ResolvedProviderID: "prov-openai-1",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o",
|
||||
SessionID: "sess-round-trip-1",
|
||||
InputTokens: 1200,
|
||||
OutputTokens: 640,
|
||||
TotalTokens: 1840,
|
||||
CostUSD: 0.0231,
|
||||
}
|
||||
usageGroups := []agentNetworkTypes.AgentNetworkUsageGroup{
|
||||
{UsageID: usage.ID, GroupID: "grp-eng", AccountID: accountID},
|
||||
{UsageID: usage.ID, GroupID: "grp-oncall", AccountID: accountID},
|
||||
}
|
||||
require.NoError(t, s.CreateAgentNetworkUsage(ctx, usage, usageGroups), "populated usage insert must succeed")
|
||||
|
||||
// Stripped (denied / 403) usage row: no provider/model/tokens, no groups.
|
||||
denied := &agentNetworkTypes.AgentNetworkUsage{
|
||||
ID: "log-denied-1",
|
||||
AccountID: accountID,
|
||||
Timestamp: now,
|
||||
UserID: "user-bob",
|
||||
}
|
||||
require.NoError(t, s.CreateAgentNetworkUsage(ctx, denied, nil), "stripped usage insert must succeed")
|
||||
|
||||
// Idempotency: re-inserting the same id must not error.
|
||||
require.NoError(t, s.CreateAgentNetworkUsage(ctx, usage, usageGroups), "duplicate usage insert must be idempotent")
|
||||
|
||||
// Access-log row + group children.
|
||||
entry := &agentNetworkTypes.AgentNetworkAccessLog{
|
||||
ID: "log-allowed-1",
|
||||
AccountID: accountID,
|
||||
ServiceID: "agent-net-svc-1",
|
||||
Timestamp: now,
|
||||
UserID: "user-alice",
|
||||
StatusCode: 200,
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o",
|
||||
SessionID: "sess-round-trip-1",
|
||||
InputTokens: 1200,
|
||||
OutputTokens: 640,
|
||||
TotalTokens: 1840,
|
||||
CostUSD: 0.0231,
|
||||
}
|
||||
entryGroups := []agentNetworkTypes.AgentNetworkAccessLogGroup{
|
||||
{LogID: entry.ID, GroupID: "grp-eng", AccountID: accountID},
|
||||
{LogID: entry.ID, GroupID: "grp-oncall", AccountID: accountID},
|
||||
}
|
||||
require.NoError(t, s.CreateAgentNetworkAccessLog(ctx, entry, entryGroups), "access-log insert must succeed")
|
||||
|
||||
// Read back through the filtered list + verify group hydration.
|
||||
logs, total, err := s.GetAgentNetworkAccessLogs(ctx, LockingStrengthNone, accountID, agentNetworkTypes.AgentNetworkAccessLogFilter{Page: 1, PageSize: 50})
|
||||
require.NoError(t, err, "list must succeed")
|
||||
assert.Equal(t, int64(1), total, "one access-log row expected")
|
||||
require.Len(t, logs, 1)
|
||||
assert.ElementsMatch(t, []string{"grp-eng", "grp-oncall"}, logs[0].GroupIDs, "group ids must hydrate")
|
||||
assert.Equal(t, "sess-round-trip-1", logs[0].SessionID, "session id must persist and read back on the access-log row")
|
||||
|
||||
// Session filter narrows the access-log listing to one conversation.
|
||||
sessionID := "sess-round-trip-1"
|
||||
sessLogs, sessTotal, err := s.GetAgentNetworkAccessLogs(ctx, LockingStrengthNone, accountID,
|
||||
agentNetworkTypes.AgentNetworkAccessLogFilter{Page: 1, PageSize: 50, SessionID: &sessionID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), sessTotal, "session filter must match the one row with that session id")
|
||||
require.Len(t, sessLogs, 1)
|
||||
assert.Equal(t, entry.ID, sessLogs[0].ID, "session filter must return the matching log row")
|
||||
|
||||
bogus := "no-such-session"
|
||||
_, emptyTotal, err := s.GetAgentNetworkAccessLogs(ctx, LockingStrengthNone, accountID,
|
||||
agentNetworkTypes.AgentNetworkAccessLogFilter{Page: 1, PageSize: 50, SessionID: &bogus})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), emptyTotal, "unknown session id must match nothing")
|
||||
|
||||
// Session filter also narrows the always-on usage rows.
|
||||
sessUsage, err := s.GetAgentNetworkUsageRows(ctx, LockingStrengthNone, accountID,
|
||||
agentNetworkTypes.AgentNetworkAccessLogFilter{SessionID: &sessionID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, sessUsage, 1, "session filter must narrow usage rows to the matching session")
|
||||
assert.Equal(t, "sess-round-trip-1", sessUsage[0].SessionID, "usage row must carry the session id")
|
||||
}
|
||||
|
||||
// TestAgentNetworkUsageOverview_DailyAggregation drives GetAgentNetworkUsageRows
|
||||
// + AggregateUsageByGranularity end-to-end against a real sqlite store, with
|
||||
// two rows on the same day and one on another, plus a model filter.
|
||||
func TestAgentNetworkUsageOverview_DailyAggregation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
const accountID = "acc-anet-overview-1"
|
||||
day1 := time.Date(2026, 5, 5, 10, 0, 0, 0, time.UTC)
|
||||
day1b := time.Date(2026, 5, 5, 22, 0, 0, 0, time.UTC)
|
||||
day2 := time.Date(2026, 5, 6, 9, 0, 0, 0, time.UTC)
|
||||
|
||||
mk := func(id string, ts time.Time, model string, in, out int64, cost float64) *agentNetworkTypes.AgentNetworkUsage {
|
||||
return &agentNetworkTypes.AgentNetworkUsage{
|
||||
ID: id, AccountID: accountID, Timestamp: ts, Model: model,
|
||||
InputTokens: in, OutputTokens: out, TotalTokens: in + out, CostUSD: cost,
|
||||
}
|
||||
}
|
||||
require.NoError(t, s.CreateAgentNetworkUsage(ctx, mk("u1", day1, "gpt-4o", 100, 50, 0.10), nil))
|
||||
require.NoError(t, s.CreateAgentNetworkUsage(ctx, mk("u2", day1b, "gpt-4o", 200, 80, 0.20), nil))
|
||||
require.NoError(t, s.CreateAgentNetworkUsage(ctx, mk("u3", day2, "claude-3", 10, 5, 0.01), nil))
|
||||
|
||||
rows, err := s.GetAgentNetworkUsageRows(ctx, LockingStrengthNone, accountID, agentNetworkTypes.AgentNetworkAccessLogFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 3, "all three usage rows expected")
|
||||
|
||||
buckets := agentNetworkTypes.AggregateUsageByGranularity(rows, agentNetworkTypes.UsageGranularityDay)
|
||||
require.Len(t, buckets, 2, "two distinct days expected")
|
||||
assert.Equal(t, "2026-05-05", buckets[0].PeriodStart, "oldest-first ordering")
|
||||
assert.Equal(t, int64(300), buckets[0].InputTokens, "same-day input tokens summed")
|
||||
assert.Equal(t, int64(130), buckets[0].OutputTokens)
|
||||
assert.InDelta(t, 0.30, buckets[0].CostUSD, 1e-9, "same-day cost summed")
|
||||
assert.Equal(t, "2026-05-06", buckets[1].PeriodStart)
|
||||
assert.Equal(t, int64(15), buckets[1].TotalTokens)
|
||||
|
||||
// Model filter narrows to a single day.
|
||||
model := "claude-3"
|
||||
filtered, err := s.GetAgentNetworkUsageRows(ctx, LockingStrengthNone, accountID, agentNetworkTypes.AgentNetworkAccessLogFilter{Models: []string{model}})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, filtered, 1, "model filter must narrow rows")
|
||||
assert.Equal(t, "u3", filtered[0].ID)
|
||||
}
|
||||
|
||||
// TestDeleteOldAgentNetworkAccessLogs verifies the retention sweep removes only
|
||||
// access-log rows (and their group children) older than the cutoff, leaving
|
||||
// recent rows — and never touching usage records.
|
||||
func TestDeleteOldAgentNetworkAccessLogs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
const accountID = "acc-anet-retention-1"
|
||||
old := time.Now().UTC().AddDate(0, 0, -40)
|
||||
recent := time.Now().UTC().AddDate(0, 0, -1)
|
||||
|
||||
mkLog := func(id string, ts time.Time) (*agentNetworkTypes.AgentNetworkAccessLog, []agentNetworkTypes.AgentNetworkAccessLogGroup) {
|
||||
return &agentNetworkTypes.AgentNetworkAccessLog{
|
||||
ID: id, AccountID: accountID, ServiceID: "svc", Timestamp: ts, StatusCode: 200, Model: "gpt-4o",
|
||||
}, []agentNetworkTypes.AgentNetworkAccessLogGroup{
|
||||
{LogID: id, GroupID: "grp-eng", AccountID: accountID},
|
||||
}
|
||||
}
|
||||
oldEntry, oldGroups := mkLog("old-1", old)
|
||||
recentEntry, recentGroups := mkLog("recent-1", recent)
|
||||
require.NoError(t, s.CreateAgentNetworkAccessLog(ctx, oldEntry, oldGroups))
|
||||
require.NoError(t, s.CreateAgentNetworkAccessLog(ctx, recentEntry, recentGroups))
|
||||
// A usage row for the old request must survive the access-log sweep.
|
||||
require.NoError(t, s.CreateAgentNetworkUsage(ctx, &agentNetworkTypes.AgentNetworkUsage{
|
||||
ID: "old-1", AccountID: accountID, Timestamp: old, Model: "gpt-4o", InputTokens: 10, TotalTokens: 10,
|
||||
}, nil))
|
||||
|
||||
cutoff := time.Now().UTC().AddDate(0, 0, -30)
|
||||
deleted, err := s.DeleteOldAgentNetworkAccessLogs(ctx, accountID, cutoff)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), deleted, "only the 40-day-old log is deleted")
|
||||
|
||||
logs, total, err := s.GetAgentNetworkAccessLogs(ctx, LockingStrengthNone, accountID, agentNetworkTypes.AgentNetworkAccessLogFilter{Page: 1, PageSize: 50})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), total, "the recent log remains")
|
||||
require.Len(t, logs, 1)
|
||||
assert.Equal(t, "recent-1", logs[0].ID)
|
||||
|
||||
// Usage is untouched by the access-log retention sweep.
|
||||
usage, err := s.GetAgentNetworkUsageRows(ctx, LockingStrengthNone, accountID, agentNetworkTypes.AgentNetworkAccessLogFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, usage, 1, "usage record for the deleted log must survive")
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
)
|
||||
|
||||
// TestAgentNetworkBudgetRule_RealStore_RoundTrip is the GC-0 no-mock guard: it
|
||||
// drives the budget-rule CRUD through a real sqlite store and asserts the full
|
||||
// object — targets and the reused PolicyLimits cap shape — survives the
|
||||
// save → gorm/JSON serialize → reload round-trip, then that delete removes it
|
||||
// and a second delete reports NotFound.
|
||||
func TestAgentNetworkBudgetRule_RealStore_RoundTrip(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err, "real sqlite test store must come up")
|
||||
defer cleanup()
|
||||
|
||||
const accountID = "acc-budgetrule-1"
|
||||
rule := agentNetworkTypes.NewAccountBudgetRule(accountID)
|
||||
rule.Name = "eng-monthly"
|
||||
rule.TargetGroups = []string{"grp-eng", "grp-oncall"}
|
||||
rule.TargetUsers = []string{"user-alice"}
|
||||
rule.Limits = agentNetworkTypes.PolicyLimits{
|
||||
TokenLimit: agentNetworkTypes.PolicyTokenLimit{
|
||||
Enabled: true, GroupCap: 100_000, UserCap: 10_000, WindowSeconds: 2_592_000,
|
||||
},
|
||||
BudgetLimit: agentNetworkTypes.PolicyBudgetLimit{
|
||||
Enabled: true, GroupCapUsd: 500, UserCapUsd: 50, WindowSeconds: 2_592_000,
|
||||
},
|
||||
}
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, rule), "save must succeed")
|
||||
|
||||
got, err := s.GetAgentNetworkBudgetRuleByID(ctx, LockingStrengthNone, accountID, rule.ID)
|
||||
require.NoError(t, err, "get by id must succeed after save")
|
||||
assert.Equal(t, rule.Name, got.Name, "name must round-trip")
|
||||
assert.Equal(t, []string{"grp-eng", "grp-oncall"}, got.TargetGroups, "target groups must round-trip")
|
||||
assert.Equal(t, []string{"user-alice"}, got.TargetUsers, "target users must round-trip")
|
||||
assert.Equal(t, rule.Limits, got.Limits, "the reused PolicyLimits cap shape must round-trip intact")
|
||||
assert.True(t, got.Enabled, "enabled must round-trip")
|
||||
|
||||
list, err := s.GetAccountAgentNetworkBudgetRules(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err, "list must succeed")
|
||||
require.Len(t, list, 1, "exactly the one saved rule must be listed")
|
||||
assert.Equal(t, rule.ID, list[0].ID, "listed rule id must match")
|
||||
|
||||
require.NoError(t, s.DeleteAgentNetworkBudgetRule(ctx, accountID, rule.ID), "delete must succeed")
|
||||
|
||||
_, err = s.GetAgentNetworkBudgetRuleByID(ctx, LockingStrengthNone, accountID, rule.ID)
|
||||
assert.Error(t, err, "get after delete must report not found")
|
||||
|
||||
err = s.DeleteAgentNetworkBudgetRule(ctx, accountID, rule.ID)
|
||||
assert.Error(t, err, "deleting an absent rule must report not found")
|
||||
}
|
||||
|
||||
// TestAgentNetworkBudgetRule_RealStore_ScopedByAccount pins that rules are
|
||||
// account-scoped: a rule under one account is invisible to another.
|
||||
func TestAgentNetworkBudgetRule_RealStore_ScopedByAccount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
|
||||
ruleA := agentNetworkTypes.NewAccountBudgetRule("acc-A")
|
||||
require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, ruleA))
|
||||
|
||||
list, err := s.GetAccountAgentNetworkBudgetRules(ctx, LockingStrengthNone, "acc-B")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, list, "account B must not see account A's budget rule")
|
||||
|
||||
_, err = s.GetAgentNetworkBudgetRuleByID(ctx, LockingStrengthNone, "acc-B", ruleA.ID)
|
||||
assert.Error(t, err, "cross-account get by id must not resolve")
|
||||
}
|
||||
|
||||
// TestAgentNetworkSettings_RealStore_CollectionTogglesRoundTrip pins the GC-0
|
||||
// additive settings columns: the three collection toggles default off on a
|
||||
// fresh row and survive a save/reload at their set values.
|
||||
func TestAgentNetworkSettings_RealStore_CollectionTogglesRoundTrip(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
|
||||
const accountID = "acc-settings-toggles"
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, &agentNetworkTypes.Settings{
|
||||
AccountID: accountID,
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
Subdomain: "violet",
|
||||
}))
|
||||
|
||||
got, err := s.GetAgentNetworkSettings(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, got.EnableLogCollection, "log collection must default off")
|
||||
assert.False(t, got.EnablePromptCollection, "prompt collection must default off")
|
||||
assert.False(t, got.RedactPii, "redact pii must default off")
|
||||
|
||||
got.EnableLogCollection = true
|
||||
got.EnablePromptCollection = true
|
||||
got.RedactPii = true
|
||||
require.NoError(t, s.SaveAgentNetworkSettings(ctx, got))
|
||||
|
||||
reloaded, err := s.GetAgentNetworkSettings(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, reloaded.EnableLogCollection, "log collection must round-trip on")
|
||||
assert.True(t, reloaded.EnablePromptCollection, "prompt collection must round-trip on")
|
||||
assert.True(t, reloaded.RedactPii, "redact pii must round-trip on")
|
||||
}
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
|
||||
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
@@ -300,6 +301,11 @@ type Store interface {
|
||||
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
|
||||
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
|
||||
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
|
||||
CreateAgentNetworkAccessLog(ctx context.Context, entry *agentNetworkTypes.AgentNetworkAccessLog, groups []agentNetworkTypes.AgentNetworkAccessLogGroup) error
|
||||
CreateAgentNetworkUsage(ctx context.Context, usage *agentNetworkTypes.AgentNetworkUsage, groups []agentNetworkTypes.AgentNetworkUsageGroup) error
|
||||
GetAgentNetworkAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkAccessLog, int64, error)
|
||||
GetAgentNetworkUsageRows(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkUsage, error)
|
||||
DeleteOldAgentNetworkAccessLogs(ctx context.Context, accountID string, olderThan time.Time) (int64, error)
|
||||
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error)
|
||||
GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error)
|
||||
DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error
|
||||
@@ -329,6 +335,34 @@ type Store interface {
|
||||
GetProxyMetrics(ctx context.Context) (ProxyMetrics, error)
|
||||
|
||||
GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error)
|
||||
|
||||
// Agent Network persistence (providers, policies, guardrails, settings).
|
||||
GetAllAgentNetworkProviders(ctx context.Context, lockStrength LockingStrength) ([]*agentNetworkTypes.Provider, error)
|
||||
GetAccountAgentNetworkProviders(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Provider, error)
|
||||
GetAgentNetworkProviderByID(ctx context.Context, lockStrength LockingStrength, accountID, providerID string) (*agentNetworkTypes.Provider, error)
|
||||
SaveAgentNetworkProvider(ctx context.Context, provider *agentNetworkTypes.Provider) error
|
||||
DeleteAgentNetworkProvider(ctx context.Context, accountID, providerID string) error
|
||||
GetAccountAgentNetworkPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Policy, error)
|
||||
GetAgentNetworkPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*agentNetworkTypes.Policy, error)
|
||||
SaveAgentNetworkPolicy(ctx context.Context, policy *agentNetworkTypes.Policy) error
|
||||
DeleteAgentNetworkPolicy(ctx context.Context, accountID, policyID string) error
|
||||
GetAccountAgentNetworkGuardrails(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Guardrail, error)
|
||||
GetAgentNetworkGuardrailByID(ctx context.Context, lockStrength LockingStrength, accountID, guardrailID string) (*agentNetworkTypes.Guardrail, error)
|
||||
SaveAgentNetworkGuardrail(ctx context.Context, guardrail *agentNetworkTypes.Guardrail) error
|
||||
DeleteAgentNetworkGuardrail(ctx context.Context, accountID, guardrailID string) error
|
||||
GetAgentNetworkSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*agentNetworkTypes.Settings, error)
|
||||
GetAllAgentNetworkSettings(ctx context.Context, lockStrength LockingStrength) ([]*agentNetworkTypes.Settings, error)
|
||||
GetAgentNetworkSettingsByCluster(ctx context.Context, lockStrength LockingStrength, cluster string) ([]*agentNetworkTypes.Settings, error)
|
||||
SaveAgentNetworkSettings(ctx context.Context, settings *agentNetworkTypes.Settings) error
|
||||
IncrementAgentNetworkConsumption(ctx context.Context, accountID string, kind agentNetworkTypes.ConsumptionDimension, dimID string, windowSeconds int64, windowStart time.Time, tokensIn, tokensOut int64, costUSD float64) error
|
||||
IncrementAgentNetworkConsumptionBatch(ctx context.Context, accountID string, keys []agentNetworkTypes.ConsumptionKey, tokensIn, tokensOut int64, costUSD float64) error
|
||||
GetAgentNetworkConsumption(ctx context.Context, lockStrength LockingStrength, accountID string, kind agentNetworkTypes.ConsumptionDimension, dimID string, windowSeconds int64, windowStart time.Time) (*agentNetworkTypes.Consumption, error)
|
||||
GetAgentNetworkConsumptionBatch(ctx context.Context, lockStrength LockingStrength, accountID string, keys []agentNetworkTypes.ConsumptionKey) (map[agentNetworkTypes.ConsumptionKey]*agentNetworkTypes.Consumption, error)
|
||||
ListAgentNetworkConsumption(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Consumption, error)
|
||||
GetAccountAgentNetworkBudgetRules(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.AccountBudgetRule, error)
|
||||
GetAgentNetworkBudgetRuleByID(ctx context.Context, lockStrength LockingStrength, accountID, ruleID string) (*agentNetworkTypes.AccountBudgetRule, error)
|
||||
SaveAgentNetworkBudgetRule(ctx context.Context, rule *agentNetworkTypes.AccountBudgetRule) error
|
||||
DeleteAgentNetworkBudgetRule(ctx context.Context, accountID, ruleID string) error
|
||||
}
|
||||
|
||||
// ProxyMetrics aggregates self-hosted proxy + cluster usage signals
|
||||
|
||||
464
management/server/store/store_mock_agentnetwork.go
Normal file
464
management/server/store/store_mock_agentnetwork.go
Normal file
@@ -0,0 +1,464 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
|
||||
agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types"
|
||||
)
|
||||
|
||||
// GetAllAgentNetworkProviders mocks base method.
|
||||
func (m *MockStore) GetAllAgentNetworkProviders(ctx context.Context, lockStrength LockingStrength) ([]*agentNetworkTypes.Provider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAllAgentNetworkProviders", ctx, lockStrength)
|
||||
ret0, _ := ret[0].([]*agentNetworkTypes.Provider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAllAgentNetworkProviders indicates an expected call of GetAllAgentNetworkProviders.
|
||||
func (mr *MockStoreMockRecorder) GetAllAgentNetworkProviders(ctx, lockStrength interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllAgentNetworkProviders", reflect.TypeOf((*MockStore)(nil).GetAllAgentNetworkProviders), ctx, lockStrength)
|
||||
}
|
||||
|
||||
// GetAccountAgentNetworkProviders mocks base method.
|
||||
func (m *MockStore) GetAccountAgentNetworkProviders(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Provider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAccountAgentNetworkProviders", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*agentNetworkTypes.Provider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAccountAgentNetworkProviders indicates an expected call of GetAccountAgentNetworkProviders.
|
||||
func (mr *MockStoreMockRecorder) GetAccountAgentNetworkProviders(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountAgentNetworkProviders", reflect.TypeOf((*MockStore)(nil).GetAccountAgentNetworkProviders), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetAgentNetworkProviderByID mocks base method.
|
||||
func (m *MockStore) GetAgentNetworkProviderByID(ctx context.Context, lockStrength LockingStrength, accountID, providerID string) (*agentNetworkTypes.Provider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAgentNetworkProviderByID", ctx, lockStrength, accountID, providerID)
|
||||
ret0, _ := ret[0].(*agentNetworkTypes.Provider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAgentNetworkProviderByID indicates an expected call of GetAgentNetworkProviderByID.
|
||||
func (mr *MockStoreMockRecorder) GetAgentNetworkProviderByID(ctx, lockStrength, accountID, providerID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkProviderByID", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkProviderByID), ctx, lockStrength, accountID, providerID)
|
||||
}
|
||||
|
||||
// SaveAgentNetworkProvider mocks base method.
|
||||
func (m *MockStore) SaveAgentNetworkProvider(ctx context.Context, provider *agentNetworkTypes.Provider) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveAgentNetworkProvider", ctx, provider)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveAgentNetworkProvider indicates an expected call of SaveAgentNetworkProvider.
|
||||
func (mr *MockStoreMockRecorder) SaveAgentNetworkProvider(ctx, provider interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAgentNetworkProvider", reflect.TypeOf((*MockStore)(nil).SaveAgentNetworkProvider), ctx, provider)
|
||||
}
|
||||
|
||||
// DeleteAgentNetworkProvider mocks base method.
|
||||
func (m *MockStore) DeleteAgentNetworkProvider(ctx context.Context, accountID, providerID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAgentNetworkProvider", ctx, accountID, providerID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAgentNetworkProvider indicates an expected call of DeleteAgentNetworkProvider.
|
||||
func (mr *MockStoreMockRecorder) DeleteAgentNetworkProvider(ctx, accountID, providerID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAgentNetworkProvider", reflect.TypeOf((*MockStore)(nil).DeleteAgentNetworkProvider), ctx, accountID, providerID)
|
||||
}
|
||||
|
||||
// GetAccountAgentNetworkPolicies mocks base method.
|
||||
func (m *MockStore) GetAccountAgentNetworkPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Policy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAccountAgentNetworkPolicies", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*agentNetworkTypes.Policy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAccountAgentNetworkPolicies indicates an expected call of GetAccountAgentNetworkPolicies.
|
||||
func (mr *MockStoreMockRecorder) GetAccountAgentNetworkPolicies(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountAgentNetworkPolicies", reflect.TypeOf((*MockStore)(nil).GetAccountAgentNetworkPolicies), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetAgentNetworkPolicyByID mocks base method.
|
||||
func (m *MockStore) GetAgentNetworkPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*agentNetworkTypes.Policy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAgentNetworkPolicyByID", ctx, lockStrength, accountID, policyID)
|
||||
ret0, _ := ret[0].(*agentNetworkTypes.Policy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAgentNetworkPolicyByID indicates an expected call of GetAgentNetworkPolicyByID.
|
||||
func (mr *MockStoreMockRecorder) GetAgentNetworkPolicyByID(ctx, lockStrength, accountID, policyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkPolicyByID", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkPolicyByID), ctx, lockStrength, accountID, policyID)
|
||||
}
|
||||
|
||||
// SaveAgentNetworkPolicy mocks base method.
|
||||
func (m *MockStore) SaveAgentNetworkPolicy(ctx context.Context, policy *agentNetworkTypes.Policy) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveAgentNetworkPolicy", ctx, policy)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveAgentNetworkPolicy indicates an expected call of SaveAgentNetworkPolicy.
|
||||
func (mr *MockStoreMockRecorder) SaveAgentNetworkPolicy(ctx, policy interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAgentNetworkPolicy", reflect.TypeOf((*MockStore)(nil).SaveAgentNetworkPolicy), ctx, policy)
|
||||
}
|
||||
|
||||
// DeleteAgentNetworkPolicy mocks base method.
|
||||
func (m *MockStore) DeleteAgentNetworkPolicy(ctx context.Context, accountID, policyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAgentNetworkPolicy", ctx, accountID, policyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAgentNetworkPolicy indicates an expected call of DeleteAgentNetworkPolicy.
|
||||
func (mr *MockStoreMockRecorder) DeleteAgentNetworkPolicy(ctx, accountID, policyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAgentNetworkPolicy", reflect.TypeOf((*MockStore)(nil).DeleteAgentNetworkPolicy), ctx, accountID, policyID)
|
||||
}
|
||||
|
||||
// GetAccountAgentNetworkGuardrails mocks base method.
|
||||
func (m *MockStore) GetAccountAgentNetworkGuardrails(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Guardrail, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAccountAgentNetworkGuardrails", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*agentNetworkTypes.Guardrail)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAccountAgentNetworkGuardrails indicates an expected call of GetAccountAgentNetworkGuardrails.
|
||||
func (mr *MockStoreMockRecorder) GetAccountAgentNetworkGuardrails(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountAgentNetworkGuardrails", reflect.TypeOf((*MockStore)(nil).GetAccountAgentNetworkGuardrails), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetAgentNetworkGuardrailByID mocks base method.
|
||||
func (m *MockStore) GetAgentNetworkGuardrailByID(ctx context.Context, lockStrength LockingStrength, accountID, guardrailID string) (*agentNetworkTypes.Guardrail, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAgentNetworkGuardrailByID", ctx, lockStrength, accountID, guardrailID)
|
||||
ret0, _ := ret[0].(*agentNetworkTypes.Guardrail)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAgentNetworkGuardrailByID indicates an expected call of GetAgentNetworkGuardrailByID.
|
||||
func (mr *MockStoreMockRecorder) GetAgentNetworkGuardrailByID(ctx, lockStrength, accountID, guardrailID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkGuardrailByID", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkGuardrailByID), ctx, lockStrength, accountID, guardrailID)
|
||||
}
|
||||
|
||||
// SaveAgentNetworkGuardrail mocks base method.
|
||||
func (m *MockStore) SaveAgentNetworkGuardrail(ctx context.Context, guardrail *agentNetworkTypes.Guardrail) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveAgentNetworkGuardrail", ctx, guardrail)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveAgentNetworkGuardrail indicates an expected call of SaveAgentNetworkGuardrail.
|
||||
func (mr *MockStoreMockRecorder) SaveAgentNetworkGuardrail(ctx, guardrail interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAgentNetworkGuardrail", reflect.TypeOf((*MockStore)(nil).SaveAgentNetworkGuardrail), ctx, guardrail)
|
||||
}
|
||||
|
||||
// DeleteAgentNetworkGuardrail mocks base method.
|
||||
func (m *MockStore) DeleteAgentNetworkGuardrail(ctx context.Context, accountID, guardrailID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAgentNetworkGuardrail", ctx, accountID, guardrailID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAgentNetworkGuardrail indicates an expected call of DeleteAgentNetworkGuardrail.
|
||||
func (mr *MockStoreMockRecorder) DeleteAgentNetworkGuardrail(ctx, accountID, guardrailID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAgentNetworkGuardrail", reflect.TypeOf((*MockStore)(nil).DeleteAgentNetworkGuardrail), ctx, accountID, guardrailID)
|
||||
}
|
||||
|
||||
// GetAgentNetworkSettings mocks base method.
|
||||
func (m *MockStore) GetAgentNetworkSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*agentNetworkTypes.Settings, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAgentNetworkSettings", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].(*agentNetworkTypes.Settings)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAgentNetworkSettings indicates an expected call of GetAgentNetworkSettings.
|
||||
func (mr *MockStoreMockRecorder) GetAgentNetworkSettings(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkSettings", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkSettings), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetAgentNetworkSettingsByCluster mocks base method.
|
||||
func (m *MockStore) GetAgentNetworkSettingsByCluster(ctx context.Context, lockStrength LockingStrength, cluster string) ([]*agentNetworkTypes.Settings, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAgentNetworkSettingsByCluster", ctx, lockStrength, cluster)
|
||||
ret0, _ := ret[0].([]*agentNetworkTypes.Settings)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAgentNetworkSettingsByCluster indicates an expected call of GetAgentNetworkSettingsByCluster.
|
||||
func (mr *MockStoreMockRecorder) GetAgentNetworkSettingsByCluster(ctx, lockStrength, cluster interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkSettingsByCluster", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkSettingsByCluster), ctx, lockStrength, cluster)
|
||||
}
|
||||
|
||||
// SaveAgentNetworkSettings mocks base method.
|
||||
func (m *MockStore) SaveAgentNetworkSettings(ctx context.Context, settings *agentNetworkTypes.Settings) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveAgentNetworkSettings", ctx, settings)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveAgentNetworkSettings indicates an expected call of SaveAgentNetworkSettings.
|
||||
func (mr *MockStoreMockRecorder) SaveAgentNetworkSettings(ctx, settings interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAgentNetworkSettings", reflect.TypeOf((*MockStore)(nil).SaveAgentNetworkSettings), ctx, settings)
|
||||
}
|
||||
|
||||
// IncrementAgentNetworkConsumption mocks base method.
|
||||
func (m *MockStore) IncrementAgentNetworkConsumption(ctx context.Context, accountID string, kind agentNetworkTypes.ConsumptionDimension, dimID string, windowSeconds int64, windowStart time.Time, tokensIn, tokensOut int64, costUSD float64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IncrementAgentNetworkConsumption", ctx, accountID, kind, dimID, windowSeconds, windowStart, tokensIn, tokensOut, costUSD)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IncrementAgentNetworkConsumption indicates an expected call of IncrementAgentNetworkConsumption.
|
||||
func (mr *MockStoreMockRecorder) IncrementAgentNetworkConsumption(ctx, accountID, kind, dimID, windowSeconds, windowStart, tokensIn, tokensOut, costUSD interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementAgentNetworkConsumption", reflect.TypeOf((*MockStore)(nil).IncrementAgentNetworkConsumption), ctx, accountID, kind, dimID, windowSeconds, windowStart, tokensIn, tokensOut, costUSD)
|
||||
}
|
||||
|
||||
// GetAgentNetworkConsumption mocks base method.
|
||||
func (m *MockStore) GetAgentNetworkConsumption(ctx context.Context, lockStrength LockingStrength, accountID string, kind agentNetworkTypes.ConsumptionDimension, dimID string, windowSeconds int64, windowStart time.Time) (*agentNetworkTypes.Consumption, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAgentNetworkConsumption", ctx, lockStrength, accountID, kind, dimID, windowSeconds, windowStart)
|
||||
ret0, _ := ret[0].(*agentNetworkTypes.Consumption)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAgentNetworkConsumption indicates an expected call of GetAgentNetworkConsumption.
|
||||
func (mr *MockStoreMockRecorder) GetAgentNetworkConsumption(ctx, lockStrength, accountID, kind, dimID, windowSeconds, windowStart interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkConsumption", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkConsumption), ctx, lockStrength, accountID, kind, dimID, windowSeconds, windowStart)
|
||||
}
|
||||
|
||||
// GetAgentNetworkConsumptionBatch mocks base method.
|
||||
func (m *MockStore) GetAgentNetworkConsumptionBatch(ctx context.Context, lockStrength LockingStrength, accountID string, keys []agentNetworkTypes.ConsumptionKey) (map[agentNetworkTypes.ConsumptionKey]*agentNetworkTypes.Consumption, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAgentNetworkConsumptionBatch", ctx, lockStrength, accountID, keys)
|
||||
ret0, _ := ret[0].(map[agentNetworkTypes.ConsumptionKey]*agentNetworkTypes.Consumption)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAgentNetworkConsumptionBatch indicates an expected call of GetAgentNetworkConsumptionBatch.
|
||||
func (mr *MockStoreMockRecorder) GetAgentNetworkConsumptionBatch(ctx, lockStrength, accountID, keys interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkConsumptionBatch", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkConsumptionBatch), ctx, lockStrength, accountID, keys)
|
||||
}
|
||||
|
||||
// IncrementAgentNetworkConsumptionBatch mocks base method.
|
||||
func (m *MockStore) IncrementAgentNetworkConsumptionBatch(ctx context.Context, accountID string, keys []agentNetworkTypes.ConsumptionKey, tokensIn, tokensOut int64, costUSD float64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IncrementAgentNetworkConsumptionBatch", ctx, accountID, keys, tokensIn, tokensOut, costUSD)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IncrementAgentNetworkConsumptionBatch indicates an expected call of IncrementAgentNetworkConsumptionBatch.
|
||||
func (mr *MockStoreMockRecorder) IncrementAgentNetworkConsumptionBatch(ctx, accountID, keys, tokensIn, tokensOut, costUSD interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementAgentNetworkConsumptionBatch", reflect.TypeOf((*MockStore)(nil).IncrementAgentNetworkConsumptionBatch), ctx, accountID, keys, tokensIn, tokensOut, costUSD)
|
||||
}
|
||||
|
||||
// ListAgentNetworkConsumption mocks base method.
|
||||
func (m *MockStore) ListAgentNetworkConsumption(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Consumption, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAgentNetworkConsumption", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*agentNetworkTypes.Consumption)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAgentNetworkConsumption indicates an expected call of ListAgentNetworkConsumption.
|
||||
func (mr *MockStoreMockRecorder) ListAgentNetworkConsumption(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAgentNetworkConsumption", reflect.TypeOf((*MockStore)(nil).ListAgentNetworkConsumption), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetAccountAgentNetworkBudgetRules mocks base method.
|
||||
func (m *MockStore) GetAccountAgentNetworkBudgetRules(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.AccountBudgetRule, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAccountAgentNetworkBudgetRules", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*agentNetworkTypes.AccountBudgetRule)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAccountAgentNetworkBudgetRules indicates an expected call of GetAccountAgentNetworkBudgetRules.
|
||||
func (mr *MockStoreMockRecorder) GetAccountAgentNetworkBudgetRules(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountAgentNetworkBudgetRules", reflect.TypeOf((*MockStore)(nil).GetAccountAgentNetworkBudgetRules), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetAgentNetworkBudgetRuleByID mocks base method.
|
||||
func (m *MockStore) GetAgentNetworkBudgetRuleByID(ctx context.Context, lockStrength LockingStrength, accountID, ruleID string) (*agentNetworkTypes.AccountBudgetRule, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAgentNetworkBudgetRuleByID", ctx, lockStrength, accountID, ruleID)
|
||||
ret0, _ := ret[0].(*agentNetworkTypes.AccountBudgetRule)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAgentNetworkBudgetRuleByID indicates an expected call of GetAgentNetworkBudgetRuleByID.
|
||||
func (mr *MockStoreMockRecorder) GetAgentNetworkBudgetRuleByID(ctx, lockStrength, accountID, ruleID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkBudgetRuleByID", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkBudgetRuleByID), ctx, lockStrength, accountID, ruleID)
|
||||
}
|
||||
|
||||
// SaveAgentNetworkBudgetRule mocks base method.
|
||||
func (m *MockStore) SaveAgentNetworkBudgetRule(ctx context.Context, rule *agentNetworkTypes.AccountBudgetRule) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveAgentNetworkBudgetRule", ctx, rule)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveAgentNetworkBudgetRule indicates an expected call of SaveAgentNetworkBudgetRule.
|
||||
func (mr *MockStoreMockRecorder) SaveAgentNetworkBudgetRule(ctx, rule interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAgentNetworkBudgetRule", reflect.TypeOf((*MockStore)(nil).SaveAgentNetworkBudgetRule), ctx, rule)
|
||||
}
|
||||
|
||||
// DeleteAgentNetworkBudgetRule mocks base method.
|
||||
func (m *MockStore) DeleteAgentNetworkBudgetRule(ctx context.Context, accountID, ruleID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAgentNetworkBudgetRule", ctx, accountID, ruleID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAgentNetworkBudgetRule indicates an expected call of DeleteAgentNetworkBudgetRule.
|
||||
func (mr *MockStoreMockRecorder) DeleteAgentNetworkBudgetRule(ctx, accountID, ruleID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAgentNetworkBudgetRule", reflect.TypeOf((*MockStore)(nil).DeleteAgentNetworkBudgetRule), ctx, accountID, ruleID)
|
||||
}
|
||||
|
||||
// CreateAgentNetworkAccessLog mocks base method.
|
||||
func (m *MockStore) CreateAgentNetworkAccessLog(ctx context.Context, entry *agentNetworkTypes.AgentNetworkAccessLog, groups []agentNetworkTypes.AgentNetworkAccessLogGroup) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateAgentNetworkAccessLog", ctx, entry, groups)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateAgentNetworkAccessLog indicates an expected call of CreateAgentNetworkAccessLog.
|
||||
func (mr *MockStoreMockRecorder) CreateAgentNetworkAccessLog(ctx, entry, groups interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAgentNetworkAccessLog", reflect.TypeOf((*MockStore)(nil).CreateAgentNetworkAccessLog), ctx, entry, groups)
|
||||
}
|
||||
|
||||
// CreateAgentNetworkUsage mocks base method.
|
||||
func (m *MockStore) CreateAgentNetworkUsage(ctx context.Context, usage *agentNetworkTypes.AgentNetworkUsage, groups []agentNetworkTypes.AgentNetworkUsageGroup) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateAgentNetworkUsage", ctx, usage, groups)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateAgentNetworkUsage indicates an expected call of CreateAgentNetworkUsage.
|
||||
func (mr *MockStoreMockRecorder) CreateAgentNetworkUsage(ctx, usage, groups interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAgentNetworkUsage", reflect.TypeOf((*MockStore)(nil).CreateAgentNetworkUsage), ctx, usage, groups)
|
||||
}
|
||||
|
||||
// GetAgentNetworkAccessLogs mocks base method.
|
||||
func (m *MockStore) GetAgentNetworkAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkAccessLog, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAgentNetworkAccessLogs", ctx, lockStrength, accountID, filter)
|
||||
ret0, _ := ret[0].([]*agentNetworkTypes.AgentNetworkAccessLog)
|
||||
ret1, _ := ret[1].(int64)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// GetAgentNetworkAccessLogs indicates an expected call of GetAgentNetworkAccessLogs.
|
||||
func (mr *MockStoreMockRecorder) GetAgentNetworkAccessLogs(ctx, lockStrength, accountID, filter interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkAccessLogs", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkAccessLogs), ctx, lockStrength, accountID, filter)
|
||||
}
|
||||
|
||||
// GetAgentNetworkUsageRows mocks base method.
|
||||
func (m *MockStore) GetAgentNetworkUsageRows(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkUsage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAgentNetworkUsageRows", ctx, lockStrength, accountID, filter)
|
||||
ret0, _ := ret[0].([]*agentNetworkTypes.AgentNetworkUsage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAgentNetworkUsageRows indicates an expected call of GetAgentNetworkUsageRows.
|
||||
func (mr *MockStoreMockRecorder) GetAgentNetworkUsageRows(ctx, lockStrength, accountID, filter interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkUsageRows", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkUsageRows), ctx, lockStrength, accountID, filter)
|
||||
}
|
||||
|
||||
// DeleteOldAgentNetworkAccessLogs mocks base method.
|
||||
func (m *MockStore) DeleteOldAgentNetworkAccessLogs(ctx context.Context, accountID string, olderThan time.Time) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteOldAgentNetworkAccessLogs", ctx, accountID, olderThan)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteOldAgentNetworkAccessLogs indicates an expected call of DeleteOldAgentNetworkAccessLogs.
|
||||
func (mr *MockStoreMockRecorder) DeleteOldAgentNetworkAccessLogs(ctx, accountID, olderThan interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldAgentNetworkAccessLogs", reflect.TypeOf((*MockStore)(nil).DeleteOldAgentNetworkAccessLogs), ctx, accountID, olderThan)
|
||||
}
|
||||
|
||||
// GetAllAgentNetworkSettings mocks base method.
|
||||
func (m *MockStore) GetAllAgentNetworkSettings(ctx context.Context, lockStrength LockingStrength) ([]*agentNetworkTypes.Settings, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAllAgentNetworkSettings", ctx, lockStrength)
|
||||
ret0, _ := ret[0].([]*agentNetworkTypes.Settings)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAllAgentNetworkSettings indicates an expected call of GetAllAgentNetworkSettings.
|
||||
func (mr *MockStoreMockRecorder) GetAllAgentNetworkSettings(ctx, lockStrength interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllAgentNetworkSettings", reflect.TypeOf((*MockStore)(nil).GetAllAgentNetworkSettings), ctx, lockStrength)
|
||||
}
|
||||
@@ -466,15 +466,20 @@ func feedRouterFromListener(ctx context.Context, ln net.Listener, router *nbtcp.
|
||||
_ = ln.Close()
|
||||
}()
|
||||
|
||||
var backoff nbtcp.AcceptBackoff
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
|
||||
if ctx.Err() != nil || nbtcp.IsClosedListenerErr(err) {
|
||||
return
|
||||
}
|
||||
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v; backing off", err)
|
||||
if !backoff.Backoff(ctx) {
|
||||
return
|
||||
}
|
||||
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v", err)
|
||||
continue
|
||||
}
|
||||
backoff.Reset()
|
||||
router.HandleConn(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -533,3 +533,125 @@ MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
|
||||
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
|
||||
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
|
||||
-----END EC PRIVATE KEY-----`)
|
||||
|
||||
// scriptedAcceptListener returns pre-scripted errors from Accept(). Used
|
||||
// to drive the feedRouterFromListener tests without binding a real
|
||||
// socket — the production code path is a netstack-backed listener that
|
||||
// returns gVisor's "endpoint is in invalid state" forever after its
|
||||
// endpoint is destroyed.
|
||||
type scriptedAcceptListener struct {
|
||||
errs chan error
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener {
|
||||
s := &scriptedAcceptListener{
|
||||
errs: make(chan error, len(errs)+1),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
for _, e := range errs {
|
||||
s.errs <- e
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-s.closed:
|
||||
return nil, net.ErrClosed
|
||||
case err := <-s.errs:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Close() error {
|
||||
select {
|
||||
case <-s.closed:
|
||||
default:
|
||||
close(s.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Addr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
|
||||
}
|
||||
|
||||
// errSentinel carries a literal error message so tests can synthesise
|
||||
// the exact gVisor text without importing the netstack package.
|
||||
type errSentinel string
|
||||
|
||||
func (e errSentinel) Error() string { return string(e) }
|
||||
|
||||
// TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint is the
|
||||
// regression guard for the inbound side of the tight-loop bug. The
|
||||
// per-account plain-HTTP feeder must recognise gVisor's "endpoint is in
|
||||
// invalid state" and exit, otherwise it pegs a CPU core and floods the
|
||||
// account-scoped log with the same accept error every iteration.
|
||||
func TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80}
|
||||
router := nbtcp.NewRouter(logger, nil, addr)
|
||||
|
||||
gvisorErr := &net.OpError{
|
||||
Op: "accept",
|
||||
Net: "tcp",
|
||||
Addr: addr,
|
||||
Err: errSentinel("endpoint is in invalid state"),
|
||||
}
|
||||
ln := newScriptedAcceptListener(gvisorErr)
|
||||
defer ln.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
feedRouterFromListener(context.Background(), ln, router, logger, "acct-1")
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Expected: loop recognised the gVisor error and returned.
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("feedRouterFromListener did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFeedRouterFromListener_BacksOffOnTransientError asserts the
|
||||
// defence-in-depth path: an unknown sticky Accept error must NOT cause
|
||||
// CPU spin. The loop backs off and exits cleanly when ctx is cancelled.
|
||||
func TestFeedRouterFromListener_BacksOffOnTransientError(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80}
|
||||
router := nbtcp.NewRouter(logger, nil, addr)
|
||||
|
||||
const transientCount = 5
|
||||
errs := make([]error, transientCount)
|
||||
for i := range errs {
|
||||
errs[i] = errSentinel("transient: temporary network error")
|
||||
}
|
||||
ln := newScriptedAcceptListener(errs...)
|
||||
defer ln.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
feedRouterFromListener(ctx, ln, router, logger, "acct-1")
|
||||
}()
|
||||
time.AfterFunc(150*time.Millisecond, cancel)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Expected.
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("feedRouterFromListener did not exit on ctx cancellation — backoff or exit path broken")
|
||||
}
|
||||
|
||||
// Without backoff the 5 scripted errors would burn in microseconds.
|
||||
// With backoff the first delay alone is 5ms, so the loop must take
|
||||
// at least that long even though ctx fires at 150ms.
|
||||
elapsed := time.Since(start)
|
||||
assert.GreaterOrEqual(t, elapsed, 5*time.Millisecond,
|
||||
"loop ran without backing off — would burn CPU in production")
|
||||
}
|
||||
|
||||
@@ -128,6 +128,7 @@ type logEntry struct {
|
||||
BytesDownload int64
|
||||
Protocol Protocol
|
||||
Metadata map[string]string
|
||||
AgentNetwork bool
|
||||
}
|
||||
|
||||
// Protocol identifies the transport protocol of an access log entry.
|
||||
@@ -214,6 +215,54 @@ func (l *Logger) allowDenyLog(serviceID types.ServiceID, reason string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// usageMetadataKeys is the allowlist of metadata retained on a stripped,
|
||||
// usage-only agent-network entry. Mirrors the llm.* / cost.* keys in
|
||||
// proxy/internal/middleware/keys.go — only the dimensions management needs to
|
||||
// record a usage row (provider / model / tokens / cost / groups).
|
||||
var usageMetadataKeys = map[string]struct{}{
|
||||
"llm.provider": {},
|
||||
"llm.model": {},
|
||||
"llm.resolved_provider_id": {},
|
||||
"llm.input_tokens": {},
|
||||
"llm.output_tokens": {},
|
||||
"llm.total_tokens": {},
|
||||
"cost.usd_total": {},
|
||||
"llm.authorising_groups": {},
|
||||
}
|
||||
|
||||
// stripAgentNetworkEntryForUsage returns the entry reduced to what's needed to
|
||||
// record usage/cost: it drops request detail (host / path / source IP) and any
|
||||
// prompt capture, keeping the LLM usage metadata plus the caller identity
|
||||
// (user / auth mechanism) needed for attribution. Shipped when an
|
||||
// agent-network account has log collection disabled but usage must still be
|
||||
// collected. logEntry is passed by value, so mutating it here is safe; Metadata
|
||||
// is replaced with a fresh map rather than mutated in place.
|
||||
func stripAgentNetworkEntryForUsage(entry logEntry) logEntry {
|
||||
entry.Host = ""
|
||||
entry.Path = ""
|
||||
entry.SourceIP = netip.Addr{}
|
||||
// Drop the rest of the per-request telemetry too — a usage-only entry
|
||||
// must carry the LLM usage metadata and caller identity, nothing that
|
||||
// describes the individual request.
|
||||
entry.Method = ""
|
||||
entry.ResponseCode = 0
|
||||
entry.DurationMs = 0
|
||||
entry.BytesUpload = 0
|
||||
entry.BytesDownload = 0
|
||||
entry.Protocol = ""
|
||||
|
||||
if len(entry.Metadata) > 0 {
|
||||
stripped := make(map[string]string, len(usageMetadataKeys))
|
||||
for k := range usageMetadataKeys {
|
||||
if v, ok := entry.Metadata[k]; ok {
|
||||
stripped[k] = v
|
||||
}
|
||||
}
|
||||
entry.Metadata = stripped
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func (l *Logger) log(entry logEntry) {
|
||||
// Fire off the log request in a separate routine.
|
||||
// This increases the possibility of losing a log message
|
||||
@@ -264,6 +313,7 @@ func (l *Logger) log(entry logEntry) {
|
||||
BytesDownload: entry.BytesDownload,
|
||||
Protocol: string(entry.Protocol),
|
||||
Metadata: entry.Metadata,
|
||||
AgentNetwork: entry.AgentNetwork,
|
||||
},
|
||||
}); err != nil {
|
||||
l.logger.WithFields(log.Fields{
|
||||
|
||||
@@ -83,11 +83,23 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
BytesDownload: bytesDownload,
|
||||
Protocol: ProtocolHTTP,
|
||||
Metadata: capturedData.GetMetadata(),
|
||||
AgentNetwork: capturedData.GetAgentNetwork(),
|
||||
}
|
||||
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
|
||||
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceID(), capturedData.GetAccountID())
|
||||
|
||||
l.log(entry)
|
||||
// Emit the access log unless the matched target opted out
|
||||
// (agent-network synth targets do this when the account's
|
||||
// EnableLogCollection toggle is off). For agent-network entries we
|
||||
// still ship a stripped, usage-only record even when suppressed, so
|
||||
// usage/cost is collected regardless of the log-collection toggle;
|
||||
// request detail and prompt capture are dropped before sending.
|
||||
switch {
|
||||
case !capturedData.GetSuppressAccessLog():
|
||||
l.log(entry)
|
||||
case entry.AgentNetwork:
|
||||
l.log(stripAgentNetworkEntryForUsage(entry))
|
||||
}
|
||||
|
||||
// Track usage for cost monitoring (upload + download) by domain
|
||||
l.trackUsage(host, bytesUpload+bytesDownload)
|
||||
|
||||
185
proxy/internal/accesslog/middleware_test.go
Normal file
185
proxy/internal/accesslog/middleware_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// recorderClient is a minimal stub for the access-log gRPCClient interface. It
|
||||
// counts SendAccessLog invocations and signals on every call so tests can
|
||||
// deterministically wait for the goroutine inside Logger.log without sleeping.
|
||||
type recorderClient struct {
|
||||
mu sync.Mutex
|
||||
calls int64
|
||||
lastEntry *proto.AccessLog
|
||||
called chan struct{}
|
||||
}
|
||||
|
||||
func newRecorderClient() *recorderClient {
|
||||
return &recorderClient{called: make(chan struct{}, 16)}
|
||||
}
|
||||
|
||||
func (r *recorderClient) SendAccessLog(_ context.Context, in *proto.SendAccessLogRequest, _ ...grpc.CallOption) (*proto.SendAccessLogResponse, error) {
|
||||
r.mu.Lock()
|
||||
r.calls++
|
||||
r.lastEntry = in.GetLog()
|
||||
r.mu.Unlock()
|
||||
select {
|
||||
case r.called <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return &proto.SendAccessLogResponse{}, nil
|
||||
}
|
||||
|
||||
func (r *recorderClient) callCount() int64 {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.calls
|
||||
}
|
||||
|
||||
// newTestLogger builds a Logger backed by the supplied recorderClient. It is
|
||||
// the same constructor production uses, just with a stub gRPC client — no
|
||||
// mocks, no interface re-implementations.
|
||||
func newTestLogger(t *testing.T, client *recorderClient) *Logger {
|
||||
t.Helper()
|
||||
logger := NewLogger(client, nil, nil)
|
||||
t.Cleanup(logger.Close)
|
||||
return logger
|
||||
}
|
||||
|
||||
// TestMiddleware_SuppressAccessLog_SkipsLogSink asserts the suppression gate.
|
||||
// When the inner handler stamps SuppressAccessLog=true on CapturedData (mirrors
|
||||
// what reverseproxy does when the matched target's DisableAccessLog flag is
|
||||
// set), the middleware must NOT invoke the access-log sink. Bandwidth telemetry
|
||||
// (trackUsage) keeps running — it's the call to SendAccessLog that we gate.
|
||||
func TestMiddleware_SuppressAccessLog_SkipsLogSink(t *testing.T) {
|
||||
client := newRecorderClient()
|
||||
l := newTestLogger(t, client)
|
||||
|
||||
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cd := proxy.CapturedDataFromContext(r.Context())
|
||||
require.NotNil(t, cd, "middleware must inject CapturedData into the request context")
|
||||
cd.SetSuppressAccessLog(true)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(l.Middleware(inner))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
resp, err := http.Get(srv.URL + "/agent-network/v1/chat/completions")
|
||||
require.NoError(t, err, "GET against suppressed target must succeed")
|
||||
require.NoError(t, resp.Body.Close())
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode, "inner handler must run normally")
|
||||
|
||||
// Give the goroutine fence a beat (Logger.log dispatches in a goroutine).
|
||||
// The negative assertion needs a small window: if a send is going to
|
||||
// happen, it happens promptly.
|
||||
select {
|
||||
case <-client.called:
|
||||
t.Fatalf("access-log sink must not be invoked when SuppressAccessLog=true (got %d call(s))", client.callCount())
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(0), client.callCount(),
|
||||
"SendAccessLog must not be called for suppressed requests")
|
||||
}
|
||||
|
||||
// TestMiddleware_SuppressAccessLog_DefaultEmitsLog is the regression sanity:
|
||||
// when nothing sets SuppressAccessLog (the universal default for every
|
||||
// non-agent-network target), the middleware MUST still emit the access-log
|
||||
// entry. This is the guarantee that wires-through to the EnableLogCollection
|
||||
// gate without breaking anyone who isn't opted in.
|
||||
func TestMiddleware_SuppressAccessLog_DefaultEmitsLog(t *testing.T) {
|
||||
client := newRecorderClient()
|
||||
l := newTestLogger(t, client)
|
||||
|
||||
var innerRan atomic.Bool
|
||||
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
innerRan.Store(true)
|
||||
// Intentionally DO NOT touch SuppressAccessLog — mirrors every
|
||||
// non-agent-network target.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(l.Middleware(inner))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
resp, err := http.Get(srv.URL + "/service/healthz")
|
||||
require.NoError(t, err, "GET against default target must succeed")
|
||||
require.NoError(t, resp.Body.Close())
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode, "inner handler must run normally")
|
||||
require.True(t, innerRan.Load(), "inner handler must have run")
|
||||
|
||||
select {
|
||||
case <-client.called:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("SendAccessLog must be invoked for non-suppressed requests, none observed (calls=%d)", client.callCount())
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(1), client.callCount(),
|
||||
"non-suppressed request must produce exactly one access-log send")
|
||||
}
|
||||
|
||||
// TestMiddleware_SuppressAccessLog_PreservesUsageTracking proves the gate is
|
||||
// surgical: with SuppressAccessLog=true the access-log send is skipped, but
|
||||
// the per-domain usage tracker still records the bytes transferred. This is
|
||||
// the cost-monitoring guarantee called out in the gate's comment.
|
||||
func TestMiddleware_SuppressAccessLog_PreservesUsageTracking(t *testing.T) {
|
||||
client := newRecorderClient()
|
||||
l := newTestLogger(t, client)
|
||||
|
||||
payload := []byte("ok")
|
||||
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cd := proxy.CapturedDataFromContext(r.Context())
|
||||
require.NotNil(t, cd, "middleware must inject CapturedData")
|
||||
cd.SetSuppressAccessLog(true)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(payload)
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(l.Middleware(inner))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
resp, err := http.Get(srv.URL + "/agent-network/v1/chat/completions")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
// Allow trackUsage to land — it runs synchronously after l.log(entry) is
|
||||
// (would have been) called.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
l.usageMux.Lock()
|
||||
usage, present := l.domainUsage[hostNoPort(srv.URL)]
|
||||
l.usageMux.Unlock()
|
||||
require.True(t, present, "domain usage must be tracked even when the access-log is suppressed")
|
||||
assert.Greater(t, usage.bytesTransferred, int64(0), "bytesTransferred must include the response payload")
|
||||
assert.Equal(t, int64(0), client.callCount(),
|
||||
"SendAccessLog must remain suppressed across the response write")
|
||||
}
|
||||
|
||||
// hostNoPort extracts the host name from an httptest server URL. The
|
||||
// middleware strips the port before keying domain usage, so the test mirrors
|
||||
// that to look the entry up.
|
||||
func hostNoPort(url string) string {
|
||||
// httptest URLs are always "http://127.0.0.1:PORT".
|
||||
const prefix = "http://"
|
||||
host := url[len(prefix):]
|
||||
for i := 0; i < len(host); i++ {
|
||||
if host[i] == ':' || host[i] == '/' {
|
||||
return host[:i]
|
||||
}
|
||||
}
|
||||
return host
|
||||
}
|
||||
@@ -297,6 +297,109 @@ func TestProtect_SessionCookieGroupsPropagate(t *testing.T) {
|
||||
assert.Equal(t, groups, capturedData.GetUserGroups(), "CapturedData groups must be retained after handler completes")
|
||||
}
|
||||
|
||||
// stubTunnelValidator implements SessionValidator for the tunnel-peer
|
||||
// path. ValidateTunnelPeer returns a fixed response so tests can assert
|
||||
// how the proxy maps it onto CapturedData, and records whether the
|
||||
// fast-path actually reached management.
|
||||
type stubTunnelValidator struct {
|
||||
called bool
|
||||
resp *proto.ValidateTunnelPeerResponse
|
||||
}
|
||||
|
||||
func (s *stubTunnelValidator) ValidateSession(context.Context, *proto.ValidateSessionRequest, ...grpc.CallOption) (*proto.ValidateSessionResponse, error) {
|
||||
return nil, errors.New("not used in this test")
|
||||
}
|
||||
|
||||
func (s *stubTunnelValidator) ValidateTunnelPeer(context.Context, *proto.ValidateTunnelPeerRequest, ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
s.called = true
|
||||
return s.resp, nil
|
||||
}
|
||||
|
||||
// TestProtect_PrivateService_TunnelPeerGroupsPropagate locks the agent-network
|
||||
// auth path end-to-end at the proxy edge: a Private service must route through
|
||||
// ValidateTunnelPeer and lift the returned peer_group_ids onto CapturedData so
|
||||
// the llm_router group-authorisation pass can see them. Regression guard for
|
||||
// the failure that surfaces downstream as llm_policy.no_authorised_provider —
|
||||
// i.e. a synthesised service that reaches the proxy without private=true (so
|
||||
// this path is skipped) leaves UserGroups empty and every request is denied.
|
||||
func TestProtect_PrivateService_TunnelPeerGroupsPropagate(t *testing.T) {
|
||||
groups := []string{"grp-admins", "grp-users"}
|
||||
names := []string{"Admins", "Users"}
|
||||
validator := &stubTunnelValidator{resp: &proto.ValidateTunnelPeerResponse{
|
||||
Valid: true,
|
||||
UserId: "user-1",
|
||||
UserEmail: "user@example.com",
|
||||
SessionToken: "tunnel-session-token",
|
||||
PeerGroupIds: groups,
|
||||
PeerGroupNames: names,
|
||||
}}
|
||||
mw := NewMiddleware(log.StandardLogger(), validator, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
// Private service: no operator schemes — auth gates solely on the tunnel peer.
|
||||
require.NoError(t, mw.AddDomain("agent.example.com", nil, kp.PublicKey, time.Hour, "acct-1", "svc-1", nil, true))
|
||||
|
||||
cd := proxy.NewCapturedData("")
|
||||
cd.SetClientIP(netip.MustParseAddr("100.90.1.14")) // CGNAT tunnel source
|
||||
|
||||
var seenGroups []string
|
||||
var seenUser string
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c := proxy.CapturedDataFromContext(r.Context())
|
||||
require.NotNil(t, c, "captured data must be present in request context")
|
||||
seenGroups = c.GetUserGroups()
|
||||
seenUser = c.GetUserID()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
|
||||
return PeerIdentity{}, true
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "http://agent.example.com/v1/chat/completions", nil)
|
||||
req.RemoteAddr = "100.90.1.14:5000"
|
||||
req = req.WithContext(WithTunnelLookup(proxy.WithCapturedData(req.Context(), cd), lookup))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code, "private service must authorise a tunnel peer the validator accepts")
|
||||
assert.Equal(t, groups, seenGroups, "ValidateTunnelPeer peer_group_ids must reach CapturedData.UserGroups for llm_router authorisation")
|
||||
assert.Equal(t, "user-1", seenUser, "tunnel-peer principal must reach CapturedData")
|
||||
assert.Equal(t, groups, cd.GetUserGroups(), "groups must persist on CapturedData after the handler returns")
|
||||
}
|
||||
|
||||
// TestProtect_PrivateService_TunnelPeerDenied verifies the deny path: when
|
||||
// ValidateTunnelPeer rejects the peer, a Private service 403s and never reaches
|
||||
// the upstream handler (no fall-through to unauthenticated pass-through).
|
||||
func TestProtect_PrivateService_TunnelPeerDenied(t *testing.T) {
|
||||
validator := &stubTunnelValidator{resp: &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "not_in_group",
|
||||
}}
|
||||
mw := NewMiddleware(log.StandardLogger(), validator, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
require.NoError(t, mw.AddDomain("agent.example.com", nil, kp.PublicKey, time.Hour, "acct-1", "svc-1", nil, true))
|
||||
|
||||
cd := proxy.NewCapturedData("")
|
||||
cd.SetClientIP(netip.MustParseAddr("100.90.1.14"))
|
||||
|
||||
reached := false
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reached = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
|
||||
return PeerIdentity{}, true
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "http://agent.example.com/v1/chat/completions", nil)
|
||||
req.RemoteAddr = "100.90.1.14:5000"
|
||||
req = req.WithContext(WithTunnelLookup(proxy.WithCapturedData(req.Context(), cd), lookup))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, rec.Code, "private service must 403 when the tunnel peer is rejected")
|
||||
assert.False(t, reached, "denied private request must not reach the upstream handler")
|
||||
}
|
||||
|
||||
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
@@ -1228,22 +1331,6 @@ func TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked(t *testing.T) {
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code, "PIN-only domain should serve the login page on plain HTTP")
|
||||
}
|
||||
|
||||
// stubTunnelValidator records ValidateTunnelPeer calls so a test can
|
||||
// assert whether the fast-path reached management.
|
||||
type stubTunnelValidator struct {
|
||||
called bool
|
||||
resp *proto.ValidateTunnelPeerResponse
|
||||
}
|
||||
|
||||
func (s *stubTunnelValidator) ValidateSession(context.Context, *proto.ValidateSessionRequest, ...grpc.CallOption) (*proto.ValidateSessionResponse, error) {
|
||||
return nil, errors.New("not used in this test")
|
||||
}
|
||||
|
||||
func (s *stubTunnelValidator) ValidateTunnelPeer(context.Context, *proto.ValidateTunnelPeerRequest, ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
s.called = true
|
||||
return s.resp, nil
|
||||
}
|
||||
|
||||
// TestProtect_TunnelPeerFastPath_RequiresInboundMarker guards the
|
||||
// anti-spoof gate: a request with an RFC1918 source IP arriving on the
|
||||
// public listener (no TunnelLookupFromContext attached) must not be
|
||||
|
||||
196
proxy/internal/llm/anthropic.go
Normal file
196
proxy/internal/llm/anthropic.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AnthropicParser implements the Parser interface for the Anthropic Messages
|
||||
// and Completions APIs. Detection is substring-based to tolerate upstream
|
||||
// path rewrites.
|
||||
type AnthropicParser struct{}
|
||||
|
||||
var anthropicPathHints = []string{
|
||||
"/v1/messages",
|
||||
"/v1/complete",
|
||||
}
|
||||
|
||||
// Provider returns ProviderAnthropic.
|
||||
func (AnthropicParser) Provider() Provider { return ProviderAnthropic }
|
||||
|
||||
// ProviderName returns the stable label used for metrics and metadata.
|
||||
func (AnthropicParser) ProviderName() string { return "anthropic" }
|
||||
|
||||
// DetectFromURL reports whether the given request path looks like an
|
||||
// Anthropic API endpoint. The match is case-insensitive and substring-based.
|
||||
func (AnthropicParser) DetectFromURL(path string) bool {
|
||||
lower := strings.ToLower(path)
|
||||
for _, hint := range anthropicPathHints {
|
||||
if strings.Contains(lower, hint) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type anthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
Stream *bool `json:"stream"`
|
||||
System json.RawMessage `json:"system"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
// Legacy /v1/complete endpoint.
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
// ParseRequest extracts the model name and streaming flag from an Anthropic
|
||||
// request body. Unknown or missing fields leave the corresponding struct
|
||||
// members zero-valued.
|
||||
func (AnthropicParser) ParseRequest(body []byte) (RequestFacts, error) {
|
||||
var req anthropicRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return RequestFacts{}, fmt.Errorf("decode anthropic request: %w: %v", ErrMalformedRequest, err)
|
||||
}
|
||||
return RequestFacts{
|
||||
Model: req.Model,
|
||||
Stream: ptrDeref(req.Stream),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type anthropicResponse struct {
|
||||
Usage struct {
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
// CacheReadInputTokens and CacheCreationInputTokens are
|
||||
// ADDITIVE to InputTokens (not subset), each billed at its
|
||||
// own rate by the cost meter. cache_read is the cheaper
|
||||
// read-from-cache rate, cache_creation is the more
|
||||
// expensive write-to-cache rate.
|
||||
CacheReadInputTokens int64 `json:"cache_read_input_tokens"`
|
||||
CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
// ParseResponse decodes the non-streaming Anthropic response envelope. Status
|
||||
// codes other than 200 are treated as non-LLM responses so the caller can
|
||||
// skip cost accounting without aborting the request.
|
||||
func (AnthropicParser) ParseResponse(status int, contentType string, body []byte) (Usage, error) {
|
||||
if status != 200 {
|
||||
return Usage{}, fmt.Errorf("anthropic status %d: %w", status, ErrNotLLMResponse)
|
||||
}
|
||||
if isEventStream(contentType) {
|
||||
return Usage{}, ErrStreamingUnsupported
|
||||
}
|
||||
if !isJSON(contentType) {
|
||||
return Usage{}, fmt.Errorf("anthropic content-type %q: %w", contentType, ErrNotLLMResponse)
|
||||
}
|
||||
|
||||
var resp anthropicResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return Usage{}, fmt.Errorf("decode anthropic response: %w: %v", ErrMalformedResponse, err)
|
||||
}
|
||||
return Usage{
|
||||
InputTokens: resp.Usage.InputTokens,
|
||||
OutputTokens: resp.Usage.OutputTokens,
|
||||
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens + resp.Usage.CacheReadInputTokens + resp.Usage.CacheCreationInputTokens,
|
||||
CachedInputTokens: resp.Usage.CacheReadInputTokens,
|
||||
CacheCreationTokens: resp.Usage.CacheCreationInputTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExtractPrompt returns the user-visible prompt text from an Anthropic
|
||||
// request body. Handles the Messages API (system + messages[]) and the
|
||||
// legacy /v1/complete prompt string. Returns "" on any decode failure.
|
||||
func (AnthropicParser) ExtractPrompt(body []byte) string {
|
||||
var req anthropicRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
if len(req.System) > 0 {
|
||||
if s := decodeStringOrJoin(req.System); s != "" {
|
||||
b.WriteString("system: ")
|
||||
b.WriteString(s)
|
||||
}
|
||||
}
|
||||
for _, m := range req.Messages {
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
if m.Role != "" {
|
||||
b.WriteString(m.Role)
|
||||
b.WriteString(": ")
|
||||
}
|
||||
b.WriteString(decodeStringOrJoin(m.Content))
|
||||
}
|
||||
if b.Len() == 0 && req.Prompt != "" {
|
||||
b.WriteString(req.Prompt)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// ExtractSessionID is the body-side fallback for Anthropic. Claude Code's
|
||||
// authoritative session marker is the X-Claude-Code-Session-Id request
|
||||
// header (handled by the request-parser middleware); this only mines the
|
||||
// optional metadata.user_id for an embedded "...session_<uuid>" marker.
|
||||
// metadata.user_id on its own is a USER identifier, not a session, so the
|
||||
// whole value is deliberately NOT used — returning it would mislabel every
|
||||
// request from a user as one session. Returns "" when no session marker is
|
||||
// present.
|
||||
func (AnthropicParser) ExtractSessionID(body []byte) string {
|
||||
var req struct {
|
||||
Metadata struct {
|
||||
UserID string `json:"user_id"`
|
||||
} `json:"metadata"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return ""
|
||||
}
|
||||
if idx := strings.LastIndex(req.Metadata.UserID, "session_"); idx >= 0 {
|
||||
if session := req.Metadata.UserID[idx+len("session_"):]; session != "" {
|
||||
return session
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type anthropicMessageResponse struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
// Legacy /v1/complete response.
|
||||
Completion string `json:"completion"`
|
||||
}
|
||||
|
||||
// ExtractCompletion returns the assistant text from a non-streaming Anthropic
|
||||
// Messages or Completions response. Returns "" when status/content-type
|
||||
// indicate the body is not parseable or no text part is present.
|
||||
func (AnthropicParser) ExtractCompletion(status int, contentType string, body []byte) string {
|
||||
if status != 200 || isEventStream(contentType) || !isJSON(contentType) {
|
||||
return ""
|
||||
}
|
||||
var resp anthropicMessageResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, part := range resp.Content {
|
||||
if part.Text == "" {
|
||||
continue
|
||||
}
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString(part.Text)
|
||||
}
|
||||
if b.Len() == 0 {
|
||||
return resp.Completion
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
169
proxy/internal/llm/anthropic_test.go
Normal file
169
proxy/internal/llm/anthropic_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAnthropicDetectFromURL(t *testing.T) {
|
||||
p := AnthropicParser{}
|
||||
|
||||
cases := map[string]bool{
|
||||
"/v1/messages": true,
|
||||
"/v1/complete": true,
|
||||
"/V1/Messages": true,
|
||||
"/proxy/v1/messages?x": true,
|
||||
"/v1/chat/completions": false,
|
||||
"": false,
|
||||
}
|
||||
for path, want := range cases {
|
||||
assert.Equal(t, want, p.DetectFromURL(path), "DetectFromURL(%q)", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicParseRequest(t *testing.T) {
|
||||
p := AnthropicParser{}
|
||||
|
||||
t.Run("stream true", func(t *testing.T) {
|
||||
facts, err := p.ParseRequest([]byte(`{"model":"claude-sonnet-4-5","stream":true}`))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-sonnet-4-5", facts.Model, "model extracted")
|
||||
assert.True(t, facts.Stream, "stream flag honoured")
|
||||
})
|
||||
|
||||
t.Run("stream default", func(t *testing.T) {
|
||||
facts, err := p.ParseRequest([]byte(`{"model":"claude-sonnet-4-5"}`))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, facts.Stream, "missing stream flag defaults to false")
|
||||
})
|
||||
|
||||
t.Run("malformed", func(t *testing.T) {
|
||||
_, err := p.ParseRequest([]byte(`{"model":`))
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrMalformedRequest), "sentinel wrapped")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnthropicParseResponse(t *testing.T) {
|
||||
p := AnthropicParser{}
|
||||
|
||||
t.Run("happy fixture", func(t *testing.T) {
|
||||
body, err := os.ReadFile(filepath.Join("fixtures", "anthropic_messages.json"))
|
||||
require.NoError(t, err, "fixture must be readable")
|
||||
|
||||
usage, err := p.ParseResponse(200, "application/json", body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(123), usage.InputTokens, "input tokens extracted")
|
||||
assert.Equal(t, int64(45), usage.OutputTokens, "output tokens extracted")
|
||||
assert.Equal(t, int64(168), usage.TotalTokens, "total computed as sum")
|
||||
})
|
||||
|
||||
t.Run("streaming rejected", func(t *testing.T) {
|
||||
_, err := p.ParseResponse(200, "text/event-stream", []byte(""))
|
||||
require.ErrorIs(t, err, ErrStreamingUnsupported, "SSE responses must use the scanner")
|
||||
})
|
||||
|
||||
t.Run("non-200", func(t *testing.T) {
|
||||
_, err := p.ParseResponse(429, "application/json", []byte(`{}`))
|
||||
require.ErrorIs(t, err, ErrNotLLMResponse, "non-200 rejected as non-LLM")
|
||||
})
|
||||
|
||||
t.Run("non-json content type", func(t *testing.T) {
|
||||
_, err := p.ParseResponse(200, "text/html", []byte(`{}`))
|
||||
require.ErrorIs(t, err, ErrNotLLMResponse, "text/html treated as non-LLM")
|
||||
})
|
||||
|
||||
t.Run("malformed body", func(t *testing.T) {
|
||||
_, err := p.ParseResponse(200, "application/json", []byte(`{`))
|
||||
require.ErrorIs(t, err, ErrMalformedResponse, "bad JSON yields malformed error")
|
||||
})
|
||||
|
||||
// Anthropic's two cache fields are ADDITIVE to input_tokens (not
|
||||
// subset). The parser must surface them so the cost meter can
|
||||
// bill each bucket at its own configured rate. Total includes
|
||||
// every bucket so downstream attribution sees the full token
|
||||
// volume the request consumed.
|
||||
t.Run("cache_read_input_tokens surfaces as CachedInputTokens (additive)", func(t *testing.T) {
|
||||
body := []byte(`{"usage":{"input_tokens":256,"output_tokens":200,"cache_read_input_tokens":768}}`)
|
||||
usage, err := p.ParseResponse(200, "application/json", body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(256), usage.InputTokens, "regular input remains separate from cache buckets")
|
||||
assert.Equal(t, int64(768), usage.CachedInputTokens, "cache_read maps onto CachedInputTokens — same field carries OpenAI cached subset and Anthropic cache reads")
|
||||
assert.Zero(t, usage.CacheCreationTokens)
|
||||
assert.Equal(t, int64(256+200+768), usage.TotalTokens, "total includes every input bucket plus output — cache reads are billable tokens")
|
||||
})
|
||||
|
||||
t.Run("cache_creation_input_tokens surfaces as CacheCreationTokens (additive)", func(t *testing.T) {
|
||||
body := []byte(`{"usage":{"input_tokens":256,"output_tokens":200,"cache_creation_input_tokens":512}}`)
|
||||
usage, err := p.ParseResponse(200, "application/json", body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(256), usage.InputTokens)
|
||||
assert.Zero(t, usage.CachedInputTokens)
|
||||
assert.Equal(t, int64(512), usage.CacheCreationTokens, "cache_creation surfaces — meter applies the write-rate multiplier")
|
||||
assert.Equal(t, int64(256+200+512), usage.TotalTokens)
|
||||
})
|
||||
|
||||
t.Run("both cache buckets present", func(t *testing.T) {
|
||||
body := []byte(`{"usage":{"input_tokens":256,"output_tokens":200,"cache_read_input_tokens":768,"cache_creation_input_tokens":512}}`)
|
||||
usage, err := p.ParseResponse(200, "application/json", body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(768), usage.CachedInputTokens)
|
||||
assert.Equal(t, int64(512), usage.CacheCreationTokens)
|
||||
assert.Equal(t, int64(256+200+768+512), usage.TotalTokens, "all four buckets sum into total")
|
||||
})
|
||||
|
||||
t.Run("absent cache fields leave counts at zero", func(t *testing.T) {
|
||||
body := []byte(`{"usage":{"input_tokens":100,"output_tokens":50}}`)
|
||||
usage, err := p.ParseResponse(200, "application/json", body)
|
||||
require.NoError(t, err)
|
||||
assert.Zero(t, usage.CachedInputTokens, "no cache_read field = no cached count")
|
||||
assert.Zero(t, usage.CacheCreationTokens, "no cache_creation field = no creation count")
|
||||
assert.Equal(t, int64(150), usage.TotalTokens, "back to the simple in+out total when no cache buckets present")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnthropicExtractPrompt_Messages(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-sonnet-4-7","system":"be brief","messages":[{"role":"user","content":"hi"},{"role":"assistant","content":"yes"}]}`)
|
||||
got := AnthropicParser{}.ExtractPrompt(body)
|
||||
require.Contains(t, got, "system: be brief", "system surfaces with role label")
|
||||
require.Contains(t, got, "user: hi", "user message surfaces")
|
||||
require.Contains(t, got, "assistant: yes", "assistant message surfaces")
|
||||
}
|
||||
|
||||
func TestAnthropicExtractPrompt_LegacyComplete(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-2","prompt":"\n\nHuman: hi\n\nAssistant:"}`)
|
||||
got := AnthropicParser{}.ExtractPrompt(body)
|
||||
require.Contains(t, got, "Human: hi", "legacy prompt string surfaces")
|
||||
}
|
||||
|
||||
func TestAnthropicExtractSessionID(t *testing.T) {
|
||||
t.Run("claude code session suffix", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-opus-4-8","metadata":{"user_id":"user_abc123_account_def456_session_9f8e7d6c"},"messages":[]}`)
|
||||
assert.Equal(t, "9f8e7d6c", AnthropicParser{}.ExtractSessionID(body), "session_<id> suffix must be extracted from metadata.user_id")
|
||||
})
|
||||
t.Run("plain user_id is not treated as a session", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-opus-4-8","metadata":{"user_id":"acme-team"},"messages":[]}`)
|
||||
assert.Equal(t, "", AnthropicParser{}.ExtractSessionID(body), "a user identifier without a session marker must NOT be used as a session id")
|
||||
})
|
||||
t.Run("no metadata yields empty", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-opus-4-8","messages":[{"role":"user","content":"hi"}]}`)
|
||||
assert.Equal(t, "", AnthropicParser{}.ExtractSessionID(body), "absent metadata.user_id yields no session id")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnthropicExtractCompletion_Messages(t *testing.T) {
|
||||
body, err := os.ReadFile(filepath.Join("fixtures", "anthropic_messages.json"))
|
||||
require.NoError(t, err)
|
||||
got := AnthropicParser{}.ExtractCompletion(200, "application/json", body)
|
||||
require.NotEmpty(t, got, "anthropic fixture has assistant text")
|
||||
}
|
||||
|
||||
func TestAnthropicExtractCompletion_Streaming(t *testing.T) {
|
||||
got := AnthropicParser{}.ExtractCompletion(200, "text/event-stream", []byte(""))
|
||||
require.Empty(t, got, "streaming responses are skipped")
|
||||
}
|
||||
189
proxy/internal/llm/bedrock.go
Normal file
189
proxy/internal/llm/bedrock.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProviderNameBedrock is the stable label for the AWS Bedrock parser, used as
|
||||
// the llm.provider metadata value and the cost-meter formula selector.
|
||||
const ProviderNameBedrock = "bedrock"
|
||||
|
||||
// BedrockParser implements the Parser interface for the AWS Bedrock runtime.
|
||||
// Bedrock carries the model in the URL path (/model/{id}/{action}); the request
|
||||
// middleware extracts it there, so this parser focuses on the response shapes:
|
||||
// the vendor-native InvokeModel body (e.g. Anthropic's snake_case usage) and the
|
||||
// unified Converse body (camelCase usage).
|
||||
type BedrockParser struct{}
|
||||
|
||||
var bedrockPathHints = []string{"/invoke", "/converse"}
|
||||
|
||||
// Provider returns ProviderBedrock.
|
||||
func (BedrockParser) Provider() Provider { return ProviderBedrock }
|
||||
|
||||
// ProviderName returns the stable label used for metrics and metadata.
|
||||
func (BedrockParser) ProviderName() string { return ProviderNameBedrock }
|
||||
|
||||
// DetectFromURL reports whether the path is a Bedrock runtime model endpoint.
|
||||
func (BedrockParser) DetectFromURL(path string) bool {
|
||||
lower := strings.ToLower(path)
|
||||
if !strings.HasPrefix(lower, "/model/") {
|
||||
return false
|
||||
}
|
||||
for _, hint := range bedrockPathHints {
|
||||
if strings.Contains(lower, hint) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ParseRequest is a no-op for Bedrock: the model lives in the URL path, not the
|
||||
// body, and the streaming flag is derived from the path action. The request
|
||||
// middleware handles both via parseBedrockPath, so this returns empty facts.
|
||||
func (BedrockParser) ParseRequest([]byte) (RequestFacts, error) {
|
||||
return RequestFacts{}, nil
|
||||
}
|
||||
|
||||
// bedrockResponse captures token usage from both Bedrock response shapes:
|
||||
// InvokeModel (vendor-native; Anthropic uses snake_case + additive cache
|
||||
// buckets) and Converse (camelCase, with a precomputed total).
|
||||
type bedrockResponse struct {
|
||||
Usage struct {
|
||||
// InvokeModel (Anthropic-on-Bedrock) — snake_case.
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheReadInputTokens int64 `json:"cache_read_input_tokens"`
|
||||
CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"`
|
||||
// Converse — camelCase.
|
||||
InputTokensCamel int64 `json:"inputTokens"`
|
||||
OutputTokensCamel int64 `json:"outputTokens"`
|
||||
TotalTokensCamel int64 `json:"totalTokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
// ParseResponse decodes the non-streaming Bedrock response envelope, handling
|
||||
// both the InvokeModel and Converse usage shapes. Non-200 / non-JSON bodies are
|
||||
// treated as non-LLM responses so the caller skips cost accounting.
|
||||
func (BedrockParser) ParseResponse(status int, contentType string, body []byte) (Usage, error) {
|
||||
if status != 200 {
|
||||
return Usage{}, fmt.Errorf("bedrock status %d: %w", status, ErrNotLLMResponse)
|
||||
}
|
||||
if isAWSEventStream(contentType) || isEventStream(contentType) {
|
||||
return Usage{}, ErrStreamingUnsupported
|
||||
}
|
||||
if !isJSON(contentType) {
|
||||
return Usage{}, fmt.Errorf("bedrock content-type %q: %w", contentType, ErrNotLLMResponse)
|
||||
}
|
||||
|
||||
var resp bedrockResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return Usage{}, fmt.Errorf("decode bedrock response: %w: %v", ErrMalformedResponse, err)
|
||||
}
|
||||
inTok := firstNonZero(resp.Usage.InputTokens, resp.Usage.InputTokensCamel)
|
||||
outTok := firstNonZero(resp.Usage.OutputTokens, resp.Usage.OutputTokensCamel)
|
||||
total := resp.Usage.TotalTokensCamel
|
||||
if total == 0 {
|
||||
total = inTok + outTok + resp.Usage.CacheReadInputTokens + resp.Usage.CacheCreationInputTokens
|
||||
}
|
||||
return Usage{
|
||||
InputTokens: inTok,
|
||||
OutputTokens: outTok,
|
||||
TotalTokens: total,
|
||||
CachedInputTokens: resp.Usage.CacheReadInputTokens,
|
||||
CacheCreationTokens: resp.Usage.CacheCreationInputTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExtractPrompt returns the user-visible prompt from a Bedrock request body,
|
||||
// handling both the InvokeModel (Anthropic Messages: system + messages[]) and
|
||||
// Converse (messages[].content[].text) shapes. Returns "" on decode failure.
|
||||
func (BedrockParser) ExtractPrompt(body []byte) string {
|
||||
var req struct {
|
||||
System json.RawMessage `json:"system"`
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
if s := decodeStringOrJoin(req.System); s != "" {
|
||||
b.WriteString("system: ")
|
||||
b.WriteString(s)
|
||||
}
|
||||
for _, m := range req.Messages {
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
if m.Role != "" {
|
||||
b.WriteString(m.Role)
|
||||
b.WriteString(": ")
|
||||
}
|
||||
b.WriteString(decodeStringOrJoin(m.Content))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// ExtractCompletion returns the assistant text from a non-streaming Bedrock
|
||||
// response, handling InvokeModel (Anthropic content[].text) and Converse
|
||||
// (output.message.content[].text).
|
||||
func (BedrockParser) ExtractCompletion(status int, contentType string, body []byte) string {
|
||||
if status != 200 || isAWSEventStream(contentType) || !isJSON(contentType) {
|
||||
return ""
|
||||
}
|
||||
var resp struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Output struct {
|
||||
Message struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"output"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
appendText := func(text string) {
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString(text)
|
||||
}
|
||||
for _, p := range resp.Content {
|
||||
appendText(p.Text)
|
||||
}
|
||||
for _, p := range resp.Output.Message.Content {
|
||||
appendText(p.Text)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// ExtractSessionID has no Bedrock-native marker; session grouping relies on the
|
||||
// request headers handled by the middleware. Returns "".
|
||||
func (BedrockParser) ExtractSessionID([]byte) string { return "" }
|
||||
|
||||
// firstNonZero returns a when non-zero, else b. Folds the snake_case and
|
||||
// camelCase usage variants into a single value.
|
||||
func firstNonZero(a, b int64) int64 {
|
||||
if a != 0 {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// isAWSEventStream reports whether contentType is the AWS binary event-stream
|
||||
// framing used by Bedrock's streaming endpoints.
|
||||
func isAWSEventStream(contentType string) bool {
|
||||
return strings.Contains(strings.ToLower(contentType), "application/vnd.amazon.eventstream")
|
||||
}
|
||||
65
proxy/internal/llm/bedrock_test.go
Normal file
65
proxy/internal/llm/bedrock_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBedrockParser_ParseResponse_Invoke(t *testing.T) {
|
||||
body := []byte(`{"usage":{"input_tokens":13,"output_tokens":5,"cache_read_input_tokens":2,"cache_creation_input_tokens":4}}`)
|
||||
u, err := BedrockParser{}.ParseResponse(200, "application/json", body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(13), u.InputTokens, "invoke input tokens")
|
||||
require.Equal(t, int64(5), u.OutputTokens, "invoke output tokens")
|
||||
require.Equal(t, int64(2), u.CachedInputTokens, "invoke cache-read tokens")
|
||||
require.Equal(t, int64(4), u.CacheCreationTokens, "invoke cache-creation tokens")
|
||||
require.Equal(t, int64(13+5+2+4), u.TotalTokens, "invoke total is additive")
|
||||
}
|
||||
|
||||
func TestBedrockParser_ParseResponse_Converse(t *testing.T) {
|
||||
body := []byte(`{"output":{"message":{"content":[{"text":"pong"}]}},"usage":{"inputTokens":11,"outputTokens":3,"totalTokens":14}}`)
|
||||
u, err := BedrockParser{}.ParseResponse(200, "application/json", body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(11), u.InputTokens, "converse camelCase input tokens")
|
||||
require.Equal(t, int64(3), u.OutputTokens, "converse camelCase output tokens")
|
||||
require.Equal(t, int64(14), u.TotalTokens, "converse uses provider total")
|
||||
}
|
||||
|
||||
func TestBedrockParser_ParseResponse_StreamingUnsupported(t *testing.T) {
|
||||
_, err := BedrockParser{}.ParseResponse(200, "application/vnd.amazon.eventstream", []byte("binary"))
|
||||
require.ErrorIs(t, err, ErrStreamingUnsupported, "event-stream must route to the streaming accumulator")
|
||||
}
|
||||
|
||||
func TestBedrockParser_ParseResponse_NonSuccess(t *testing.T) {
|
||||
_, err := BedrockParser{}.ParseResponse(404, "application/json", []byte(`{"message":"gated"}`))
|
||||
require.ErrorIs(t, err, ErrNotLLMResponse, "non-200 is not an LLM response")
|
||||
}
|
||||
|
||||
func TestBedrockParser_ExtractCompletion(t *testing.T) {
|
||||
invoke := BedrockParser{}.ExtractCompletion(200, "application/json", []byte(`{"content":[{"text":"a"},{"text":"b"}]}`))
|
||||
require.Equal(t, "a\nb", invoke, "invoke completion joins content parts")
|
||||
|
||||
converse := BedrockParser{}.ExtractCompletion(200, "application/json", []byte(`{"output":{"message":{"content":[{"text":"x"}]}}}`))
|
||||
require.Equal(t, "x", converse, "converse completion reads output.message.content")
|
||||
}
|
||||
|
||||
func TestBedrockParser_ExtractPrompt(t *testing.T) {
|
||||
invoke := BedrockParser{}.ExtractPrompt([]byte(`{"messages":[{"role":"user","content":"hi"}]}`))
|
||||
require.Equal(t, "user: hi", invoke, "invoke prompt reads anthropic content string")
|
||||
|
||||
converse := BedrockParser{}.ExtractPrompt([]byte(`{"messages":[{"role":"user","content":[{"text":"hello"}]}]}`))
|
||||
require.Equal(t, "user: hello", converse, "converse prompt reads content parts")
|
||||
}
|
||||
|
||||
func TestBedrockParser_DetectFromURL(t *testing.T) {
|
||||
require.True(t, BedrockParser{}.DetectFromURL("/model/eu.anthropic.claude/invoke"), "invoke path")
|
||||
require.True(t, BedrockParser{}.DetectFromURL("/model/x/converse-stream"), "converse-stream path")
|
||||
require.False(t, BedrockParser{}.DetectFromURL("/v1/chat/completions"), "openai path is not bedrock")
|
||||
}
|
||||
|
||||
func TestBedrockParser_RegisteredByName(t *testing.T) {
|
||||
p, ok := ParserByName(ProviderNameBedrock)
|
||||
require.True(t, ok, "bedrock parser is registered")
|
||||
require.Equal(t, ProviderNameBedrock, p.ProviderName())
|
||||
}
|
||||
31
proxy/internal/llm/errors.go
Normal file
31
proxy/internal/llm/errors.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package llm
|
||||
|
||||
import "errors"
|
||||
|
||||
// Sentinel errors returned by parsers and the pricing loader. Callers use
|
||||
// errors.Is to branch on a condition without coupling to parser internals.
|
||||
var (
|
||||
// ErrUnknownProvider indicates no parser claimed the request path.
|
||||
ErrUnknownProvider = errors.New("llmobs: unknown provider")
|
||||
|
||||
// ErrUnsupportedModel indicates the response parsed successfully but the
|
||||
// model is absent from the pricing table. Token counts are still valid.
|
||||
ErrUnsupportedModel = errors.New("llmobs: unsupported model")
|
||||
|
||||
// ErrNotLLMResponse indicates the response is not a JSON success body
|
||||
// that a non-streaming parser can consume (non-200 or wrong content type).
|
||||
ErrNotLLMResponse = errors.New("llmobs: not an LLM response")
|
||||
|
||||
// ErrStreamingUnsupported indicates the caller passed an SSE response to
|
||||
// a non-streaming parser. Streaming is handled separately via the SSE
|
||||
// scanner.
|
||||
ErrStreamingUnsupported = errors.New("llmobs: streaming response requires SSE scanner")
|
||||
|
||||
// ErrMalformedResponse indicates the response body could not be decoded
|
||||
// as the provider-specific JSON schema.
|
||||
ErrMalformedResponse = errors.New("llmobs: malformed response body")
|
||||
|
||||
// ErrMalformedRequest indicates the request body could not be decoded as
|
||||
// the provider-specific JSON schema.
|
||||
ErrMalformedRequest = errors.New("llmobs: malformed request body")
|
||||
)
|
||||
17
proxy/internal/llm/fixtures/anthropic_messages.json
Normal file
17
proxy/internal/llm/fixtures/anthropic_messages.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"id": "msg_abc",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-sonnet-4-5",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Hello, world!"
|
||||
}
|
||||
],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {
|
||||
"input_tokens": 123,
|
||||
"output_tokens": 45
|
||||
}
|
||||
}
|
||||
21
proxy/internal/llm/fixtures/anthropic_stream.txt
Normal file
21
proxy/internal/llm/fixtures/anthropic_stream.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
event: message_start
|
||||
data: {"type":"message_start","message":{"id":"msg_abc","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[],"stop_reason":null,"usage":{"input_tokens":123,"output_tokens":1}}}
|
||||
|
||||
event: content_block_start
|
||||
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":", world!"}}
|
||||
|
||||
event: content_block_stop
|
||||
data: {"type":"content_block_stop","index":0}
|
||||
|
||||
event: message_delta
|
||||
data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":45}}
|
||||
|
||||
event: message_stop
|
||||
data: {"type":"message_stop"}
|
||||
|
||||
21
proxy/internal/llm/fixtures/openai_chat_completion.json
Normal file
21
proxy/internal/llm/fixtures/openai_chat_completion.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"id": "chatcmpl-abc",
|
||||
"object": "chat.completion",
|
||||
"created": 1700000000,
|
||||
"model": "gpt-4o-mini",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello, world!"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 123,
|
||||
"completion_tokens": 45,
|
||||
"total_tokens": 168
|
||||
}
|
||||
}
|
||||
24
proxy/internal/llm/fixtures/openai_responses.json
Normal file
24
proxy/internal/llm/fixtures/openai_responses.json
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"id": "resp_abc",
|
||||
"object": "response",
|
||||
"created_at": 1700000000,
|
||||
"model": "gpt-5.4",
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "ok"}]
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"input_tokens": 15,
|
||||
"input_tokens_details": {
|
||||
"cached_tokens": 0
|
||||
},
|
||||
"output_tokens": 414,
|
||||
"output_tokens_details": {
|
||||
"reasoning_tokens": 0
|
||||
},
|
||||
"total_tokens": 429
|
||||
}
|
||||
}
|
||||
24
proxy/internal/llm/fixtures/openai_responses_stream.txt
Normal file
24
proxy/internal/llm/fixtures/openai_responses_stream.txt
Normal file
@@ -0,0 +1,24 @@
|
||||
event: response.created
|
||||
data: {"type":"response.created","response":{"id":"resp_abc","object":"response","model":"gpt-5.5","usage":null}}
|
||||
|
||||
event: response.in_progress
|
||||
data: {"type":"response.in_progress","response":{"id":"resp_abc","usage":null}}
|
||||
|
||||
event: response.output_item.added
|
||||
data: {"type":"response.output_item.added","output_index":0,"item":{"type":"message","role":"assistant","content":[]}}
|
||||
|
||||
event: response.content_part.added
|
||||
data: {"type":"response.content_part.added","item_id":"msg_1","output_index":0,"content_index":0,"part":{"type":"output_text","text":""}}
|
||||
|
||||
event: response.output_text.delta
|
||||
data: {"type":"response.output_text.delta","item_id":"msg_1","output_index":0,"content_index":0,"delta":"Hello"}
|
||||
|
||||
event: response.output_text.delta
|
||||
data: {"type":"response.output_text.delta","item_id":"msg_1","output_index":0,"content_index":0,"delta":", world!"}
|
||||
|
||||
event: response.output_text.done
|
||||
data: {"type":"response.output_text.done","item_id":"msg_1","output_index":0,"content_index":0,"text":"Hello, world!"}
|
||||
|
||||
event: response.completed
|
||||
data: {"type":"response.completed","response":{"id":"resp_abc","object":"response","model":"gpt-5.5","usage":{"input_tokens":123,"input_tokens_details":{"cached_tokens":40},"output_tokens":45,"output_tokens_details":{"reasoning_tokens":12},"total_tokens":168}}}
|
||||
|
||||
8
proxy/internal/llm/fixtures/openai_stream.txt
Normal file
8
proxy/internal/llm/fixtures/openai_stream.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
data: {"id":"chatcmpl-abc","object":"chat.completion.chunk","created":1700000000,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-abc","object":"chat.completion.chunk","created":1700000000,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{"content":", world!"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-abc","object":"chat.completion.chunk","created":1700000000,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":123,"completion_tokens":45,"total_tokens":168}}
|
||||
|
||||
data: [DONE]
|
||||
|
||||
59
proxy/internal/llm/fixtures/pricing.yaml
Normal file
59
proxy/internal/llm/fixtures/pricing.yaml
Normal file
@@ -0,0 +1,59 @@
|
||||
# Realistic-pricing starter for llm_observability. Drop this into the
|
||||
# directory you point the proxy at via --plugin-data-dir, then reference it
|
||||
# from the target's plugin config:
|
||||
#
|
||||
# plugins:
|
||||
# - id: llm_observability
|
||||
# enabled: true
|
||||
# params:
|
||||
# pricing_path: pricing.yaml
|
||||
#
|
||||
# Values are USD per 1_000 tokens. Public list prices drift; treat this as a
|
||||
# starting point and keep your production copy current.
|
||||
|
||||
openai:
|
||||
# GPT-5 family
|
||||
gpt-5:
|
||||
input_per_1k: 0.00125
|
||||
output_per_1k: 0.01
|
||||
gpt-5-mini:
|
||||
input_per_1k: 0.00025
|
||||
output_per_1k: 0.002
|
||||
gpt-5-nano:
|
||||
input_per_1k: 0.00005
|
||||
output_per_1k: 0.0004
|
||||
gpt-5.4:
|
||||
input_per_1k: 0.00125
|
||||
output_per_1k: 0.01
|
||||
# GPT-4o family
|
||||
gpt-4o:
|
||||
input_per_1k: 0.0025
|
||||
output_per_1k: 0.01
|
||||
gpt-4o-mini:
|
||||
input_per_1k: 0.00015
|
||||
output_per_1k: 0.0006
|
||||
# Embeddings
|
||||
text-embedding-3-large:
|
||||
input_per_1k: 0.00013
|
||||
output_per_1k: 0
|
||||
text-embedding-3-small:
|
||||
input_per_1k: 0.00002
|
||||
output_per_1k: 0
|
||||
|
||||
anthropic:
|
||||
# Claude 4.x family
|
||||
claude-opus-4-7:
|
||||
input_per_1k: 0.015
|
||||
output_per_1k: 0.075
|
||||
claude-sonnet-4-7:
|
||||
input_per_1k: 0.003
|
||||
output_per_1k: 0.015
|
||||
claude-sonnet-4-6:
|
||||
input_per_1k: 0.003
|
||||
output_per_1k: 0.015
|
||||
claude-sonnet-4-5:
|
||||
input_per_1k: 0.003
|
||||
output_per_1k: 0.015
|
||||
claude-haiku-4-5:
|
||||
input_per_1k: 0.0008
|
||||
output_per_1k: 0.004
|
||||
412
proxy/internal/llm/openai.go
Normal file
412
proxy/internal/llm/openai.go
Normal file
@@ -0,0 +1,412 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OpenAIParser implements the Parser interface for OpenAI-compatible APIs.
|
||||
// It recognizes chat.completions, completions, embeddings, and the newer
|
||||
// responses endpoint; any proxy path-prefix stripping is tolerated by the
|
||||
// substring match in DetectFromURL.
|
||||
type OpenAIParser struct{}
|
||||
|
||||
// openAIPathHints are substring patterns that mark a request as
|
||||
// OpenAI-shaped. The bare `/chat/completions` is listed alongside
|
||||
// `/v1/chat/completions` because gateways like Cloudflare AI
|
||||
// Gateway place their own version segment before the provider
|
||||
// slug (gateway/v1/{account}/{gateway}/openai/chat/completions) —
|
||||
// the canonical `/v1/` ends up nowhere near `/chat/completions`,
|
||||
// so the `/v1/chat/completions` hint misses. `/chat/completions`
|
||||
// is OpenAI's API contract: any service accepting OpenAI bodies
|
||||
// serves at this path, so false-positive risk is negligible.
|
||||
// `/completions` (legacy), `/embeddings`, and `/responses` are
|
||||
// kept on the canonical-only path because their bare forms are
|
||||
// too generic to be safe substrings.
|
||||
var openAIPathHints = []string{
|
||||
"/v1/chat/completions",
|
||||
"/v1/completions",
|
||||
"/v1/embeddings",
|
||||
"/v1/responses",
|
||||
"/chat/completions",
|
||||
}
|
||||
|
||||
// Provider returns ProviderOpenAI.
|
||||
func (OpenAIParser) Provider() Provider { return ProviderOpenAI }
|
||||
|
||||
// ProviderName returns the stable label used for metrics and metadata.
|
||||
func (OpenAIParser) ProviderName() string { return "openai" }
|
||||
|
||||
// DetectFromURL reports whether the given request path looks like an OpenAI
|
||||
// API endpoint. The match is case-insensitive and substring-based so that a
|
||||
// reverse proxy prefix strip or rewrite does not defeat detection.
|
||||
func (OpenAIParser) DetectFromURL(path string) bool {
|
||||
lower := strings.ToLower(path)
|
||||
for _, hint := range openAIPathHints {
|
||||
if strings.Contains(lower, hint) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type openAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Stream *bool `json:"stream"`
|
||||
StreamOptions *struct {
|
||||
IncludeUsage *bool `json:"include_usage"`
|
||||
} `json:"stream_options"`
|
||||
// Chat Completions / Completions: messages[].content (string or array of
|
||||
// content parts). Responses API: input is either a string or an array of
|
||||
// items with content parts. We use json.RawMessage to defer parsing each
|
||||
// shape independently.
|
||||
Messages []openAIMessage `json:"messages"`
|
||||
Prompt json.RawMessage `json:"prompt"`
|
||||
Input json.RawMessage `json:"input"`
|
||||
}
|
||||
|
||||
type openAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
// ParseRequest extracts the model name and streaming flag from an OpenAI
|
||||
// request body. Unknown or missing fields leave the corresponding struct
|
||||
// members zero-valued.
|
||||
func (OpenAIParser) ParseRequest(body []byte) (RequestFacts, error) {
|
||||
var req openAIRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return RequestFacts{}, fmt.Errorf("decode openai request: %w: %v", ErrMalformedRequest, err)
|
||||
}
|
||||
return RequestFacts{
|
||||
Model: req.Model,
|
||||
Stream: ptrDeref(req.Stream),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// openAIResponse accepts both naming conventions in a single struct because
|
||||
// OpenAI's older Chat Completions API uses prompt_tokens/completion_tokens
|
||||
// while the newer Responses API (/v1/responses) uses input_tokens/output_tokens
|
||||
// (aligned with Anthropic). Pointer fields let us tell "absent" from "zero".
|
||||
//
|
||||
// PromptTokensDetails.CachedTokens (Chat Completions) and
|
||||
// InputTokensDetails.CachedTokens (Responses API) carry the SUBSET of
|
||||
// prompt/input tokens that hit the prompt cache. Cost-meter applies the
|
||||
// discount rate to that subset and the regular rate to the remainder so
|
||||
// we never double-bill the cached portion.
|
||||
type openAIResponse struct {
|
||||
Usage struct {
|
||||
PromptTokens *int64 `json:"prompt_tokens"`
|
||||
CompletionTokens *int64 `json:"completion_tokens"`
|
||||
InputTokens *int64 `json:"input_tokens"`
|
||||
OutputTokens *int64 `json:"output_tokens"`
|
||||
TotalTokens *int64 `json:"total_tokens"`
|
||||
PromptTokensDetails *struct {
|
||||
CachedTokens *int64 `json:"cached_tokens"`
|
||||
} `json:"prompt_tokens_details"`
|
||||
InputTokensDetails *struct {
|
||||
CachedTokens *int64 `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
// ParseResponse decodes the non-streaming OpenAI response envelope. Status
|
||||
// codes other than 200 are treated as non-LLM responses so the caller can
|
||||
// skip cost accounting without aborting the request.
|
||||
func (OpenAIParser) ParseResponse(status int, contentType string, body []byte) (Usage, error) {
|
||||
if status != 200 {
|
||||
return Usage{}, fmt.Errorf("openai status %d: %w", status, ErrNotLLMResponse)
|
||||
}
|
||||
if isEventStream(contentType) {
|
||||
return Usage{}, ErrStreamingUnsupported
|
||||
}
|
||||
if !isJSON(contentType) {
|
||||
return Usage{}, fmt.Errorf("openai content-type %q: %w", contentType, ErrNotLLMResponse)
|
||||
}
|
||||
|
||||
var resp openAIResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return Usage{}, fmt.Errorf("decode openai response: %w: %v", ErrMalformedResponse, err)
|
||||
}
|
||||
|
||||
// Responses-API names take precedence when present; fall back to the older
|
||||
// Chat Completions names. This handles both endpoints transparently
|
||||
// without forcing a per-route configuration.
|
||||
u := Usage{
|
||||
InputTokens: pickInt64(resp.Usage.InputTokens, resp.Usage.PromptTokens),
|
||||
OutputTokens: pickInt64(resp.Usage.OutputTokens, resp.Usage.CompletionTokens),
|
||||
TotalTokens: derefInt64(resp.Usage.TotalTokens),
|
||||
CachedInputTokens: openAICachedTokens(resp),
|
||||
}
|
||||
if u.TotalTokens == 0 && (u.InputTokens > 0 || u.OutputTokens > 0) {
|
||||
u.TotalTokens = u.InputTokens + u.OutputTokens
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// openAICachedTokens returns the cached-prompt subset reported by
|
||||
// either the Responses-API (input_tokens_details.cached_tokens) or
|
||||
// the Chat-Completions API (prompt_tokens_details.cached_tokens).
|
||||
// Responses-API takes precedence when both are populated.
|
||||
func openAICachedTokens(resp openAIResponse) int64 {
|
||||
// Responses-API details are authoritative when present: an explicit
|
||||
// cached_tokens of 0 must be honored, not treated as missing and
|
||||
// overridden by the Chat-Completions field (which would overstate cache).
|
||||
if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens != nil {
|
||||
return derefInt64(resp.Usage.InputTokensDetails.CachedTokens)
|
||||
}
|
||||
if resp.Usage.PromptTokensDetails != nil {
|
||||
return derefInt64(resp.Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// ExtractPrompt returns the user-visible prompt text from an OpenAI request.
|
||||
// Handles chat.completions (messages[].content), legacy completions (prompt
|
||||
// string), and the Responses API (input as string or content-part array).
|
||||
// Returns "" when nothing extractable is found.
|
||||
func (OpenAIParser) ExtractPrompt(body []byte) string {
|
||||
var req openAIRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return ""
|
||||
}
|
||||
if len(req.Messages) > 0 {
|
||||
return joinMessages(req.Messages)
|
||||
}
|
||||
if len(req.Input) > 0 {
|
||||
return extractResponsesInput(req.Input)
|
||||
}
|
||||
if len(req.Prompt) > 0 {
|
||||
return decodeStringOrJoin(req.Prompt)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractResponsesInput handles the Responses API `input` field. It is one
|
||||
// of three shapes: a plain string, an array of message items
|
||||
// ({role, content: string | [parts]}) as sent by Codex and the Responses
|
||||
// SDK, or a flat array of content parts ({type, text/input_text}). Message
|
||||
// items are flattened to "role: text" lines; items without extractable text
|
||||
// (reasoning blocks, tool calls) are skipped.
|
||||
func extractResponsesInput(raw json.RawMessage) string {
|
||||
if s, ok := tryDecodeString(raw); ok {
|
||||
return s
|
||||
}
|
||||
var items []struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
Text string `json:"text"`
|
||||
InputText string `json:"input_text"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &items); err != nil {
|
||||
return extractContentParts(raw)
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, it := range items {
|
||||
var text string
|
||||
switch {
|
||||
case len(it.Content) > 0:
|
||||
text = decodeStringOrJoin(it.Content)
|
||||
case it.Text != "":
|
||||
text = it.Text
|
||||
case it.InputText != "":
|
||||
text = it.InputText
|
||||
}
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
if it.Role != "" {
|
||||
b.WriteString(it.Role)
|
||||
b.WriteString(": ")
|
||||
}
|
||||
b.WriteString(text)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// ExtractSessionID reads the OpenAI session marker. Codex (the Responses
|
||||
// API client) stamps client_metadata.session_id on every request body;
|
||||
// plain chat.completions traffic carries no session id and yields "".
|
||||
func (OpenAIParser) ExtractSessionID(body []byte) string {
|
||||
var req struct {
|
||||
ClientMetadata struct {
|
||||
SessionID string `json:"session_id"`
|
||||
} `json:"client_metadata"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return ""
|
||||
}
|
||||
return req.ClientMetadata.SessionID
|
||||
}
|
||||
|
||||
type openAIChatChoice struct {
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
} `json:"message"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type openAIChatResponse struct {
|
||||
Choices []openAIChatChoice `json:"choices"`
|
||||
// Responses API: output[].content[].text
|
||||
Output []struct {
|
||||
Type string `json:"type"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
Text string `json:"text"`
|
||||
} `json:"output"`
|
||||
OutputText string `json:"output_text"`
|
||||
}
|
||||
|
||||
// ExtractCompletion returns the assistant text from a non-streaming OpenAI
|
||||
// response. Handles chat.completions (choices[].message.content), legacy
|
||||
// completions (choices[].text), and Responses API (output[].content[].text
|
||||
// or the convenience output_text field).
|
||||
func (OpenAIParser) ExtractCompletion(status int, contentType string, body []byte) string {
|
||||
if status != 200 || isEventStream(contentType) || !isJSON(contentType) {
|
||||
return ""
|
||||
}
|
||||
var resp openAIChatResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return ""
|
||||
}
|
||||
if resp.OutputText != "" {
|
||||
return resp.OutputText
|
||||
}
|
||||
for _, c := range resp.Choices {
|
||||
if len(c.Message.Content) > 0 {
|
||||
if s := decodeStringOrJoin(c.Message.Content); s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
if c.Text != "" {
|
||||
return c.Text
|
||||
}
|
||||
}
|
||||
for _, o := range resp.Output {
|
||||
if o.Text != "" {
|
||||
return o.Text
|
||||
}
|
||||
if len(o.Content) > 0 {
|
||||
if s := extractContentParts(o.Content); s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// joinMessages flattens a chat.completions messages array into a single
|
||||
// "role: content" string per message, separated by newlines. Roles surface
|
||||
// system/user/assistant context which is useful for log review.
|
||||
func joinMessages(msgs []openAIMessage) string {
|
||||
var b strings.Builder
|
||||
for i, m := range msgs {
|
||||
if i > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
if m.Role != "" {
|
||||
b.WriteString(m.Role)
|
||||
b.WriteString(": ")
|
||||
}
|
||||
b.WriteString(decodeStringOrJoin(m.Content))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// extractContentParts handles the Responses-API content shape, which is
|
||||
// either a single string or an array of {type, text} parts. text and
|
||||
// input_text both carry user-facing content.
|
||||
func extractContentParts(raw json.RawMessage) string {
|
||||
if s, ok := tryDecodeString(raw); ok {
|
||||
return s
|
||||
}
|
||||
var parts []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
InputText string `json:"input_text"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
// Last-ditch: array of strings.
|
||||
var arr []string
|
||||
if json.Unmarshal(raw, &arr) == nil {
|
||||
return strings.Join(arr, "\n")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, p := range parts {
|
||||
var text string
|
||||
switch {
|
||||
case p.Text != "":
|
||||
text = p.Text
|
||||
case p.InputText != "":
|
||||
text = p.InputText
|
||||
}
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString(text)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// decodeStringOrJoin accepts either a JSON string or a content-parts array
|
||||
// (chat.completions multimodal) and returns a flat string. Multimodal parts
|
||||
// are separated by newlines; non-text parts are skipped.
|
||||
func decodeStringOrJoin(raw json.RawMessage) string {
|
||||
if s, ok := tryDecodeString(raw); ok {
|
||||
return s
|
||||
}
|
||||
return extractContentParts(raw)
|
||||
}
|
||||
|
||||
func tryDecodeString(raw json.RawMessage) (string, bool) {
|
||||
if len(raw) == 0 {
|
||||
return "", false
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// pickInt64 returns the first non-nil pointer's value. Used to prefer one
|
||||
// naming convention while transparently falling back to another.
|
||||
func pickInt64(preferred, fallback *int64) int64 {
|
||||
if preferred != nil {
|
||||
return *preferred
|
||||
}
|
||||
return derefInt64(fallback)
|
||||
}
|
||||
|
||||
func derefInt64(v *int64) int64 {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
func ptrDeref(b *bool) bool {
|
||||
if b == nil {
|
||||
return false
|
||||
}
|
||||
return *b
|
||||
}
|
||||
|
||||
func isEventStream(contentType string) bool {
|
||||
return strings.Contains(strings.ToLower(contentType), "text/event-stream")
|
||||
}
|
||||
|
||||
func isJSON(contentType string) bool {
|
||||
lower := strings.ToLower(contentType)
|
||||
return strings.Contains(lower, "application/json") || strings.Contains(lower, "+json")
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user