From 3be90f06b25d2e7ff044263f39e01a13e854a65a Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 1 Jul 2026 12:31:46 +0300 Subject: [PATCH 01/19] [management] Add peer expiration reason to activity meta (#6619) --- management/server/account.go | 6 +++--- management/server/peer.go | 11 ++++++++++- management/server/user.go | 8 +++++--- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 34220ed3f..2c57c4637 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -689,7 +689,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID) - if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil { + if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers, peerExpirationSessionExpired); err != nil { log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", accountID) return peerSchedulerRetryInterval, true } @@ -724,7 +724,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID) - if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers); err != nil { + if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers, peerExpirationInactivity); err != nil { log.Errorf("failed updating account peers while expiring peers for account %s", accountID) return peerSchedulerRetryInterval, true } @@ -1949,7 +1949,7 @@ func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, account } } if len(peers) > 0 { - err := am.expireAndUpdatePeers(ctx, accountID, peers) + err := am.expireAndUpdatePeers(ctx, accountID, peers, peerExpirationValidationFailed) if err != nil { log.WithContext(ctx).Errorf("failed to expire and update invalidated peers for account %s: %v", accountID, err) return diff --git a/management/server/peer.go b/management/server/peer.go index 440e90044..32bf9feea 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -34,7 +34,16 @@ import ( "github.com/netbirdio/netbird/version" ) -const remoteJobsMinVer = "0.64.0" +type peerExpirationReason string + +const ( + remoteJobsMinVer = "0.64.0" + + peerExpirationSessionExpired peerExpirationReason = "session expiration" + peerExpirationInactivity peerExpirationReason = "inactivity timeout" + peerExpirationValidationFailed peerExpirationReason = "failed integration validation" + peerExpirationUserBlocked peerExpirationReason = "blocked owner account" +) // GetPeers returns peers visible to the user within an account. // Users with "peers:read" see all peers. Otherwise, users see only their own peers, or none if restricted by account settings. diff --git a/management/server/user.go b/management/server/user.go index 666d6d178..b4b9ebe01 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -675,7 +675,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if len(peersToExpire) > 0 { - if err := am.expireAndUpdatePeers(ctx, accountID, peersToExpire); err != nil { + if err := am.expireAndUpdatePeers(ctx, accountID, peersToExpire, peerExpirationUserBlocked); err != nil { log.WithContext(ctx).Errorf("failed update expired peers: %s", err) return nil, err } @@ -1118,7 +1118,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a } // expireAndUpdatePeers expires all peers of the given user and updates them in the account -func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { +func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer, reason peerExpirationReason) error { log.WithContext(ctx).Debugf("Expiring %d peers for account %s", len(peers), accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { @@ -1145,10 +1145,12 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if err := am.Store.SavePeerStatus(ctx, accountID, peer.ID, *peer.Status); err != nil { return err } + meta := peer.EventMeta(dnsDomain) + meta["reason"] = string(reason) am.StoreEvent( ctx, peer.UserID, peer.ID, accountID, - activity.PeerLoginExpired, peer.EventMeta(dnsDomain), + activity.PeerLoginExpired, meta, ) } From 92a66cdd202407e4114c64ebfea42b6ed50c5377 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 1 Jul 2026 12:45:14 +0200 Subject: [PATCH 02/19] [management,proxy,client] 0.74.0 version (#6563) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [management,proxy] Agent network: per-account LLM gateway (policy, metering, multi-provider) (#6555) * [agent-network] Shared proto, OpenAPI schema, and generated types * [agent-network] Management: store, manager, synthesizer, policy engine, provider catalog, HTTP/gRPC API Adds the account-scoped agent-network module: provider/policy/budget CRUD and store, the reverse-proxy service synthesizer, policy selection + limit enforcement, the provider catalog (incl. Vertex AI and AWS Bedrock entries), and the management HTTP + proxy gRPC surfaces. * [management] Fix agent-network proxy-peer fan-out on affected-peer recompute The affected-peers resolver loaded only persisted reverse-proxy services, but agent-network services are synthesized on demand and never persisted. As a result the embedded proxy peer was never folded into the affected set when a client's group changed, so the proxy received no network-map update for a newly authorised client and rejected its handshake until a full resync (restart). loadProxyServices now merges the synthesized agent-network services (injected via a registration hook to avoid an import cycle), so proxy peers learn newly authorised clients immediately. * [proxy] Reverse-proxy middleware framework, chain, and request plumbing The per-target middleware chain (slots, dispatcher, mutation gate, metadata merger), body capture, access-log terminal sink, and the proxy wiring that builds + runs chains for synthesized agent-network services. * [proxy] LLM parsers, pricing, and builtin middlewares (OpenAI, Anthropic, Vertex AI, AWS Bedrock) Request/response parsers and SSE/event-stream metering, the embedded pricing table, and the builtin middleware set: request parser, router, policy limit-check/record, cost meter, guardrail, identity inject, response parser. Includes the path-routed providers — Google Vertex AI (keyfile:: service-account OAuth minting) and AWS Bedrock (bearer auth, invoke/converse/streaming, optional /bedrock prefix) — plus the Models allowlist and unmeterable-publisher deny. * [proxy] IPv6 in-place apply and TCP accept-loop hardening on netstack listeners * [agent-network] End-to-end test suite, module docs, and deployment preset * [agent-network] Fix codespell typos and exclude false positives - labelgen word pool: vermillion -> vermilion, racoon -> raccoon. - codespell ignore list: add flate (Go compress/flate package), recordin (a test-local identifier), and unparseable (a valid alternative spelling used consistently across identifiers + a metadata-value constant). * [management] Set LastSeen on injected proxy peer in realstack test (MySQL strict-mode) The injected embedded proxy peer had a PeerStatus with a zero LastSeen, which serializes to '0000-00-00' and is rejected by MySQL in strict mode (SQLite tolerates it). Set LastSeen to a valid time so SaveAccount succeeds on both engines. * [agent-network] Remove e2e shell-script suite from this branch The end-to-end shell scripts under scripts/e2e/ are maintained in a separate testing suite and are not part of this change set. * [agent-network] Polish module docs: remove internal review scaffolding, fix links, verify diagrams Strip PR-review framing, commit references, absolute paths, and stale internal references from the agent-network module docs; fix broken relative links; verify all diagrams against the current architecture. Remove the internal AI-reviewer prompt file. * [management] Refine session expiration handling to support 3-state encoding for SSO deadlines * [agent-network] Relocate agentnetwork package to internals/modules Move management/server/agentnetwork (and its catalog/, labelgen/, types/ subpackages) to management/internals/modules/agentnetwork, alongside the reverse-proxy module, and rewrite all importers. Pure relocation: package names, the synthesizer + affectedpeers registration hook, and store access (shared store.Store) are unchanged, so no import cycle is introduced (affectedpeers still depends only on the agentnetwork/types leaf). * [agent-network] Co-locate HTTP handlers in the module (RegisterEndpoints) Move the agent-network HTTP handlers from server/http/handlers/agentnetwork into the module at internals/modules/agentnetwork/handlers (package handlers) and rename the entrypoint AddEndpoints -> RegisterEndpoints, matching the reverse-proxy module convention. Wiring in http/handler.go updated accordingly. * Update getting started to point to rc when agent network enabled * Add a reference to a commercial license * Fix docs localhost link * Fix docs localhost link * Add private services domain note * [management] Add agent-network telemetry metrics (#6561) Surface agent-network adoption and usage in the self-hosted metrics worker: distinct accounts, providers, policies, budget rules, accounts with log collection enabled, and aggregated input/output tokens plus cost. Tokens and cost are summed from agent_network_request_usage (the always-written per-request ledger) so the figures are accurate regardless of the log-collection toggle and carry no double-counting. All values come from a handful of indexed aggregate queries run only on the worker's periodic tick. Adds store.AgentNetworkMetrics with GetAgentNetworkMetrics on the Store interface, the SqlStore implementation, and a zero-valued FileStore stub. * Update NetBird server and proxy image versions to 0.74.0-rc.2 * [management,proxy] Reduce agent-network cognitive complexity (#6566) Address the SonarCloud quality-gate findings in new agent-network code by extracting focused helpers. No behavior change. - synthesizer.go: split buildIdentityInjectConfigJSON into per-shape rule builders; extract mergeGuardrail from mergeGuardrails to cut nesting depth. - llm_identity_inject: extract injectionEmitsAnything validation predicate from New. - llm_response_parser/streaming.go: extract applyOpenAIStreamUsage and applyAnthropicStreamUsage (via a named anthropicStreamUsage type) and simplify the OpenAI scanner loop. - reverseproxy.go: decompose ServeHTTP into serveRouteError, buildTargetContext, serveDirect, serveWithChain, captureRequestForChain, serveDeny, newResponseWriter, observeResponse, and forwardUpstream, preserving the defer ordering so response observation still reads the captured writer before it is released. * [management] Move agent-network access-log ingest into the agentnetwork module (#6568) The agent-network access-log ingest path (metaKey wire contract, flatten, usage derivation, and the dual-write of the usage ledger + settings-gated full row) lived in the reverseproxy accesslogs manager, even though the agentnetwork module already owns the rest of that domain — types, read (ListAccessLogs / GetUsageOverview), the budget-counter writes, and retention cleanup. Move it next to the rest: a stateless agentnetwork.IngestAccessLog(ctx, store, entry) that the reverseproxy SaveAccessLog delegates to when the entry is agent-network. Removes the agentNetworkTypes import from the reverseproxy manager. No behavior change; the write/read table separation is unchanged. Adds real-store coverage for the disable->enable log-collection toggle (usage ledger always written, full row gated) plus the metadata parse and group-dedup helpers, which previously had no dedicated tests. * Add session view support in the access log * [management,proxy] Container-based agent-network e2e harness (#6577) * [e2e] Add container-based agent-network e2e harness (Pillar 1) Introduce a self-contained, OIDC-free e2e harness that stands up NetBird in containers, so suites no longer depend on the hand-maintained Tilt stack or a real IdP. - harness brings up the combined server (management + signal + relay + STUN + embedded IdP) in a single container built from combined/Dockerfile.multistage, and mints an admin PAT through the unauthenticated /api/setup bootstrap (NB_SETUP_PAT_ENABLED). API access goes through the existing shared/management/client/rest typed client. - the image is built via the docker CLI (BuildKit) so the Dockerfile's cache mounts are honored; testcontainers then runs the tagged image. - everything is behind the `e2e` build tag so normal builds and unit tests never pull in testcontainers. Adds BuildKit cache mounts to combined/Dockerfile.multistage so source changes recompile incrementally rather than from scratch. Pillar 1 proven by TestCombinedBootstrap: server builds, boots, mints a PAT, and the PAT authenticates a real management API call. * [e2e] Add management-side agent-network scenarios (Pillar 2) Port the API-driven agent-network scenarios from the bash suites to Go, sharing one combined server per package run (TestMain) with each test owning its resource cleanup. Drives the /api/agent-network/* endpoints through the shared REST client's NewRequest primitive with the generated api types. Scenarios: - provider lifecycle (create/get/list/delete + 404 after delete) - provider validation (missing api_key, unknown catalog id → 4xx) - settings collection-toggle round-trip with cluster/subdomain immutability - policy window floor (reject <60s enabled limit, accept at 60s) - consumption read endpoint returns an array All deterministic and dependency-free (dummy provider keys; no upstream calls), so they run headless in CI. * [e2e] Add live chat-through-proxy scenario (Pillar 3) Stand up the full agent-network data path in containers and drive a real chat-completion through the gateway: - harness: a shared docker network (combined server reachable by alias), a proxy container built from the published reverse-proxy image (NB_PROXY_PRIVATE, NB_PROXY_ALLOW_INSECURE, NB_RELAY_TRANSPORT=ws to match the combined server's WS-multiplexed relay) with a generated self-signed wildcard cert, and a netbird client container that joins via a setup key. - the combined image, proxy image, and client image default to the published rc.2 releases (overridable via NB_E2E_*_IMAGE; a bare local tag is built from source instead). Geolocation download is disabled so the server starts without external fetches. - one shared domain is used for the management exposed address, the proxy domain, and the agent-network cluster; the proxy token is minted via the server CLI (global) to match the manual install. TestChatCompletionThroughProxy provisions provider+policy+group+setup key, runs proxy+client, drives an OpenAI chat-completion through the tunnel, and asserts a 200 plus the ingested access-log row. Requires OPENAI_TOKEN (skips otherwise). The provider must be created with enabled=true explicitly — the create default is false despite the API doc. * [e2e] Run the live chat scenario across a provider matrix Replace the single-provider chat test with a data-driven matrix that runs the same scenario through every provider whose credentials are present in the environment (keys/URLs sourced from ~/.llm-keys locally, Actions secrets in CI): - OpenAI (chat), Anthropic (messages), Vercel, OpenRouter, Cloudflare (OpenAI-compatible gateways), and Bedrock (path-routed, bearer, via the messages shape) — covering both wire shapes and the gateway routing. - all providers are created enabled with a unique model string so the proxy's connect-time snapshot carries them all and model->provider routing is unambiguous (provider toggles after connect don't reconcile to a connected proxy). - the client supports both wire shapes (/v1/chat/completions and /v1/messages); Cloudflare gets the openai provider segment appended to its gateway URL. Each provider must return 200 through the tunnel and produce an ingested access-log row. Vertex is intentionally excluded from the uniform matrix: it needs a bespoke rawPredict request shape rather than the shared chat/messages path, so it warrants a dedicated scenario. * [ci] Add manual workflow for the agent-network e2e suite The e2e suite (build tag `e2e`) stands up the combined server + proxy + client in Docker and drives live chat-completions, so it is slow and needs provider credentials. Gate it out of normal CI (it already is, via the build tag) and run it on demand via workflow_dispatch. Provider scenarios skip when their secret is unset, so it degrades gracefully. * [e2e] Add Vertex to the provider matrix; run e2e on ubuntu-latest Vertex (Anthropic-on-Vertex) doesn't share the chat/messages wire shapes: the model travels in a rawPredict path and the proxy mints the service account's OAuth token. Add a Vertex client method that posts /v1/projects//locations//publishers/anthropic/models/:rawPredict with the Vertex anthropic_version body, and wire it into the matrix as a path-routed provider (created without a models array). It is keyed off GOOGLE_VERTEX_SA_BASE64 + GOOGLE_VERTEX_PROJECT (region defaults to "global", model to a pinned claude snapshot, both overridable). Also bump the e2e workflow runner to ubuntu-latest and add the Vertex secrets. * Add docker/docker and docker/go-connections as direct dependencies in go.mod * [ci] Trigger agent-network e2e workflow on push to main and pull requests * [e2e] Fix proxy cert permission denied on Linux CI runners The proxy bind-mounts a temp dir of self-signed certs. MkdirTemp creates it 0700 and the key was 0600, which Docker Desktop on macOS ignores but a non-root proxy container on Linux runners cannot traverse/read, so the cert watcher failed with "open /certs/tls.crt: permission denied" and the container exited. Widen the cert dir to 0755 and write the throwaway key 0644 so the proxy uid can read the bind-mounted material. * [e2e] Build images from source by default instead of pulling rc.2 The agent-network code under test lives in this branch, so the e2e should exercise it rather than a frozen published release. Flip the harness default: combined/proxy/client are now built from their in-repo Dockerfiles (combined/Dockerfile.multistage, proxy/Dockerfile.multistage, e2e/harness/Dockerfile.client) under local tags. Pulling a published image stays available by setting NB_E2E_*_IMAGE to a registry reference. Builds now go through buildx --load so the Dockerfile cache mounts are honored and the result is loaded for testcontainers. The CI workflow adds a container-driver builder and a local layer cache (NB_E2E_BUILDX_CACHE) persisted via actions/cache, which caches the base/apt/dep-download layers across runs. The Go compile still re-runs each time, as BuildKit mount caches cannot be exported to the GitHub cache. * [e2e] Cover real providers in lifecycle + assert real consumption metering - TestProviderLifecycle now runs per available real provider (create → get → list → delete → 404) instead of a single dummy provider, exercising each catalog's create and field round-trip. Create is offline, so it stays fast and burns no provider quota; falls back to a synthetic OpenAI provider when no keys are set. - TestProvidersMatrix attaches a token limit (high caps, 60s window) to its policy, which switches on usage metering, and asserts consumption rows are recorded with positive token counts after the live traffic. Consumption is account-scoped (keyed by source group / user and window, not per provider), so the assertion is aggregate. - TestProviderValidation gains invalid-upstream and blank-name cases. Create validation is uniform across catalogs (no per-provider required-field rules), so per-provider rejection cases would be redundant. * [e2e] Assert session id propagates per provider Each matrix request now sends a unique session id as the universal x-session-id header and asserts it round-trips into that provider's access-log row. This guards the session-grouping contract end to end for every provider (header extraction runs in llm_request_parser ahead of the parser-specific body extraction, so it is provider-agnostic). * [e2e] Drop accidentally committed sync-phases dashboard netbird-sync-phases.json was swept into the Pillar 1 commit by a broad git add; it belongs to the unrelated sync-phases metrics work, not this e2e harness. Remove it from the branch so the PR diff is scoped to the e2e changes. * [e2e] Revert accidentally committed sync-phase ingest spec The netbird_sync_phase measurement spec in metrics ingest was swept into the Pillar 1 commit; it belongs to the unrelated sync-phases metrics work, not this e2e harness. Its emission side never landed here, so the spec was orphaned anyway. Restore ingest/main.go to its origin/main state. * Fix golint issues * Fix sonar * Add access log session test * Fix access log tests --------- Co-authored-by: braginini Co-authored-by: Zoltan Papp --- .github/workflows/agent-network-e2e.yml | 69 + .github/workflows/golangci-lint.yml | 2 +- combined/Dockerfile.multistage | 10 +- docs/agent-networks/00-overview.md | 109 + docs/agent-networks/01-end-to-end-flows.md | 217 ++ docs/agent-networks/README.md | 66 + docs/agent-networks/modules/10-shared-api.md | 105 + .../modules/20-management-store.md | 112 + .../modules/21-management-agentnetwork.md | 225 ++ .../modules/22-management-handlers-wiring.md | 203 ++ .../modules/30-proxy-middleware-framework.md | 215 ++ .../modules/31-proxy-middleware-builtin.md | 365 +++ .../modules/32-proxy-llm-parsers.md | 392 ++++ .../modules/33-proxy-runtime.md | 194 ++ docs/agent-networks/modules/40-dashboard.md | 228 ++ .../modules/50-path-routed-providers.md | 251 ++ e2e/agentnetwork/bootstrap_test.go | 30 + e2e/agentnetwork/chat_test.go | 281 +++ e2e/agentnetwork/main_test.go | 46 + e2e/agentnetwork/management_test.go | 221 ++ e2e/harness/Dockerfile.client | 24 + e2e/harness/agentnetwork.go | 130 ++ e2e/harness/bootstrap.go | 47 + e2e/harness/cert.go | 66 + e2e/harness/client.go | 256 ++ e2e/harness/combined.go | 243 ++ e2e/harness/config.go | 26 + e2e/harness/doc.go | 13 + e2e/harness/paths.go | 29 + e2e/harness/proxy.go | 122 + go.mod | 6 +- .../network_map/controller/controller.go | 36 +- .../network_map/controller/repository.go | 10 + .../modules/agentnetwork/accesslog_ingest.go | 215 ++ .../accesslog_ingest_realstore_test.go | 124 + .../accesslog_sessions_realstore_test.go | 343 +++ .../agentnetwork/affectedpeers_hook.go | 15 + .../modules/agentnetwork/catalog/catalog.go | 749 ++++++ .../handlers/access_log_handler.go | 134 ++ .../agentnetwork/handlers/budget_handler.go | 172 ++ .../handlers/budget_handler_test.go | 131 ++ .../handlers/consumption_handler.go | 53 + .../handlers/guardrails_handler.go | 171 ++ .../agentnetwork/handlers/handlers_test.go | 256 ++ .../agentnetwork/handlers/policies_handler.go | 228 ++ .../handlers/providers_handler.go | 217 ++ .../agentnetwork/handlers/settings_handler.go | 74 + .../modules/agentnetwork/labelgen/labelgen.go | 66 + .../agentnetwork/labelgen/labelgen_test.go | 101 + .../modules/agentnetwork/labelgen/words.go | 136 ++ .../internals/modules/agentnetwork/manager.go | 911 ++++++++ .../modules/agentnetwork/policyselect.go | 660 ++++++ .../policyselect_account_realstore_test.go | 181 ++ .../policyselect_realstore_test.go | 214 ++ .../modules/agentnetwork/policyselect_test.go | 641 +++++ .../modules/agentnetwork/reconcile.go | 131 ++ .../modules/agentnetwork/reconcile_test.go | 232 ++ .../modules/agentnetwork/synthesizer.go | 1083 +++++++++ .../synthesizer_guardrail_realstore_test.go | 178 ++ ...nthesizer_log_collection_realstore_test.go | 70 + ...ynthesizer_parser_redact_realstore_test.go | 145 ++ .../synthesizer_realstore_test.go | 174 ++ .../modules/agentnetwork/synthesizer_test.go | 1098 +++++++++ .../modules/agentnetwork/types/accesslog.go | 289 +++ .../agentnetwork/types/accesslogfilter.go | 249 ++ .../modules/agentnetwork/types/budgetrule.go | 106 + .../modules/agentnetwork/types/consumption.go | 69 + .../agentnetwork/types/consumption_test.go | 141 ++ .../modules/agentnetwork/types/guardrail.go | 120 + .../modules/agentnetwork/types/policy.go | 192 ++ .../modules/agentnetwork/types/provider.go | 252 ++ .../modules/agentnetwork/types/settings.go | 78 + .../modules/agentnetwork/types/usage.go | 47 + .../agentnetwork/types/usageoverview.go | 96 + .../modules/agentnetwork/wire_shape_test.go | 109 + management/internals/modules/peers/manager.go | 56 +- .../reverseproxy/accesslogs/accesslogentry.go | 5 + .../accesslogs/manager/manager.go | 9 +- .../modules/reverseproxy/service/service.go | 109 +- management/internals/server/boot.go | 28 +- management/internals/server/modules.go | 19 + management/internals/shared/grpc/proxy.go | 200 +- management/server/activity/codes.go | 49 + .../server/affectedpeers/proxy_synth_test.go | 95 + management/server/affectedpeers/resolver.go | 37 +- .../agentnetwork_budgetrule_realstack_test.go | 126 + .../agentnetwork_proxypeer_restart_test.go | 199 ++ .../server/agentnetwork_realstack_test.go | 212 ++ management/server/http/handler.go | 7 +- .../testing/testing_tools/channel/channel.go | 4 +- management/server/metrics/selfhosted.go | 16 + management/server/metrics/selfhosted_test.go | 15 + .../server/permissions/modules/module.go | 2 + management/server/store/file_store.go | 6 + management/server/store/sql_store.go | 339 +++ .../server/store/sql_store_agentnetwork.go | 664 ++++++ .../sql_store_agentnetwork_accesslog_test.go | 302 +++ .../sql_store_agentnetwork_budgetrule_test.go | 112 + management/server/store/store.go | 66 + .../server/store/store_mock_agentnetwork.go | 495 ++++ proxy/inbound.go | 9 +- proxy/inbound_test.go | 122 + proxy/internal/accesslog/logger.go | 50 + proxy/internal/accesslog/middleware.go | 14 +- proxy/internal/accesslog/middleware_test.go | 185 ++ proxy/internal/auth/middleware_test.go | 119 +- proxy/internal/llm/anthropic.go | 196 ++ proxy/internal/llm/anthropic_test.go | 169 ++ proxy/internal/llm/bedrock.go | 189 ++ proxy/internal/llm/bedrock_test.go | 65 + proxy/internal/llm/errors.go | 31 + .../llm/fixtures/anthropic_messages.json | 17 + .../llm/fixtures/anthropic_stream.txt | 21 + .../llm/fixtures/openai_chat_completion.json | 21 + .../llm/fixtures/openai_responses.json | 24 + .../llm/fixtures/openai_responses_stream.txt | 24 + proxy/internal/llm/fixtures/openai_stream.txt | 8 + proxy/internal/llm/fixtures/pricing.yaml | 59 + proxy/internal/llm/openai.go | 412 ++++ proxy/internal/llm/openai_test.go | 255 ++ proxy/internal/llm/parser.go | 112 + proxy/internal/llm/parser_test.go | 54 + .../llm/pricing/defaults_coverage_test.go | 65 + .../llm/pricing/defaults_pricing.yaml | 264 +++ proxy/internal/llm/pricing/pricing.go | 449 ++++ proxy/internal/llm/pricing/pricing_other.go | 20 + proxy/internal/llm/pricing/pricing_test.go | 432 ++++ proxy/internal/llm/pricing/pricing_unix.go | 68 + proxy/internal/llm/sse.go | 117 + proxy/internal/llm/sse_test.go | 175 ++ proxy/internal/metrics/metrics.go | 9 + proxy/internal/middleware/bodypolicy.go | 63 + proxy/internal/middleware/bodytap/request.go | 344 +++ proxy/internal/middleware/bodytap/response.go | 189 ++ .../middleware/bodytap/routing_scan_test.go | 86 + .../agentnetwork_chain_integration_test.go | 318 +++ proxy/internal/middleware/builtin/all_test.go | 40 + proxy/internal/middleware/builtin/builtin.go | 93 + .../middleware/builtin/cost_meter/factory.go | 88 + .../builtin/cost_meter/middleware.go | 193 ++ .../builtin/cost_meter/middleware_test.go | 459 ++++ .../builtin/llm_guardrail/factory.go | 82 + .../builtin/llm_guardrail/middleware.go | 183 ++ .../builtin/llm_guardrail/middleware_test.go | 219 ++ .../builtin/llm_guardrail/redact.go | 75 + .../builtin/llm_guardrail/redact_test.go | 217 ++ .../builtin/llm_identity_inject/factory.go | 108 + .../builtin/llm_identity_inject/middleware.go | 439 ++++ .../llm_identity_inject/middleware_test.go | 666 ++++++ .../builtin/llm_limit_check/factory.go | 38 + .../builtin/llm_limit_check/middleware.go | 196 ++ .../llm_limit_check/middleware_test.go | 186 ++ .../builtin/llm_limit_record/factory.go | 35 + .../builtin/llm_limit_record/middleware.go | 144 ++ .../llm_limit_record/middleware_test.go | 191 ++ .../llm_request_parser/bedrock_test.go | 55 + .../builtin/llm_request_parser/factory.go | 71 + .../builtin/llm_request_parser/middleware.go | 453 ++++ .../llm_request_parser/middleware_test.go | 418 ++++ .../builtin/llm_response_parser/factory.go | 43 + .../builtin/llm_response_parser/gzip_test.go | 133 ++ .../builtin/llm_response_parser/middleware.go | 339 +++ .../llm_response_parser/middleware_test.go | 433 ++++ .../responses_stream_test.go | 69 + .../builtin/llm_response_parser/streaming.go | 270 +++ .../llm_response_parser/streaming_bedrock.go | 110 + .../streaming_bedrock_test.go | 74 + .../llm_response_parser/streaming_test.go | 169 ++ .../middleware/builtin/llm_router/factory.go | 106 + .../builtin/llm_router/middleware.go | 793 +++++++ .../builtin/llm_router/middleware_test.go | 840 +++++++ .../builtin/llm_router/path_routed_test.go | 159 ++ proxy/internal/middleware/chain.go | 320 +++ proxy/internal/middleware/chain_test.go | 370 +++ proxy/internal/middleware/decision.go | 81 + proxy/internal/middleware/dispatcher.go | 189 ++ proxy/internal/middleware/headerpolicy.go | 99 + proxy/internal/middleware/keys.go | 86 + proxy/internal/middleware/manager.go | 412 ++++ proxy/internal/middleware/metadata.go | 99 + proxy/internal/middleware/metrics.go | 171 ++ proxy/internal/middleware/middleware.go | 47 + proxy/internal/middleware/redaction.go | 79 + proxy/internal/middleware/registry.go | 121 + proxy/internal/middleware/spec.go | 44 + proxy/internal/middleware/types.go | 253 ++ .../agent_network_chain_realstack_test.go | 321 +++ proxy/internal/proxy/context.go | 43 +- proxy/internal/proxy/reverseproxy.go | 456 +++- proxy/internal/proxy/reverseproxy_test.go | 43 + proxy/internal/proxy/servicemapping.go | 16 + proxy/internal/proxy/strip_prefix_test.go | 30 + proxy/internal/roundtrip/netbird.go | 74 +- proxy/internal/tcp/accept.go | 85 + proxy/internal/tcp/accept_test.go | 142 ++ proxy/internal/tcp/router.go | 15 +- proxy/internal/tcp/router_test.go | 129 ++ proxy/middleware_register.go | 16 + proxy/middleware_translate.go | 165 ++ proxy/middleware_translate_test.go | 246 ++ proxy/server.go | 173 +- shared/management/http/api/openapi.yml | 2061 +++++++++++++++++ shared/management/http/api/types.gen.go | 954 ++++++++ shared/management/proto/management.pb.go | 24 +- shared/management/proto/proxy_service.pb.go | 1934 +++++++++++----- shared/management/proto/proxy_service.proto | 125 + .../management/proto/proxy_service_grpc.pb.go | 88 + shared/management/status/error.go | 20 + 208 files changed, 39957 insertions(+), 688 deletions(-) create mode 100644 .github/workflows/agent-network-e2e.yml create mode 100644 docs/agent-networks/00-overview.md create mode 100644 docs/agent-networks/01-end-to-end-flows.md create mode 100644 docs/agent-networks/README.md create mode 100644 docs/agent-networks/modules/10-shared-api.md create mode 100644 docs/agent-networks/modules/20-management-store.md create mode 100644 docs/agent-networks/modules/21-management-agentnetwork.md create mode 100644 docs/agent-networks/modules/22-management-handlers-wiring.md create mode 100644 docs/agent-networks/modules/30-proxy-middleware-framework.md create mode 100644 docs/agent-networks/modules/31-proxy-middleware-builtin.md create mode 100644 docs/agent-networks/modules/32-proxy-llm-parsers.md create mode 100644 docs/agent-networks/modules/33-proxy-runtime.md create mode 100644 docs/agent-networks/modules/40-dashboard.md create mode 100644 docs/agent-networks/modules/50-path-routed-providers.md create mode 100644 e2e/agentnetwork/bootstrap_test.go create mode 100644 e2e/agentnetwork/chat_test.go create mode 100644 e2e/agentnetwork/main_test.go create mode 100644 e2e/agentnetwork/management_test.go create mode 100644 e2e/harness/Dockerfile.client create mode 100644 e2e/harness/agentnetwork.go create mode 100644 e2e/harness/bootstrap.go create mode 100644 e2e/harness/cert.go create mode 100644 e2e/harness/client.go create mode 100644 e2e/harness/combined.go create mode 100644 e2e/harness/config.go create mode 100644 e2e/harness/doc.go create mode 100644 e2e/harness/paths.go create mode 100644 e2e/harness/proxy.go create mode 100644 management/internals/modules/agentnetwork/accesslog_ingest.go create mode 100644 management/internals/modules/agentnetwork/accesslog_ingest_realstore_test.go create mode 100644 management/internals/modules/agentnetwork/accesslog_sessions_realstore_test.go create mode 100644 management/internals/modules/agentnetwork/affectedpeers_hook.go create mode 100644 management/internals/modules/agentnetwork/catalog/catalog.go create mode 100644 management/internals/modules/agentnetwork/handlers/access_log_handler.go create mode 100644 management/internals/modules/agentnetwork/handlers/budget_handler.go create mode 100644 management/internals/modules/agentnetwork/handlers/budget_handler_test.go create mode 100644 management/internals/modules/agentnetwork/handlers/consumption_handler.go create mode 100644 management/internals/modules/agentnetwork/handlers/guardrails_handler.go create mode 100644 management/internals/modules/agentnetwork/handlers/handlers_test.go create mode 100644 management/internals/modules/agentnetwork/handlers/policies_handler.go create mode 100644 management/internals/modules/agentnetwork/handlers/providers_handler.go create mode 100644 management/internals/modules/agentnetwork/handlers/settings_handler.go create mode 100644 management/internals/modules/agentnetwork/labelgen/labelgen.go create mode 100644 management/internals/modules/agentnetwork/labelgen/labelgen_test.go create mode 100644 management/internals/modules/agentnetwork/labelgen/words.go create mode 100644 management/internals/modules/agentnetwork/manager.go create mode 100644 management/internals/modules/agentnetwork/policyselect.go create mode 100644 management/internals/modules/agentnetwork/policyselect_account_realstore_test.go create mode 100644 management/internals/modules/agentnetwork/policyselect_realstore_test.go create mode 100644 management/internals/modules/agentnetwork/policyselect_test.go create mode 100644 management/internals/modules/agentnetwork/reconcile.go create mode 100644 management/internals/modules/agentnetwork/reconcile_test.go create mode 100644 management/internals/modules/agentnetwork/synthesizer.go create mode 100644 management/internals/modules/agentnetwork/synthesizer_guardrail_realstore_test.go create mode 100644 management/internals/modules/agentnetwork/synthesizer_log_collection_realstore_test.go create mode 100644 management/internals/modules/agentnetwork/synthesizer_parser_redact_realstore_test.go create mode 100644 management/internals/modules/agentnetwork/synthesizer_realstore_test.go create mode 100644 management/internals/modules/agentnetwork/synthesizer_test.go create mode 100644 management/internals/modules/agentnetwork/types/accesslog.go create mode 100644 management/internals/modules/agentnetwork/types/accesslogfilter.go create mode 100644 management/internals/modules/agentnetwork/types/budgetrule.go create mode 100644 management/internals/modules/agentnetwork/types/consumption.go create mode 100644 management/internals/modules/agentnetwork/types/consumption_test.go create mode 100644 management/internals/modules/agentnetwork/types/guardrail.go create mode 100644 management/internals/modules/agentnetwork/types/policy.go create mode 100644 management/internals/modules/agentnetwork/types/provider.go create mode 100644 management/internals/modules/agentnetwork/types/settings.go create mode 100644 management/internals/modules/agentnetwork/types/usage.go create mode 100644 management/internals/modules/agentnetwork/types/usageoverview.go create mode 100644 management/internals/modules/agentnetwork/wire_shape_test.go create mode 100644 management/server/affectedpeers/proxy_synth_test.go create mode 100644 management/server/agentnetwork_budgetrule_realstack_test.go create mode 100644 management/server/agentnetwork_proxypeer_restart_test.go create mode 100644 management/server/agentnetwork_realstack_test.go create mode 100644 management/server/store/sql_store_agentnetwork.go create mode 100644 management/server/store/sql_store_agentnetwork_accesslog_test.go create mode 100644 management/server/store/sql_store_agentnetwork_budgetrule_test.go create mode 100644 management/server/store/store_mock_agentnetwork.go create mode 100644 proxy/internal/accesslog/middleware_test.go create mode 100644 proxy/internal/llm/anthropic.go create mode 100644 proxy/internal/llm/anthropic_test.go create mode 100644 proxy/internal/llm/bedrock.go create mode 100644 proxy/internal/llm/bedrock_test.go create mode 100644 proxy/internal/llm/errors.go create mode 100644 proxy/internal/llm/fixtures/anthropic_messages.json create mode 100644 proxy/internal/llm/fixtures/anthropic_stream.txt create mode 100644 proxy/internal/llm/fixtures/openai_chat_completion.json create mode 100644 proxy/internal/llm/fixtures/openai_responses.json create mode 100644 proxy/internal/llm/fixtures/openai_responses_stream.txt create mode 100644 proxy/internal/llm/fixtures/openai_stream.txt create mode 100644 proxy/internal/llm/fixtures/pricing.yaml create mode 100644 proxy/internal/llm/openai.go create mode 100644 proxy/internal/llm/openai_test.go create mode 100644 proxy/internal/llm/parser.go create mode 100644 proxy/internal/llm/parser_test.go create mode 100644 proxy/internal/llm/pricing/defaults_coverage_test.go create mode 100644 proxy/internal/llm/pricing/defaults_pricing.yaml create mode 100644 proxy/internal/llm/pricing/pricing.go create mode 100644 proxy/internal/llm/pricing/pricing_other.go create mode 100644 proxy/internal/llm/pricing/pricing_test.go create mode 100644 proxy/internal/llm/pricing/pricing_unix.go create mode 100644 proxy/internal/llm/sse.go create mode 100644 proxy/internal/llm/sse_test.go create mode 100644 proxy/internal/middleware/bodypolicy.go create mode 100644 proxy/internal/middleware/bodytap/request.go create mode 100644 proxy/internal/middleware/bodytap/response.go create mode 100644 proxy/internal/middleware/bodytap/routing_scan_test.go create mode 100644 proxy/internal/middleware/builtin/agentnetwork_chain_integration_test.go create mode 100644 proxy/internal/middleware/builtin/all_test.go create mode 100644 proxy/internal/middleware/builtin/builtin.go create mode 100644 proxy/internal/middleware/builtin/cost_meter/factory.go create mode 100644 proxy/internal/middleware/builtin/cost_meter/middleware.go create mode 100644 proxy/internal/middleware/builtin/cost_meter/middleware_test.go create mode 100644 proxy/internal/middleware/builtin/llm_guardrail/factory.go create mode 100644 proxy/internal/middleware/builtin/llm_guardrail/middleware.go create mode 100644 proxy/internal/middleware/builtin/llm_guardrail/middleware_test.go create mode 100644 proxy/internal/middleware/builtin/llm_guardrail/redact.go create mode 100644 proxy/internal/middleware/builtin/llm_guardrail/redact_test.go create mode 100644 proxy/internal/middleware/builtin/llm_identity_inject/factory.go create mode 100644 proxy/internal/middleware/builtin/llm_identity_inject/middleware.go create mode 100644 proxy/internal/middleware/builtin/llm_identity_inject/middleware_test.go create mode 100644 proxy/internal/middleware/builtin/llm_limit_check/factory.go create mode 100644 proxy/internal/middleware/builtin/llm_limit_check/middleware.go create mode 100644 proxy/internal/middleware/builtin/llm_limit_check/middleware_test.go create mode 100644 proxy/internal/middleware/builtin/llm_limit_record/factory.go create mode 100644 proxy/internal/middleware/builtin/llm_limit_record/middleware.go create mode 100644 proxy/internal/middleware/builtin/llm_limit_record/middleware_test.go create mode 100644 proxy/internal/middleware/builtin/llm_request_parser/bedrock_test.go create mode 100644 proxy/internal/middleware/builtin/llm_request_parser/factory.go create mode 100644 proxy/internal/middleware/builtin/llm_request_parser/middleware.go create mode 100644 proxy/internal/middleware/builtin/llm_request_parser/middleware_test.go create mode 100644 proxy/internal/middleware/builtin/llm_response_parser/factory.go create mode 100644 proxy/internal/middleware/builtin/llm_response_parser/gzip_test.go create mode 100644 proxy/internal/middleware/builtin/llm_response_parser/middleware.go create mode 100644 proxy/internal/middleware/builtin/llm_response_parser/middleware_test.go create mode 100644 proxy/internal/middleware/builtin/llm_response_parser/responses_stream_test.go create mode 100644 proxy/internal/middleware/builtin/llm_response_parser/streaming.go create mode 100644 proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock.go create mode 100644 proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock_test.go create mode 100644 proxy/internal/middleware/builtin/llm_response_parser/streaming_test.go create mode 100644 proxy/internal/middleware/builtin/llm_router/factory.go create mode 100644 proxy/internal/middleware/builtin/llm_router/middleware.go create mode 100644 proxy/internal/middleware/builtin/llm_router/middleware_test.go create mode 100644 proxy/internal/middleware/builtin/llm_router/path_routed_test.go create mode 100644 proxy/internal/middleware/chain.go create mode 100644 proxy/internal/middleware/chain_test.go create mode 100644 proxy/internal/middleware/decision.go create mode 100644 proxy/internal/middleware/dispatcher.go create mode 100644 proxy/internal/middleware/headerpolicy.go create mode 100644 proxy/internal/middleware/keys.go create mode 100644 proxy/internal/middleware/manager.go create mode 100644 proxy/internal/middleware/metadata.go create mode 100644 proxy/internal/middleware/metrics.go create mode 100644 proxy/internal/middleware/middleware.go create mode 100644 proxy/internal/middleware/redaction.go create mode 100644 proxy/internal/middleware/registry.go create mode 100644 proxy/internal/middleware/spec.go create mode 100644 proxy/internal/middleware/types.go create mode 100644 proxy/internal/proxy/agent_network_chain_realstack_test.go create mode 100644 proxy/internal/proxy/strip_prefix_test.go create mode 100644 proxy/internal/tcp/accept.go create mode 100644 proxy/internal/tcp/accept_test.go create mode 100644 proxy/middleware_register.go create mode 100644 proxy/middleware_translate.go create mode 100644 proxy/middleware_translate_test.go diff --git a/.github/workflows/agent-network-e2e.yml b/.github/workflows/agent-network-e2e.yml new file mode 100644 index 000000000..c041bfbfa --- /dev/null +++ b/.github/workflows/agent-network-e2e.yml @@ -0,0 +1,69 @@ +name: Agent Network E2E + +on: + push: + branches: + - main + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + e2e: + name: Agent Network E2E + if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository + runs-on: ubuntu-latest + timeout-minutes: 45 + steps: + - name: Checkout code + uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 + with: + persist-credentials: false + + - name: Install Go + uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0 + with: + go-version-file: "go.mod" + + # Container-driver builder so the harness can build the combined/proxy/ + # client images from source with a local layer cache. + - name: Set up Buildx + uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4.1.0 + + # Persist the Docker layer cache across runs. This caches the base, apt, + # and go-mod-download layers; the Go compile still re-runs, as BuildKit + # mount caches cannot be exported to the GitHub cache. + - name: Cache Docker layers + uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 + with: + path: /tmp/.buildx-cache + key: ${{ runner.os }}-anet-e2e-buildx-${{ hashFiles('go.sum', 'combined/Dockerfile.multistage', 'proxy/Dockerfile.multistage', 'e2e/harness/Dockerfile.client') }} + restore-keys: | + ${{ runner.os }}-anet-e2e-buildx- + + - name: Run agent-network e2e + env: + # Build the images from source (this branch's code) with the shared + # local layer cache. + NB_E2E_BUILDX_CACHE: /tmp/.buildx-cache + # Provider credentials. Each provider scenario skips if its + # token (and URL, for gateways) is unset, so partial coverage is fine. + OPENAI_TOKEN: ${{ secrets.E2E_OPENAI_TOKEN }} + ANTHROPIC_TOKEN: ${{ secrets.E2E_ANTHROPIC_TOKEN }} + VERCEL_URL: ${{ secrets.E2E_VERCEL_URL }} + VERCEL_TOKEN: ${{ secrets.E2E_VERCEL_TOKEN }} + OPENROUTER_URL: ${{ secrets.E2E_OPENROUTER_URL }} + OPENROUTER_TOKEN: ${{ secrets.E2E_OPENROUTER_TOKEN }} + CLOUDFLARE_URL: ${{ secrets.E2E_CLOUDFLARE_URL }} + CLOUDFLARE_TOKEN: ${{ secrets.E2E_CLOUDFLARE_TOKEN }} + AWS_BEARER_TOKEN_BEDROCK: ${{ secrets.E2E_AWS_BEARER_TOKEN_BEDROCK }} + AWS_REGION: ${{ secrets.E2E_AWS_REGION }} + # Vertex (Anthropic-on-Vertex): SA + project required; region defaults + # to "global", model to a pinned claude snapshot. + GOOGLE_VERTEX_SA_BASE64: ${{ secrets.E2E_GOOGLE_VERTEX_SA_BASE64 }} + GOOGLE_VERTEX_PROJECT: ${{ secrets.E2E_GOOGLE_VERTEX_PROJECT }} + GOOGLE_VERTEX_REGION: ${{ secrets.E2E_GOOGLE_VERTEX_REGION }} + GOOGLE_VERTEX_MODEL: ${{ secrets.E2E_GOOGLE_VERTEX_MODEL }} + run: go test -tags e2e -timeout 40m -v ./e2e/... diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 5d26d678d..511e54b62 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -21,7 +21,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2 with: - ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals + ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals,flate,recordin,unparseable skip: go.mod,go.sum,**/proxy/web/** golangci: strategy: diff --git a/combined/Dockerfile.multistage b/combined/Dockerfile.multistage index ef3d68c6e..79746819d 100644 --- a/combined/Dockerfile.multistage +++ b/combined/Dockerfile.multistage @@ -5,12 +5,16 @@ WORKDIR /app RUN apt-get update && apt-get install -y gcc libc6-dev git && rm -rf /var/lib/apt/lists/* COPY go.mod go.sum ./ -RUN go mod download +RUN --mount=type=cache,target=/go/pkg/mod go mod download COPY . . -# Build with version info from git (matching goreleaser ldflags) -RUN CGO_ENABLED=1 GOOS=linux go build \ +# Build with version info from git (matching goreleaser ldflags). +# BuildKit cache mounts persist the module + build caches across image builds, +# so a source change recompiles incrementally instead of from scratch. +RUN --mount=type=cache,target=/go/pkg/mod \ + --mount=type=cache,target=/root/.cache/go-build \ + CGO_ENABLED=1 GOOS=linux go build \ -ldflags="-s -w \ -X github.com/netbirdio/netbird/version.version=$(git describe --tags --always --dirty 2>/dev/null || echo 'dev') \ -X main.commit=$(git rev-parse --short HEAD 2>/dev/null || echo 'unknown') \ diff --git a/docs/agent-networks/00-overview.md b/docs/agent-networks/00-overview.md new file mode 100644 index 000000000..0d76e44a0 --- /dev/null +++ b/docs/agent-networks/00-overview.md @@ -0,0 +1,109 @@ +# Agent Networks — overview + +Single-entry point. Feature scope, the module map, and the cross-cutting +topics worth keeping in mind, with links into every per-module guide. + +## TL;DR + +Agent Networks introduces an **LLM-aware reverse-proxy middleware system** +plus **account-level controls** (budget rules, log collection toggles, +PII redaction). The management server synthesises a per-peer middleware +chain that the proxy executes on every LLM request; the chain enforces +quotas, injects identity, redacts PII, parses tokens/cost, and emits +access-log entries. The dashboard exposes the surface as a single **AI +Observability** page with four tabs. + +- **Backend** lives in this repo, primarily under + `management/server/agentnetwork`, `proxy/internal/middleware`, and + `proxy/internal/llm`, with wire contracts in `shared/management`. +- **Dashboard** lives in the dashboard repo under + `src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`. + +## Reading order + +| # | Doc | Why | +|---|-----|-----| +| 1 | [01-end-to-end-flows.md](01-end-to-end-flows.md) | Get the three big diagrams in your head first. | +| 2 | [modules/10-shared-api.md](modules/10-shared-api.md) | Wire contracts — every other module either produces or consumes these. | +| 3 | [modules/21-management-agentnetwork.md](modules/21-management-agentnetwork.md) | The largest module; everything the proxy executes originates here. | +| 4 | [modules/30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md) | The generic plugin system on the proxy side. | +| 5 | [modules/31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md) | The 8 LLM middlewares that ride on the framework. | +| 6 | Everything else in any order. | | + +## Module map + +11 modules. Each is described in detail in its own file under +[`modules/`](modules/). + +| # | Module | Risk | BC impact | +|---|--------|------|-----------| +| 10 | [shared/api](modules/10-shared-api.md) — proto + OpenAPI | Low | Additive only | +| 20 | [management/store](modules/20-management-store.md) — SQL persistence | Medium | Auto-migrate (additive) | +| 21 | [management/agentnetwork](modules/21-management-agentnetwork.md) — domain layer + synthesizer | **High** | Additive | +| 22 | [management/handlers + wiring](modules/22-management-handlers-wiring.md) — HTTP API + gRPC delivery | Medium | Additive | +| 30 | [proxy/middleware-framework](modules/30-proxy-middleware-framework.md) — generic plugin system | High | Additive | +| 31 | [proxy/middleware-builtin](modules/31-proxy-middleware-builtin.md) — 8 LLM middlewares | High | Additive | +| 32 | [proxy/llm-parsers](modules/32-proxy-llm-parsers.md) — SDK adapters + pricing | Medium | Additive | +| 33 | [proxy/runtime](modules/33-proxy-runtime.md) — translate + serve + access-log | High | Additive (touches hot path) | +| 40 | [dashboard](modules/40-dashboard.md) — UI for everything above | Medium | Sidebar reshape | +| 50 | [path-routed-providers](modules/50-path-routed-providers.md) — Vertex AI + Bedrock | Medium | Additive (new catalog entries) | + +The largest and highest-risk module is `management/agentnetwork`: it is +the single writer of the middleware chain the proxy executes. + +## Cross-cutting topics + +These are the items most likely to bite production. Each is fully +documented in the linked module guide. + +1. **Capture-pointer semantics** (`*bool` for `capture_prompt` and + `capture_completion`): nil = legacy emit, false = suppress, true = + emit. nil-vs-false must be handled at every JSON hop. See + [21-management-agentnetwork.md](modules/21-management-agentnetwork.md) + and [31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md). +2. **`ProxyMapping.Private` preservation** on per-proxy live updates. + Failure mode: `auth` skips `ValidateTunnelPeer` → + `CapturedData.UserGroups` empty → `llm_router` denies. See + [33-proxy-runtime.md](modules/33-proxy-runtime.md). +3. **respInput carrying `UserEmail`/`UserGroups`/`UserGroupNames` onto + the response leg** in `reverseproxy.go`. Load-bearing wire that lets + `llm_limit_record` ship non-empty `group_ids` on `RecordLLMUsage`. See + [33-proxy-runtime.md](modules/33-proxy-runtime.md). +4. **Min-wins all-must-pass budget rule semantics**. Every matching + rule's remaining quota must be > 0 for the request to proceed; one + exhausted rule blocks the whole call. Documented in + [21-management-agentnetwork.md](modules/21-management-agentnetwork.md) + and the `llm_limit_check` middleware in + [31-proxy-middleware-builtin.md](modules/31-proxy-middleware-builtin.md). +5. **body-tap memory bounds**: per-direction 1 MiB cap, shared 256 MiB + budget, `LimitReader(r.Body, limit+1)` for truncation detection with + `replayReadCloser` fallback so upstream still sees the full body. + `cloneInputFor` deep-copies the body up to 16 times per chain — a + perf hot-spot. See + [30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md). +6. **UpstreamRewrite.AuthHeader bypasses the header denylist** + deliberately. The runtime consumer only unpacks it via the + trusted upstream-build path. See + [30-proxy-middleware-framework.md](modules/30-proxy-middleware-framework.md). +7. **`disable_access_log` default-false semantics**: the synth target + sets it true, all other targets leave it false. See + [10-shared-api.md](modules/10-shared-api.md). +8. **String-typed `decision` / `deny_code`** on + `CheckLLMPolicyLimitsResponse` — would benefit from enum pinning + before external consumers integrate. See + [10-shared-api.md](modules/10-shared-api.md). + +## Explicit non-goals + +- **Reaper / GC pass over stale synth services** — designed but cut from + scope. +- **URL-sync for tab state on AI Observability** — read path is wired + (`?tab=`) but write path isn't. Future work. +- **CI golden-file regen-and-diff for `types.gen.go` / + `proxy_service.pb.go`** — would catch codegen drift; not yet in place. + +## Where to read the code + +Per-module file scopes are listed in each module guide. Behaviour is +covered by Go tests co-located with each package (and an end-to-end +chain integration test under `proxy/internal/proxy`). diff --git a/docs/agent-networks/01-end-to-end-flows.md b/docs/agent-networks/01-end-to-end-flows.md new file mode 100644 index 000000000..7264f3768 --- /dev/null +++ b/docs/agent-networks/01-end-to-end-flows.md @@ -0,0 +1,217 @@ +# End-to-end flows + +Three cross-module mermaid diagrams. Each per-module guide repeats the +slice that's relevant to its own scope — these are the canonical +top-down views. + +- [Flow A — Config → runtime (synth + deliver)](#flow-a--config--runtime-synth--deliver) +- [Flow B — Request lifecycle through the LLM chain](#flow-b--request-lifecycle-through-the-llm-chain) +- [Flow C — Budget rule feedback loop](#flow-c--budget-rule-feedback-loop) + +--- + +## Flow A — Config → runtime (synth + deliver) + +How an operator's change to a Provider, Policy, Guardrail, Budget Rule, +or Settings record ends up as live middleware on a peer's proxy. + +```mermaid +sequenceDiagram + autonumber + actor Op as Operator + participant UI as Dashboard + participant HTTP as management/handlers + participant Mgr as agentnetwork.Manager + participant Store as management/store (SQL) + participant Ctl as network_map.Controller + participant Synth as agentnetwork.SynthesizeServices + participant Grpc as management gRPC + participant Proxy as netbird-proxy + participant Xlate as middleware_translate + participant Chain as middleware.Chain + + Op->>UI: edit provider/policy/budget/settings + UI->>HTTP: REST PUT/POST /api/agent-network/* + HTTP->>Mgr: SaveProvider / SavePolicy / SaveBudgetRule / SaveSettings + Mgr->>Store: persist (gorm) + Mgr-->>Ctl: account change event (Network-Map dirty) + loop per connected peer + Ctl->>Synth: SynthesizeServices(ctx, store, accountID) + Synth->>Store: load providers, policies, guardrails, budget rules, settings + Synth-->>Synth: build per-peer Service list + Note over Synth: each Service has a middleware
chain with capture_prompt /
capture_completion / redact_pii
baked from account settings + Synth-->>Ctl: []rpservice.Service + Ctl->>Grpc: NetworkMap push (services + middleware configs) + end + Grpc-->>Proxy: NetworkMap stream + Proxy->>Xlate: translate proto MiddlewareConfig → runtime Spec + Xlate->>Chain: register / replace per-service chain + Note over Chain: chain replacement is live
(no proxy restart, in-flight
requests unaffected) +``` + +**Notes on the diagram** + +- The `network_map.Controller` synthesises on every push, not on a + timer. A single config change costs O(connected peers × policies × + providers) per push. See [`modules/22-management-handlers-wiring.md`](modules/22-management-handlers-wiring.md). +- `SynthesizeServices` is the single source of truth for the wire + format the proxy executes. Anything the proxy does that the + synthesiser didn't request is a bug. See + [`modules/21-management-agentnetwork.md`](modules/21-management-agentnetwork.md). +- The translate step (step 13) is the only place that knows the + middleware-ID strings on the proxy side. It must reject unknown IDs; + silently dropping middlewares would create a security gap (e.g. + missing `llm_limit_check` ⇒ unbounded spend). See + [`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md). + +--- + +## Flow B — Request lifecycle through the LLM chain + +What happens when an agent on the client peer sends a chat-completion / +messages request through the synthesised reverse-proxy. + +```mermaid +sequenceDiagram + autonumber + actor Agent as Agent (local) + participant Px as netbird-proxy + participant Auth as auth middleware + participant Map as service-mapping + participant Req as llm_request_parser + participant Rt as llm_router + participant Chk as llm_limit_check + participant Inj as llm_identity_inject + participant Grd as llm_guardrail + participant Up as upstream LLM + participant Resp as llm_response_parser + participant Cost as cost_meter + participant Rec as llm_limit_record + participant Log as access-log + participant MgmtGrpc as management gRPC + + Agent->>Px: POST /v1/chat/completions (OpenAI / Anthropic) + Px->>Auth: identify peer (user, groups) + Auth->>Map: resolve service from Host + path + Map-->>Req: dispatch chain in slot order + + Req->>Req: parse body → provider, model, prompt, token estimate + Note over Req: capture_prompt gates raw_prompt
capture (nil = legacy emit,
false = drop, true = emit) + Req->>Rt: pass metadata + Rt->>Chk: route to upstream candidate + + Chk->>MgmtGrpc: CheckLLMPolicyLimits(provider, model, est_tokens, groups, user) + MgmtGrpc-->>Chk: decision = allow / deny + deny_code + alt decision == deny + Chk-->>Log: emit access-log with deny_code
(if EnableLogCollection) + Chk-->>Agent: 429 (or 403 per deny_code) + else decision == allow + Chk->>Inj: continue + Inj->>Inj: inject NetBird identity headers per provider config + Inj->>Grd: continue + Grd->>Grd: enforce model allowlist + Grd->>Up: forward (over WireGuard) + Up-->>Resp: response (JSON or SSE stream) + Resp->>Resp: parse usage tokens, completion + Note over Resp: capture_completion gates raw
completion capture + Resp->>Cost: tokens + Cost->>Cost: lookup pricing.yaml + compute cost + Cost->>Rec: tokens + cost + Rec->>MgmtGrpc: RecordLLMUsage(provider, model, prompt_t, completion_t, cost, groups, user) + Rec-->>Log: emit access-log entry
(if EnableLogCollection) + Log-->>Agent: 200 + body (streamed if SSE) + end +``` + +**Notes on the diagram** + +- The chain runs in synth-defined order. Re-ordering middlewares + changes invariants — `llm_limit_check` must precede `llm_router` so + a denied request never hits upstream, and `llm_limit_record` must + pair with `llm_limit_check` so a successful check is always recorded + (or the rate-limit semantics break). See + [`modules/31-proxy-middleware-builtin.md`](modules/31-proxy-middleware-builtin.md). +- `llm_guardrail` is also where PII redaction happens + (`redact_pii = settings.RedactPii`). Phones, emails, credit cards, + PII names — see `redact.go` for the full set. See + [`modules/31-proxy-middleware-builtin.md`](modules/31-proxy-middleware-builtin.md). +- SSE streaming requires special handling on the response side; the + parser must handle partial chunks without buffering the whole + stream. See [`modules/32-proxy-llm-parsers.md`](modules/32-proxy-llm-parsers.md). +- Access-log emission is gated on `settings.EnableLogCollection`. With + it OFF, neither the deny nor the allow leg writes an entry — the + chain still runs (budget rules are still enforced) but no audit trail + is kept. See + [`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md). + +--- + +## Flow C — Budget rule feedback loop + +How an account's budget rules tighten ceilings on every request and how +consumption flows back into the dashboard. + +```mermaid +flowchart LR + subgraph Operator + DashBud[Dashboard Budget Settings tab] + end + subgraph Mgmt[Management] + Save[POST/PUT /api/agent-network/budget-rules] + Store[(SQL store)] + Synth[SynthesizeServices] + Check[CheckLLMPolicyLimits RPC] + Rec[RecordLLMUsage RPC] + Cons[/api/agent-network/consumption] + end + subgraph Proxy[Proxy] + Chk[llm_limit_check] + RecMw[llm_limit_record] + end + subgraph DashView[Dashboard Budget Dashboard tab] + Panel[AgentConsumptionPanel] + end + + DashBud -->|create / update rules| Save + Save --> Store + Store --> Synth + Synth -->|push synth-services to peer| Proxy + + Chk -->|per request| Check + Check -->|aggregate matching rules
min-wins all-must-pass| Store + Check -->|allow / deny| Chk + + RecMw -->|post-response| Rec + Rec -->|tokens + cost + groups + user| Store + + Store -->|read counters| Cons + Cons --> Panel +``` + +**Notes on the diagram** + +- **min-wins all-must-pass** is the core semantic. A budget rule binds + to (group set, user set) with a (window, ceiling). At check time, + every rule that matches the caller is evaluated; if ANY rule has + zero remaining quota the request is denied. This is the most + surprising semantic for operators — see the invariants section of + [`modules/21-management-agentnetwork.md`](modules/21-management-agentnetwork.md). +- The proxy never makes its own budget decisions. It always asks + management via `CheckLLMPolicyLimits` and reports back via + `RecordLLMUsage`. This keeps account-wide accounting in one place + and avoids per-proxy drift. +- `RecordLLMUsage` must carry `group_ids` and `user_id` so the + decrement hits the right rule(s). The wire that carries those + fields onto the response leg is `respInput` in `reverseproxy.go`. See + [`modules/33-proxy-runtime.md`](modules/33-proxy-runtime.md). +- The dashboard's Budget Dashboard tab polls + `/api/agent-network/consumption` — not gRPC, not WebSocket. Poll + interval lives in `AgentConsumptionPanel.tsx`. See + [`modules/40-dashboard.md`](modules/40-dashboard.md). + +--- + +## Cross-references + +- Per-module guides: [`modules/`](modules/) +- Overview + module map: [`00-overview.md`](00-overview.md) diff --git a/docs/agent-networks/README.md b/docs/agent-networks/README.md new file mode 100644 index 000000000..a7d2d2ab5 --- /dev/null +++ b/docs/agent-networks/README.md @@ -0,0 +1,66 @@ +# Agent Networks — architecture documentation + +A self-contained set of documents describing the agent-networks feature: +an LLM-aware reverse-proxy middleware system plus account-level controls +(budget rules, log collection toggles, PII redaction). The management +server synthesises a per-peer middleware chain that the proxy executes on +every LLM request. + +## What to read first + +1. **[00-overview.md](00-overview.md)** — the single entry point. Feature + scope, the module map, and the cross-cutting topics worth keeping in + mind, with links to every per-module guide. +2. **[01-end-to-end-flows.md](01-end-to-end-flows.md)** — three + high-level mermaid diagrams: config-to-runtime synth/delivery, + per-request lifecycle through the LLM chain, and the budget-rule + feedback loop. +3. **Per-module guides** under `modules/` — one file per package. Each + describes the module boundary, the file-level layout, its own flow + diagrams, the public contracts, the invariants it relies on, and the + areas worth the closest attention. + +## Directory layout + +``` +docs/agent-networks/ +├── README.md # you are here +├── 00-overview.md # feature summary + module map +├── 01-end-to-end-flows.md # cross-module mermaid diagrams +└── modules/ + ├── 10-shared-api.md # proto + OpenAPI wire contracts + ├── 20-management-store.md # SQL persistence layer + ├── 21-management-agentnetwork.md # domain layer + synthesizer (largest) + ├── 22-management-handlers-wiring.md # HTTP API + gRPC delivery + ├── 30-proxy-middleware-framework.md # generic plugin system + ├── 31-proxy-middleware-builtin.md # 8 LLM-aware middlewares + ├── 32-proxy-llm-parsers.md # OpenAI/Anthropic/Bedrock SDKs + pricing + ├── 33-proxy-runtime.md # translate + serve + access-log + ├── 40-dashboard.md # UI for everything above (lives in the dashboard repo) + └── 50-path-routed-providers.md # Vertex AI + Bedrock (path-routed, keyfile:: creds, /bedrock prefix) +``` + +The `40-dashboard.md` module documents code that lives in the **dashboard +repo**, not in this repo. The guide is co-located here so backend readers +see the full picture in one place. + +## How the per-module guides are structured + +Every `modules/*.md` follows the same template so the docs are easy to +scan: + +- **Module boundary** — what this package owns; where it sits in the stack. +- **Files** — path / role. +- **Architecture & flow** — one or more mermaid diagrams. +- **Public contracts** — function signatures, gRPC messages, JSON shapes. +- **Invariants** — semantic guarantees the module relies on or enforces. +- **Things to scrutinize** — split by correctness / security / + concurrency / backward-compat / performance / observability. +- **Test coverage** — the test files that lock down behaviour in this + module. +- **Known limitations / non-goals** — what is intentionally out of scope. +- **Cross-references** — upstream/downstream module links + the + end-to-end flow + the overview. + +See [00-overview.md](00-overview.md) for the module map and the +cross-cutting topics. diff --git a/docs/agent-networks/modules/10-shared-api.md b/docs/agent-networks/modules/10-shared-api.md new file mode 100644 index 000000000..532927b90 --- /dev/null +++ b/docs/agent-networks/modules/10-shared-api.md @@ -0,0 +1,105 @@ +# shared/api — wire contracts (proto + OpenAPI) + +> **Risk level:** Medium — wire-format surface that every other module pins against; backward-compat hinges on field-number discipline more than on logic correctness. +> **Backward-compat impact:** Additive only (new proto fields use unallocated numbers, new RPCs default to `Unimplemented`, new OpenAPI schemas/paths are append-only; no existing field/RPC/schema removed or renumbered). + +## Module boundary +This module owns the cross-process contract surface between management, proxy, and dashboard. Two artefacts: `shared/management/proto/proxy_service.proto` (management↔proxy gRPC) and `shared/management/http/api/openapi.yml` (dashboard/CLI↔management REST). Both have generated companions checked in (`proxy_service.pb.go`, `proxy_service_grpc.pb.go`, `types.gen.go`) which must travel in lockstep with their sources. `shared/management/status/error.go` is in scope only for the four new typed `NotFound` constructors that the new HTTP handlers return. + +Everything downstream — `management/agentnetwork`, `management/server/http/handlers/*`, `proxy/internal/*`, the dashboard SDK — consumes these types verbatim. The concern here is wire stability and codegen reproducibility, not behaviour: behaviour is covered in the management and proxy module guides. + +`management.proto` and `signalexchange.proto` are unchanged. `status/error.go` only receives four additive constructors (lines 208-227); no existing error types are reshaped. + +## Files +| Path | Role | +| ---- | ---- | +| `shared/management/proto/proxy_service.proto` | Source of truth: 2 new RPCs, 1 new message group (`MiddlewareConfig` + slot enum), additive fields on `PathTargetOptions`, `AccessLog`, `RecordLLMUsageRequest` | +| `shared/management/proto/proxy_service.pb.go` | Generated (protoc-gen-go) | +| `shared/management/proto/proxy_service_grpc.pb.go` | Generated; adds `CheckLLMPolicyLimits` + `RecordLLMUsage` client/server stubs and `UnimplementedProxyServiceServer` defaults | +| `shared/management/http/api/openapi.yml` | 15 new `AgentNetwork*` schemas, 9 new path groups under `/api/agent-network/*` | +| `shared/management/http/api/types.gen.go` | Generated (oapi-codegen; see codegen note below) | +| `shared/management/status/error.go` | Four `NotFound` constructors for the new resource kinds (lines 208-227) | + +## Architecture & flow +```mermaid +sequenceDiagram + participant Dash as Dashboard / CLI + participant Mgmt as management (HTTP+gRPC) + participant Px as proxy + + Note over Dash,Mgmt: REST (OpenAPI / types.gen.go) + Dash->>Mgmt: PUT /api/agent-network/providers (AgentNetworkProviderRequest) + Dash->>Mgmt: PUT /api/agent-network/settings (AgentNetworkSettingsRequest) + Dash->>Mgmt: GET /api/agent-network/consumption -> [AgentNetworkConsumption] + + Note over Mgmt,Px: gRPC ProxyService (proxy_service.proto) + Mgmt-->>Px: SyncMappingsResponse{ ProxyMapping.path[*].options.middlewares,
agent_network, disable_access_log, capture_* } + Px->>Mgmt: CheckLLMPolicyLimits(account, user, groups, provider, model) + Mgmt-->>Px: decision=allow|deny + selected_policy_id + attribution_group_id + window_seconds + Px->>Mgmt: RecordLLMUsage(account, user, group_id, group_ids, window_seconds, tokens, cost) + Px->>Mgmt: SendAccessLog(AccessLog{ agent_network=true }) +``` + +The proto changes split into three independent slices: (1) **mapping enrichment** — `PathTargetOptions` grows fields 8-13 so management can ship middleware configs, capture limits, and the agent-network / log-suppression flags down to the proxy without a second RPC; (2) **two new request/response RPCs** (`CheckLLMPolicyLimits`, `RecordLLMUsage`) for per-LLM-request budget arbitration; (3) **observability tag** — `AccessLog.agent_network` so management can route logs to the right surface. + +The OpenAPI side is a thin CRUD surface — every resource (`Provider`, `Policy`, `Guardrail`, `BudgetRule`, `Settings`) follows the same `GET-list / POST / GET / PUT / DELETE` pattern, plus a read-only `/consumption` listing and a catalog endpoint. The `*Request` variants drop server-controlled fields (id, timestamps). `AgentNetworkBudgetRule` deliberately reuses `AgentNetworkPolicyLimits` to keep wire-shape parity with policies. + +## Public contracts added +- gRPC RPCs (`proxy_service.proto:52-57`): `CheckLLMPolicyLimits(CheckLLMPolicyLimitsRequest) → CheckLLMPolicyLimitsResponse`, `RecordLLMUsage(RecordLLMUsageRequest) → RecordLLMUsageResponse`. Both unary; default `UnimplementedProxyServiceServer` returns `codes.Unimplemented` (`proxy_service_grpc.pb.go:283-289`). +- New messages (`proxy_service.proto:145-175,448-502`): `MiddlewareConfig`, `MiddlewareSlot` enum, `CheckLLMPolicyLimitsRequest`/`Response`, `RecordLLMUsageRequest`/`Response`. +- New `PathTargetOptions` fields 8-13 (`proxy_service.proto:124-140`): `capture_max_request_bytes`, `capture_max_response_bytes`, `capture_content_types`, `middlewares`, `agent_network`, `disable_access_log`. All default-false / zero; pre-existing fields 1-7 byte-for-byte unchanged. +- `AccessLog.agent_network = 18` (`proxy_service.proto:258-261`). +- `RecordLLMUsageRequest.group_ids = 8` (`proxy_service.proto:496-498`) — so the record path can fan out to every applicable budget rule's window without a re-lookup. +- 15 new OpenAPI component schemas (`openapi.yml:5072-5829`): `AgentNetworkProvider[Request|Model]`, `AgentNetworkCatalog{Model,Provider,IdentityInjection,HeaderPairInjection,JSONMetadataInjection,ExtraHeader}`, `AgentNetworkPolicy[Request|TokenLimit|BudgetLimit|Limits]`, `AgentNetworkGuardrail[Checks|Request]`, `AgentNetworkConsumption`, `AgentNetworkSettings[Request]`, `AgentNetworkBudgetRule[Request]`. +- 9 new path groups (`openapi.yml:12797-13460`): `/api/agent-network/{consumption,settings,budget-rules,budget-rules/{ruleId},catalog/providers,providers,providers/{providerId},policies,policies/{policyId},guardrails,guardrails/{guardrailId}}`. +- Four typed NotFound errors (`shared/management/status/error.go:208-227`). + +## Invariants +- **Field-number monotonicity.** Every new proto field uses a previously-unallocated number in its message: `PathTargetOptions` 8-13 (was 1-7), `AccessLog` 18 (was 1-17), `RecordLLMUsageRequest` 8. `SendStatusUpdateRequest.inbound_listener = 50` (pre-existing) reserves 50+ for observability extensions, so 8 on `RecordLLMUsageRequest` doesn't conflict. +- **Old proxies stay compatible.** Old management never sends `disable_access_log`/`middlewares`/`agent_network` (zero value → existing behaviour); old proxies that don't decode these fields just drop them silently (proto3 unknown-field semantics) — log emission stays on. No pre-existing field number changed: the proto change is insertions only. +- **Old management stays compatible.** The two new RPCs are registered on the same `management.ProxyService` descriptor; old proxies hitting them get `codes.Unimplemented` from the unimplemented embed (`proxy_service_grpc.pb.go:283-289`), which is the same fallback pattern `SyncMappings` already documents (`proxy_service.proto:20-21`). +- **OpenAPI shapes are append-only.** New schemas are placed at the end of `components.schemas` (line 5072+); new paths at the end of `paths` (line 12797+). No existing schema's `required` list, enum, or property type was changed. +- **`*Request` vs response asymmetry.** Read shapes (`AgentNetworkProvider`, `AgentNetworkPolicy`, `AgentNetworkGuardrail`, `AgentNetworkSettings`, `AgentNetworkBudgetRule`) require `created_at`/`updated_at`; the matching `*Request` shapes do not — server fills them. `AgentNetworkProviderRequest.api_key` is write-only (`openapi.yml:5158-5161` "never returned in responses"); reviewers should confirm the response schema (5072-5138) actually omits `api_key`. + +## Things to scrutinize +### Correctness +- `RecordLLMUsageRequest` carries both `group_id` (singular, the attribution group — field 3) and `group_ids` (plural, full membership — field 8). `b22d5a181` adds field 8 to drive account-budget fan-out; double-check that consumers can't accidentally key counters on the wrong one. Field comments at `proxy_service.proto:489-491` and `496-498` distinguish them but it's the kind of subtle thing a follow-up commit might collapse. +- `PathTargetOptions.disable_access_log` is the only field whose default-false meaning **changes semantics** on the proxy side: false → log (status quo), true → suppress. Synthesizer sets `DisableAccessLog = !settings.EnableLogCollection`, so a missing/default settings row yields `EnableLogCollection=false → DisableAccessLog=true → suppressed`. Worth confirming downstream (`agentnetwork.synthesizer`) that operator-defined private services never inherit this flag — the proto field default protects them, but only if synth code is explicit. +- `CheckLLMPolicyLimitsResponse.decision` is a free-form `string` (`proxy_service.proto:471`) rather than an enum. Only documented values are "allow" / "deny". An enum would prevent typo drift; consider before this RPC ships to external consumers. +- `deny_code` (`proxy_service.proto:478-481`) is documented as "a stable label" but is also a free string. Pin the allowed set somewhere observable to the proxy. + +### Security +- `AgentNetworkProvider.api_key` MUST be write-only. Schema split (request has it at line 5158; response omits it) looks correct, but a regression here leaks the upstream provider credential to every dashboard reader. Check that the handler explicitly zeros it on the response path. +- `extra_values` / `identity_header_*` headers on `AgentNetworkProvider` get stamped onto upstream requests. Description at `openapi.yml:5099` says "values not declared by the catalog are ignored at synth time" — a contract this module documents but the synthesizer must enforce. Confirm the synth module honours it. +- Cluster + subdomain on `AgentNetworkSettings` are documented immutable (`openapi.yml:5686-5694`) and the `AgentNetworkSettingsRequest` (lines 5733-5752) doesn't accept them. Verify the `PUT /api/agent-network/settings` handler can't be tricked by extra JSON keys (oapi-codegen's `additionalProperties: false` is not declared here; spec defaults to permissive). + +### Backward compatibility +- The proto change is field-number additive: every previously numbered field keeps the same name + type, and the change is insertions only (no deletions in `proxy_service.proto`), so this holds at the source-text level. +- `proxy_service_grpc.pb.go` adds two RPC handlers and registers them in `ProxyService_ServiceDesc.Methods` (lines 543-552). The existing entries are unchanged and order-preserving — gRPC method dispatch is name-keyed, so order doesn't matter, but reviewing the diff (no method renamed/dropped) is still worth a glance. +- OpenAPI 3.0 doesn't have a built-in deprecation flow for paths; if any client tooling iterates `paths.*`, the additive routes shouldn't break it, but generated SDKs (especially the dashboard's) need a regen to gain access to `AgentNetwork*`. + +### Codegen pinning +- `generate.sh` (`shared/management/http/api/generate.sh:14`) installs `oapi-codegen@latest` rather than a pinned version. **This is a reproducibility gap** — re-running the script later may produce a different `types.gen.go`. Either pin the version in `generate.sh` (e.g. `@v2.7.0`) or document the pin in a `tools.go`. +- proto codegen has the protoc / protoc-gen-go version stamped in the generated file header (`proxy_service.pb.go:3-4`). +- Regenerate locally and confirm zero diff against the committed `types.gen.go` / `proxy_service.pb.go`. + +## Test coverage +| Test file | Locks down | +| --------- | ---------- | +| None in this scope | The proto and OpenAPI sources are tested transitively by the handler tests (`shared/management/http/handlers/agentnetwork/...`) and by the synthesizer/manager tests (`management/server/agentnetwork/...`). No round-trip serialisation test exists in the `proto/` or `api/` packages themselves. | +| `shared/management/proto/*_test.go` | (absent) | +| `shared/management/http/api/*_test.go` | (absent) | + +Acceptable for codegen artefacts, but a single golden-file test that re-runs `oapi-codegen` and `protoc` in CI and diffs against the checked-in files would close the reproducibility gap noted above. + +## Known limitations / explicit non-goals +- **No deprecation surface.** Old fields/RPCs are kept silently; there is no `[deprecated = true]` annotation on anything. Acceptable here because nothing is being removed. +- **No proto-side validation.** Numeric ranges (e.g. `window_seconds >= 60`, `cost_usd >= 0`, capture-byte clamps) are enforced in the OpenAPI schema via `minimum:` and inside Go code by the proxy/management, but `proto3` itself can't express them; downstream is expected to validate every message. +- **`MiddlewareConfig.config_json` is `bytes`** (`proxy_service.proto:163`) — opaque to the proto layer. Schema validity is the middleware factory's problem. This is a deliberate tradeoff (per the comment at 161-162) but worth flagging: a corrupted/malicious config_json can only fail at proxy apply time, not at the wire-decode step. +- **No catalog endpoint schema for the catalog itself** — the catalog data ships as a `GET /api/agent-network/catalog/providers` returning `[AgentNetworkCatalogProvider]` (`openapi.yml:13024`), but the catalog source-of-truth lives in `management/server/agentnetwork/catalog`, not here. +- The reaper / GC design was cut from scope; no reaper-related types appear here. + +## Cross-references +- Downstream: [management/store](20-management-store.md), [management/agentnetwork](21-management-agentnetwork.md), [management/handlers + wiring](22-management-handlers-wiring.md), [proxy/runtime](33-proxy-runtime.md) +- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md) +- Top-level: [../00-overview.md](../00-overview.md) diff --git a/docs/agent-networks/modules/20-management-store.md b/docs/agent-networks/modules/20-management-store.md new file mode 100644 index 000000000..1acc12611 --- /dev/null +++ b/docs/agent-networks/modules/20-management-store.md @@ -0,0 +1,112 @@ +# management/store — persistence for agent-network entities + +> **Risk level:** Medium — six brand-new tables behind AutoMigrate, one upsert-counter table that runs on the request hot path, and one column carrying an encrypted secret. +> **Backward-compat impact:** Additive (six new tables created by AutoMigrate; the `Store` interface gains 23 methods, but no existing column/index is touched). + +## Module boundary + +This module is the persistence layer for the Agent Network feature. Everything the management server stores about LLM proxying — providers, policies, guardrails, the per-account settings row, a usage-counter table written on every proxied LLM request, and the account-budget rules — flows through the methods added to `store.Store`. The module owns six tables, six entity types from `management/server/agentnetwork/types`, and a single hot-path upsert (`IncrementAgentNetworkConsumption`) consumed by the proxy fleet. + +Out of scope here: the catalog of provider definitions (compiled-in, no DB), the synthesizer/manager built on top of these CRUDs (covered in [21-management-agentnetwork.md](21-management-agentnetwork.md)), and the HTTP handlers that translate API requests into Save/Delete calls. + +## Files + +| Path | Role | +| ---- | ---- | +| `management/server/store/sql_store_agentnetwork.go` | gorm implementations of all 23 store methods | +| `management/server/store/sql_store_agentnetwork_budgetrule_test.go` | round-trip + account-scoping coverage against a real sqlite store | +| `management/server/store/sql_store.go` | one import, six entities appended to the `AutoMigrate` slice (sql_store.go:40, sql_store.go:141-142) | +| `management/server/store/store.go` | 23 methods added to the `Store` interface (store.go:328-354) | +| `management/server/store/store_mock_agentnetwork.go` | mockgen output for the new interface surface | + +## Tables added / migrations + +All six tables are created by `db.AutoMigrate` invoked from `NewSqlStore` at sql_store.go:133-143. There is no hand-rolled SQL migration script — the schema is whatever GORM derives from the struct tags. + +- `agent_network_providers` — `Provider.TableName()` at provider.go:76. PK `id`, index on `account_id`, named index `idx_agent_network_provider` on `provider_id`. Carries an at-rest-encrypted `api_key` and ed25519 `session_private_key` (provider.go:35,56). `extra_values` and `models` are JSON blobs (`serializer:json`). +- `agent_network_policies` — `Policy.TableName()` at policy.go:70. PK `id`, index on `account_id`. JSON columns: `source_groups`, `destination_provider_ids`, `guardrail_ids`, `limits`. +- `agent_network_guardrails` — `Guardrail.TableName()` at guardrail.go:41. PK `id`, index on `account_id`. JSON `checks`. +- `agent_network_settings` — `Settings.TableName()` at settings.go:33. PK `account_id` (one row per account), named index `idx_agent_network_settings_cluster_subdomain` on `subdomain` only — the index name implies a composite, but only one column is tagged. +- `agent_network_consumption` — `Consumption.TableName()` at consumption.go:46. Composite PK across `(account_id, dim_kind, dim_id, window_seconds, window_start_utc)` — the same tuple the upsert keys on. +- `agent_network_budget_rules` — `AccountBudgetRule.TableName()` at budgetrule.go:35. PK `id`, index on `account_id`. JSON `target_groups`, `target_users`, `limits`. + +## CRUD surface added + +Provider, Policy, Guardrail, BudgetRule follow the same pattern: `GetByID`, `GetAccount` (list), `Save` (upsert), `Delete`, with account-scoping enforced by the existing `accountAndIDQueryCondition` / `accountIDCondition` constants (sql_store.go:59-62). Provider additionally exposes `GetAllAgentNetworkProviders` (cross-account, used by the synthesizer). Settings exposes `Get`/`GetByCluster`/`Save` (no delete — one row per account, created on first save). Consumption exposes the upsert `Increment`, a point `Get`, and a cross-window `List`. + +## Architecture & flow + +```mermaid +flowchart LR + handlers["HTTP handlers
(management/server/agentnetwork)"] -->|Save/Delete| iface["Store interface
store.go:328-354"] + manager["agentnetwork.Manager"] -->|Get*| iface + synth["synthesizer
(global)"] -->|GetAllAgentNetworkProviders| iface + proxy["proxy fleet
(hot path)"] -->|IncrementAgentNetworkConsumption| iface + iface --> sql["SqlStore methods
sql_store_agentnetwork.go"] + iface -.gomock.-> mock["MockStore
store_mock_agentnetwork.go"] + sql --> gorm["gorm.DB"] + gorm --> tables[("6 tables
agent_network_*")] + sql --> enc["crypt.FieldEncrypt
(provider only)"] +``` + +Reads decrypt provider secrets in-place; writes do `provider.Copy().EncryptSensitiveData(...)` before `db.Save` so the caller's in-memory object keeps the plaintext `api_key` (sql_store_agentnetwork.go:88-102). Every list/get takes a `LockingStrength` and applies `clause.Locking{Strength: ...}` when non-`None` — matching the rest of the store. The upsert path uses `clause.OnConflict` with `gorm.Expr` server-side increments so concurrent proxy nodes converge without read-modify-write races (sql_store_agentnetwork.go:321-335). + +## Invariants enforced at the store layer + +- **Account scoping.** Every entity-by-ID method keys on `account_id = ? and id = ?`; no cross-tenant leak path through the API is reachable as long as callers always pass the auth'd `accountID` (sql_store_agentnetwork.go:70,141,201,429). +- **NotFound mapping.** `gorm.ErrRecordNotFound` is translated to typed `status.NewAgentNetwork*NotFoundError`; `Delete*` returns NotFound when `RowsAffected == 0` (sql_store_agentnetwork.go:111-113,171-173,231-233,461-463). +- **Provider secret encryption at rest.** `SaveAgentNetworkProvider` always encrypts before persist; `Get*` always decrypts after read. The plaintext `api_key` never reaches the DB through this layer (sql_store_agentnetwork.go:31,54,80,90). +- **Consumption monotonicity.** The upsert only ever issues `col = col + ?` for the three counter columns — no decrement path exists (sql_store_agentnetwork.go:330-332). +- **Window alignment is the caller's responsibility.** The store stamps `WindowStartUTC` as-passed; alignment to epoch happens in `types.WindowStart` at consumption.go:51-58. +- **Settings has no Delete.** Intentional — one row per account, created on first save; the row sticks around for the account lifetime. + +## Things to scrutinize + +### Correctness +- `SaveAgentNetworkProvider` saves the copy (sql_store_agentnetwork.go:95). The caller's in-memory pointer therefore keeps plaintext `api_key` and any `CreatedAt`/`UpdatedAt` gorm autofills land on the copy, not the original. Callers that need synced timestamps must re-fetch. +- `IncrementAgentNetworkConsumption`'s `Create` provides initial counter values (`TokensInput: tokensIn`, etc.) in the row, and on conflict the assignments add the same deltas to the existing values. The insert-vs-update arithmetic is consistent. Cross-check that no engine in use (sqlite, postgres, mysql) silently rejects the `OnConflict` clause — GORM emits engine-specific SQL but `ON DUPLICATE KEY UPDATE` (mysql) vs `ON CONFLICT (...)` (sqlite/postgres) need their unique constraint to match the composite PK on `agent_network_consumption`; it does, by construction. +- `IncrementAgentNetworkConsumption` writes `updated_at: time.Now().UTC()` literally inside the assignments map (sql_store_agentnetwork.go:333) — fine, but it's a Go-side timestamp captured at call time, not a DB-side `now()`. Acceptable for an audit field. +- `GetAgentNetworkConsumption` returns a zero-valued non-nil row on `ErrRecordNotFound` (sql_store_agentnetwork.go:364-371). Document or rename — a typed sentinel error would be more orthodox; callers must know not to error-check. + +### Concurrency / transactions +- Hot-path `IncrementAgentNetworkConsumption` runs outside any explicit transaction; concurrency safety relies entirely on the DB serialising the `ON CONFLICT` upsert against the composite PK. This is correct for postgres and mysql; for sqlite it serialises behind the single writer. +- `SaveAgentNetworkSettings` is a blind upsert with no version/etag — concurrent writes from two operators last-write-wins on the collection-toggle flags (settings.go:23-25). Acceptable for admin-curated state but worth flagging. +- `Save*Provider` uses `db.Save` on a struct with a PK already set — GORM emits UPDATE or INSERT based on row existence. No upsert clause is attached, so a race between two creates with the same generated `xid` (vanishingly unlikely) would surface as a PK violation. + +### Migration safety +- All six tables ride `AutoMigrate` (sql_store.go:141-142). AutoMigrate is additive: new columns get added, but it never drops columns nor narrows types. Three `bool` columns on `agent_network_settings` (`EnableLogCollection`, `EnablePromptCollection`, `RedactPii`) default to false at the GORM/DDL layer for existing rows; the test at sql_store_agentnetwork_budgetrule_test.go:83-112 locks that down on a fresh sqlite. Verify postgres/mysql produce the same default. +- The named index `idx_agent_network_settings_cluster_subdomain` on settings.go:15 is declared on only `subdomain`. Either the cluster column also needs `gorm:"index:idx_agent_network_settings_cluster_subdomain"` to make it composite, or the name is misleading. +- The named index `idx_agent_network_provider` on `Provider.ProviderID` (provider.go:30) is *not* unique and not scoped to account — two providers in the same account with the same `provider_id` are permitted at the DB layer; uniqueness, if any, must live above the store. + +### Backward compatibility +- Net additive. No removed methods, no renamed columns, no schema change to existing tables. Existing deployments running a prior binary continue to work; the first boot of the new binary creates the six tables. +- The `Store` interface grows by 23 methods (store.go:330-354); any non-mock external implementer of `store.Store` will fail to compile. The repo only has `SqlStore` + `MockStore`, both updated. + +### Performance (indexes, N+1) +- All by-account list queries hit the `idx_account_id` per-table index. No N+1: list methods return the full slice in one query. +- `GetAgentNetworkSettingsByCluster` (sql_store_agentnetwork.go:263-277) does a tablescan on `cluster` — no index. Tolerable for the bootstrap label generator (one-shot at provisioning) but worth noting if the call moves onto a hot path. +- `ListAgentNetworkConsumption` returns every row ever recorded for the account (sql_store_agentnetwork.go:382-400) — unbounded growth, no `LIMIT`, no time filter. With one row per (dim, window) per request burst, this table grows fastest of the six; a retention job + a paginated list method are obvious follow-ups. + +## Test coverage + +| Test file | Locks down | +| --------- | ---------- | +| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkBudgetRule_RealStore_RoundTrip` | full save → reload of `AccountBudgetRule` including the JSON-serialised `PolicyLimits`, target slices, double-delete returns NotFound (lines 18-59) | +| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkBudgetRule_RealStore_ScopedByAccount` | cross-account isolation for budget rules (lines 63-78) | +| `sql_store_agentnetwork_budgetrule_test.go::TestAgentNetworkSettings_RealStore_CollectionTogglesRoundTrip` | collection toggles default off, survive save/reload at the set values (lines 83-112) | + +Gap: there is no store-level test for providers (encryption round-trip), policies, guardrails, or `IncrementAgentNetworkConsumption` (concurrent upsert, window-key uniqueness). The consumption upsert is the most performance-sensitive method in this module and the only one without a real-sqlite test. + +## Known limitations / explicit non-goals + +- No retention / GC for `agent_network_consumption`. +- No `Delete` for `Settings` (one row per account, cleared with the account). +- No DB-engine-specific tuning — the same struct tags drive sqlite, mysql, postgres. +- Provider `extra_values` and `models` are JSON blobs; querying inside them is not supported by design. +- `GetAgentNetworkConsumption` "not-found = zero row" contract is convenient but unconventional. + +## Cross-references + +- Upstream: [shared/api](10-shared-api.md), [management/agentnetwork](21-management-agentnetwork.md) +- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md) +- Top-level: [../00-overview.md](../00-overview.md) diff --git a/docs/agent-networks/modules/21-management-agentnetwork.md b/docs/agent-networks/modules/21-management-agentnetwork.md new file mode 100644 index 000000000..b64c1ba20 --- /dev/null +++ b/docs/agent-networks/modules/21-management-agentnetwork.md @@ -0,0 +1,225 @@ +# management/agentnetwork — domain layer + synth pipeline + +> **Risk level:** High — central business logic + budget enforcement + the source of every middleware-chain change the proxy executes. +> **Backward-compat impact:** Additive within the agent-network surface; one **behavioural difference for opted-out accounts** in parser capture (the capture flag is stamped explicitly false instead of being absent — see capture-pointer semantics below). Non-agent-network proxy services are untouched (the synth chain only ships on `agent-net-svc-*` targets). + +## Module boundary + +`management/server/agentnetwork` owns every agent-network entity (providers, policies, guardrails, account budget rules, per-account settings, consumption rows) and **translates them into the in-memory `*rpservice.Service` that the reverse-proxy controller turns into `proto.ProxyMapping`s and pushes to clusters**. It is the *only* writer of the agent-network middleware chain. + +Inside the package: `manager.go` is the CRUD + permissions-gated facade; `synthesizer.go` walks settings + providers + policies + guardrails and emits the per-account service plus every middleware's JSON config; `policyselect.go` runs per-request attribution (min-wins account ceiling, then "drain bigger pool first"); `reconcile.go` diffs successive synth outputs and emits precise Create/Update/Delete proxy-mapping updates plus a peer-map refresh. `labelgen/` mints DNS-safe subdomain labels; `catalog/` is the static provider catalogue; `types/` carries gorm entity structs. The `_realstack_test.go` files in the parent `management/server/` directory exercise the manager + network-map controller end-to-end with no mocks. + +## Files + +| Path | Role | +| ---- | ---- | +| `agentnetwork/manager.go` | Manager interface + CRUD + permission gates + bootstrap-settings + reconcile trigger | +| `agentnetwork/synthesizer.go` | Settings/policy → wire-format synthesis; sole writer of the proxy middleware chain | +| `agentnetwork/policyselect.go` | Per-request policy attribution + account-budget ceiling (min-wins) | +| `agentnetwork/reconcile.go` | Per-account synth diff vs in-memory cache → Create/Update/Delete | +| `agentnetwork/catalog/catalog.go` | Static provider catalogue (auth headers, identity-injection shapes) | +| `agentnetwork/labelgen/{labelgen,words}.go` | DNS-safe subdomain picker + curated wordlist | +| `agentnetwork/types/provider.go` | Provider entity + APIKey + Models + ExtraValues + SessionKeys | +| `agentnetwork/types/policy.go` | Policy entity + `PolicyLimits` (token + budget) | +| `agentnetwork/types/guardrail.go` | Guardrail entity (`ModelAllowlist`, `PromptCapture`) | +| `agentnetwork/types/budgetrule.go` | `AccountBudgetRule` (reuses `PolicyLimits`) | +| `agentnetwork/types/settings.go` | Per-account `Settings` (Cluster, Subdomain, 3 toggles) | +| `agentnetwork/types/consumption.go` | `Consumption` row + `WindowStart` aligner | +| `agentnetwork/{synthesizer,policyselect,reconcile,wire_shape}_*test.go` | See test coverage table | +| `agentnetwork/types/consumption_test.go` | `WindowStart` alignment proofs | +| `agentnetwork/labelgen/labelgen_test.go` | Deterministic picks + exhaustion + fallback | +| `management/server/agentnetwork_realstack_test.go` | No-mock provider CRUD → network-map fan-out | +| `management/server/agentnetwork_budgetrule_realstack_test.go` | No-mock budget-rule CRUD + settings preserve-immutable | + +## Architecture & flow + +### Synthesis (settings/policy → wire format) + +```mermaid +flowchart TD + A[Mutation: provider/policy/guardrail/settings] --> B[managerImpl.reconcile accountID] + B --> C{proxyController nil?} + C -- yes --> D[accountManager.UpdateAccountPeers only] + C -- no --> E[SynthesizeServices] + E --> F[loadSettings — NotFound returns ok=false, no synth] + F --> G[filterEnabledProviders sorted by CreatedAt] + G --> H[filterEnabledPolicies] + H --> I[backfillProviderSessionKeys if missing] + I --> J[indexProviderGroups: providerID -> sorted source groups] + J --> K[buildRouterConfigJSON drops orphan providers] + J --> L[buildIdentityInjectConfigJSON per catalog entry] + H --> M[mergeGuardrails: union allowlist, OR redact] + M --> N[applyAccountCollectionControls account toggle = SOLE capture control] + N --> O[marshalGuardrailConfig] + K --> P[buildMiddlewareChain 8 middleware entries] + L --> P + O --> P + P --> Q[buildAccountService: AccessGroups=union source groups, noop.invalid target] + Q --> R[reconcile.diffMappings vs cache] + R --> S[SendServiceUpdateToCluster CREATE/MODIFY/REMOVE] + R --> T[accountManager.UpdateAccountPeers — fans synth ACLs into network map] +``` + +### Budget rule resolution (min-wins, group+user bound) + +```mermaid +flowchart TD + A[SelectPolicyForRequest in] --> B[checkAccountBudget — runs FIRST, independent of policies] + B --> C[GetAccountAgentNetworkBudgetRules] + C --> D{for each enabled rule} + D --> E{budgetRuleApplies?} + E -- no --> D + E -- yes --> F[attrGroup = lowestIntersect TargetGroups, in.GroupIDs] + F --> G{Token cap enabled?} + G -- yes --> H[evalTokenCap user dim + group dim] + H --> I{exhausted?} + I -- yes --> J[DENY: llm_account.token_cap_exceeded - STOP] + I -- no --> K{Budget cap enabled?} + G -- no --> K + K -- yes --> L[evalBudgetCap user dim + group dim] + L --> M{exhausted?} + M -- yes --> N[DENY: llm_account.budget_cap_exceeded - STOP] + M -- no --> D + K -- no --> D + D --> O[All rules passed -> fall through to per-policy selection] +``` + +Key invariant: **rules are checked sequentially and ANY exhausted rule denies (all-must-pass / min-wins).** Untargeted rules (`len(TargetGroups)==0 && len(TargetUsers)==0`) apply to every caller (`policyselect.go:393`). + +### Policy selection (per-peer, per-request) + +```mermaid +flowchart TD + A[Account-budget gate passed] --> B[GetAccountAgentNetworkPolicies] + B --> C[filterApplicablePolicies enabled + provider match + group intersect] + C --> D{candidates empty?} + D -- yes --> E[Allow, empty SelectedPolicyID] + D -- no --> F[scoreCandidates -> scoreOne per policy] + F --> G[scoreOne: attrGroup + window] + G --> H{any cap exhausted?} + H -- yes --> I[Drop policy; record last deny code] + H -- no --> K[Keep as live candidate] + F --> L{live candidates exist?} + L -- no --> M[Deny with last exhaustion code] + L -- yes --> N[Sort: uncapped wins -> larger group token -> group budget -> user token -> user budget -> oldest CreatedAt] + N --> O[winner = scored 0] + O --> P[Allow + SelectedPolicyID + AttributionGroupID + WindowSeconds] +``` + +End-to-end: a mutation calls `managerImpl.reconcile(ctx, accountID)` (`manager.go:205,239,...`). Reconcile defers an `accountManager.UpdateAccountPeers` so the network-map controller re-runs and `injectAllProxyPolicies` picks up the new access groups; with a `proxyController` wired, it re-synthesizes the service, diffs against `reconcileCache[accountID]` (guarded by `reconcileMu`), and emits proto mappings to the cluster derived from the mapping's domain (`reconcile.go:120`). Synthesis is stateless and idempotent. Sole persistent side effect: `backfillProviderSessionKeys` (`synthesizer.go:249`) mints ed25519 keys on legacy provider rows and writes them back. + +At request time the path is independent: the proxy calls `SelectPolicyForRequest` (`policyselect.go:56`); account-budget ceiling first, then per-policy scoring. Token + budget caps share `evalTokenCap` / `evalBudgetCap` — same primitive for account rules and policy limits, `label` differentiates the deny reason. After a served request, `RecordAccountBudgetUsage` (`policyselect.go:415`) fans deltas to every applicable rule's distinct `(dim_kind, dim_id, window)` tuple, deduplicating to prevent double-count when two rules share target+window. + +## Public contracts + +- **Manager interface** (`manager.go:48-80`): CRUD for `Providers/Policies/Guardrails/BudgetRules`; `GetSettings/UpdateSettings` (cluster + subdomain immutable, only the three toggles mutate); `ListConsumption/RecordConsumption(account, kind, dimID, windowSec, in, out, USD)`; `RecordAccountBudgetUsage(account, user, groups, in, out, USD)`; `SelectPolicyForRequest(ctx, PolicySelectionInput) → *PolicySelectionResult{Allow, SelectedPolicyID, AttributionGroupID, WindowSeconds, DenyCode, DenyReason}`. +- **`PolicySelectionInput`** (`manager.go:85-90`): `{AccountID, UserID, GroupIDs, ProviderID}` — populated by the proxy from CapturedData + `llm_router` resolution. +- **Synthesized middleware chain** (`synthesizer.go:576-657`), order load-bearing — response slot runs reverse-of-slice: + + | Slot | Idx | ID | ConfigJSON shape | CanMutate | + | --- | --- | --- | --- | --- | + | on_request | 0 | `llm_request_parser` | `{"capture_prompt": , "redact_pii"?: true}` | – | + | on_request | 1 | `llm_router` | `{"providers":[{id, models[], upstream_*, auth_header_*, allowed_group_ids[]}]}` | **true** | + | on_request | 2 | `llm_limit_check` | `{}` | – | + | on_request | 3 | `llm_identity_inject` | `{"providers":[{provider_id, header_pair?, json_metadata?, extra_headers?}]}` | **true** | + | on_request | 4 | `llm_guardrail` | `{"model_allowlist"?, "prompt_capture":{enabled,redact_pii}}` | – | + | on_response | 5 | `llm_limit_record` | `{}` (runs LAST at runtime) | – | + | on_response | 6 | `cost_meter` | `{}` | – | + | on_response | 7 | `llm_response_parser` | `{"capture_completion": , "redact_pii"?: true}` | – | +- **Synthesized service shape** (`synthesizer.go:739`): `Mode=HTTP`, `Private=true`, `Domain=.`, `AccessGroups=unionSourceGroups(enabledPolicies)`, one `TargetTypeCluster` target with `Host=noop.invalid:443` (router rewrites per request), `Options.{DirectUpstream,AgentNetwork}=true`, `DisableAccessLog=!settings.EnableLogCollection`, `CaptureMax{Req,Resp}Bytes=1<<20`, `CaptureContentTypes=["application/json","text/event-stream"]`. + +## Invariants + +- **Min-wins / all-must-pass for account budget rules** (`checkAccountBudget`, `policyselect.go:353`): every applicable enabled rule is checked; first exhausted cap denies. Untargeted rules bind every caller. +- **Account toggle is the SOLE control for capture enablement.** `applyAccountCollectionControls` (`synthesizer.go:701`) sets `merged.PromptCapture.Enabled = settings.EnablePromptCollection` *unconditionally*. +- **Capture-pointer semantics on parser configs** — see "Things to scrutinize" below. +- **`EnableLogCollection` ↔ `DisableAccessLog` is the only access-log toggle** (`synthesizer.go:770`). Default off ⇒ access log suppressed. +- **`RedactPii` flows verbatim to BOTH parsers** (`synthesizer.go:584-585`) and is OR'd into the merged guardrail (`synthesizer.go:706`). +- **Cluster and Subdomain are immutable on Settings.** `UpdateSettings` reloads existing row and overlays only the three toggles (`manager.go:558-561`). +- **Orphan providers (no enabled policy authorises them) NEVER reach the router** (`synthesizer.go:351-357`); skipped from `identity_inject` for symmetry. +- **Provider creation refuses empty `api_key`** (`manager.go:175`); **deletion refuses while any policy still references it** (`manager.go:265-273`). +- **Session keypair stability across provider edits** (`manager.go:226-228`) — server-managed, copied through every `UpdateProvider`, never API-surfaced. + +## Things to scrutinize + +### Correctness + +- **Capture-pointer semantics — `*bool` vs `bool`.** Three states, owned by separate sides: + - **Wire JSON this module emits:** `buildParserConfigJSON` (`synthesizer.go:678-693`) *always* stamps the capture field. Agent-network targets ship `"capture_prompt": false` or `"capture_prompt": true` — never absent. Same for `"capture_completion"`. The happy-path test pins `{"capture_prompt":false}` (`synthesizer_test.go:174`). + - **Proxy-side parser config (consumer):** parsers decode into `*bool`. Matrix: + - `nil` (field absent) → **legacy default = emit**. Preserved for non-agent-network callers and pre-existing tests (the backward-compat hook). + - `false` (field present, value false) → **suppress emission entirely**. The behaviour for opted-out agent-network accounts. Without this, `enable_log_collection=true` + `enable_prompt_collection=false` would leak raw user input AND raw model output to the access log. + - `true` → emit normally. + - **Why the synth always stamps a value:** an agent-network mapping omitting the field would hit legacy "always emit" and re-introduce the leak. The `json.Marshal` error fallback at `synthesizer.go:687` degrades to `{}` — comment-claimed unreachable, but if ever fired re-introduces the leak. Consider fail-closed (return literal `{"capture_prompt":false}`) instead. +- **`scoreCandidates` non-cumulative deny code.** Only the *last* exhausted policy's deny code survives (`policyselect.go:188-190`). Iteration order is store's natural order. Auth signal is `len(scored)==0`, so this is informational only — verify no UI depends on "first exhausted policy" semantics. +- **`effectiveWindowSeconds` token-wins tiebreak.** When both halves are enabled with different windows, token's window wins (`policyselect.go:482`). Verify `RecordLLMUsage` increments against the winning window only. +- **`RecordAccountBudgetUsage` dedup.** Two rules with the same `(kind, dim_id, window)` would double-count without the `tuples` map (`policyselect.go:434-449`). Key includes all three dimensions — correct. +- **Fail-closed on bad provider:** unknown catalog id (`synthesizer.go:794-796`) or empty API key (`synthesizer.go:801-803`) drops the **entire** account's synth, not just the bad provider. Confirm matches operator UX. + +### Security + +- **Redact OR-merge:** merged `RedactPii` = account OR guardrail (`synthesizer.go:706`). **Parser-side flag is `settings.RedactPii` only, NOT the OR** — a guardrail-only opt-in does not propagate to parsers. Correct because the account toggle gates capture, but worth noting on the proxy side. +- **Group resolution must not leak across accounts.** Every store call carries `accountID` (`policyselect.go:73, 286, 298, 322, 334, 354`); `lowestIntersect` uses caller's claimed groups only (`policyselect.go:494`). Risk surface is upstream (handler populates `in.GroupIDs`). +- **`UpdateSettings` preserves immutable Cluster + Subdomain** (`manager.go:558`). A client can't rebind the cluster. +- **Provider session keypair backfill writes through `SaveAgentNetworkProvider`** (`synthesizer.go:256`) from a read-shaped call. Idempotent → worst case is a wasted write under concurrent reconcile + snapshot. + +### Concurrency + +- **`reconcileMu`** guards `reconcileCache`. Lock window is narrow — compute diff inside, send outside (`reconcile.go:56-68`). +- **`labelRngMu`** guards `labelRng` because `math/rand.Source` is unsafe for concurrent use (`manager.go:638-640`). +- **Real-store tests** use `store.NewTestStoreFromSQL` with `t.TempDir()` per test — no shared state, no `t.Parallel()`. +- **`RecordAccountBudgetUsage` dedup `tuples` map is per-call;** concurrent calls fan out fully — correct (each request's tokens book once per applicable rule). +- **Deferred `UpdateAccountPeers` runs inline after the proxy push** (`reconcile.go:28-35`); a slow call stretches CRUD response time. + +### Backward compatibility + +- **Capture-pointer semantics (restated):** non-agent-network callers see no field → legacy nil-default emit, identical to pre-PR. Agent-network targets always carry an explicit `capture_*` value. +- **`TestSynthesizeServices_HappyPath` was updated:** request-parser config moved from `{}` to `{"capture_prompt":false}` (`synthesizer_test.go:174`). External snapshot tests against synth output need updating. +- **`MergedGuardrails` retains zeroed `TokenLimits`/`Budget`/`Retention`** even though `Policy.Limits` carries the real values now; `llm_limit_check` is the authoritative enforcement. Comment at `synthesizer.go:940-948` calls this out. + +### Performance + +- **`SynthesizeServices` runs on every controller tick / mutation reconcile.** Cost: 4 store reads + optional per-provider keypair backfill. Sort + index + merge are O(N log N) / O(P × G); dominant cost is JSON marshalling. No nested loops escape these dimensions. +- **`reconcile.diffMappings` is O(N + M)** with N=M=1 per account today — effectively constant. +- **`SynthesizeServicesForCluster`** (`synthesizer.go:71`) walks every account on a cluster; per-account failures are **swallowed** (`synthesizer.go:91-93`) so a single misconfigured account doesn't drop the cluster. Runs per proxy reconnect. + +### Observability + +- **Activity codes:** `AgentNetwork{Provider,Policy,Guardrail,BudgetRule}{Created,Updated,Deleted}`; `AgentNetworkSettingsUpdated` with `log_collection/prompt_collection/redact_pii` payload (`manager.go:567-571`). **No activity code for `SelectPolicyForRequest` denies** — surfaced via proxy access log only (likely intentional given volume). +- **Deny codes** namespaced: `llm_policy.{token,budget}_cap_exceeded`, `llm_account.{token,budget}_cap_exceeded` (`policyselect.go:18-26`). +- **Reconcile failures are logged at warn and swallowed** (`reconcile.go:42-44`). Persistent synth failures (e.g. unknown catalog id) silently keep the proxy out of sync — consider a manager-level synth-health surface if this becomes a support burden. + +## Test coverage + +| Test file | Locks down | +| --------- | ---------- | +| `synthesizer_test.go` | Mock-store: `HappyPath` (8-mw chain ordering, `{"capture_prompt":false}` baseline); `No{Settings,Providers}`; `Disabled{Provider,Policy}_NoService`; `RouterConfigOrdering`; `PolicyCheckConfig_UnionsSourceGroups`; `OrphanProvider_HasEmptyAllowedGroups`; identity-inject for LiteLLM / Bifrost (overrides + partial disable) / Cloudflare / Portkey / Vercel / OpenRouter / generic non-customizable; `GuardrailMerge_AllowlistUnion_LimitsRestrictive`; `BackfillsMissingSessionKeys`; `HTTPUpstream_KeepsExplicitPort`; `UpstreamURLPath_FlowsToRouter`; `UnknownProviderID_FailsClosed`; `EmptyAPIKey_FailsClosed`. | +| `synthesizer_realstore_test.go` | Real-sqlite: `SurvivesStatusToggle` reproduces the disable/re-enable 403 regression; `Reconcile_RealStore_PushesPrivateAfterStatusToggle` extends through reconcile push. | +| `synthesizer_guardrail_realstore_test.go` | `PromptCaptureAccountIsSoleControl`; `PromptCaptureFlowsWhenAccountOptsIn`; `AccountRedactWithoutGuardrailRedact`; `NoGuardrail_CaptureOff`. | +| `synthesizer_log_collection_realstore_test.go` | `LogCollection{Off_SuppressesAccessLog,On_PermitsAccessLog}` — verifies `DisableAccessLog` propagation through `ToProtoMapping`. | +| `synthesizer_parser_redact_realstore_test.go` | **Capture-pointer regression suite:** `ParserConfigsCarryRedactPii`; `ParserConfigsSuppressCaptureWhenLogCollectionOnly` (log=on/prompt=off ⇒ both capture flags false); `ParserConfigsOmitRedactPiiWhenOff`. | +| `policyselect_test.go` | Mock-store: `NoApplicablePolicies`; `AllowWithLowestGroupAttribution`; `LargerPoolWinsAcrossUsageLevels`; `StaysOnLargerPoolAfterPartialDrain`; `FallsThroughToSmallerPoolWhenLargerExhausted`; `TiebreakBy{LargerGroupPool,CreatedAt}`; `DeniesWhenAllExhausted`; `UncappedPolicyAlwaysWinsAgainstCapped`; `DisabledPolicyIgnored`; `StoreErrorPropagates`; `RejectsEmptyAccount`; `SharesGroupCounterAcrossPolicies`; `AntiFallThroughOnLowestGroup`; `BudgetOnlyExhaustionDenies`; `BudgetTighterThanTokenWins`. | +| `policyselect_realstore_test.go` | Real-sqlite regression guard: `NoApplicablePolicies`; `AllowAndLowestGroupAttribution`; `LargerPoolWins_FallsThroughWhenExhausted`; `BudgetCapDenies`; `GroupCounterSharedAcrossPolicies`; `DisabledPolicyIgnored`. | +| `policyselect_account_realstore_test.go` | Account budget rules: `AccountCeilingBindsEvenWithUncappedPolicy` (min-wins); `AccountGroupCeiling`; `AccountTargetUsersBindsOnlyThatUser`; `AccountRuleRecordsToOwnWindow`. | +| `reconcile_test.go` | `FirstSynth_EmitsCreate`; `NoChange_EmitsNothingExtra` (re-push as Modified — verify desired); `PolicyRemoved_EmitsDelete`; `NilProxyController_NoOp`; `EmptyAccountID_NoOp`; `ClusterFromMapping`. | +| `wire_shape_test.go` | `TestSynthesizedService_WireShape` — proto-shape lockdown via `ToProtoMapping`. Catches "service not matching" (mapping reaches proxy but no SNI/HTTP route). Asserts ID, Domain, Mode, AuthToken, `Private`, `Auth.Oidc=false`, one path `/` + `https://noop.invalid/`, 8 middlewares with correct slot enums, router config `auth_header_value="Bearer sk-test-key"`. | +| `labelgen/labelgen_test.go` | `PickUnique_{DeterministicWithSeededRng,AvoidsTakenWordsWhenMostAreReserved,FallsBackWhenAllReserved}`; `UniqueWords_DropsDuplicates`. | +| `types/consumption_test.go` | `WindowStart_{AlignedToUnixEpoch,WithinWindowConverges,AcrossWindowsDiverges,DifferentWindowsHaveDifferentBuckets,SubMinuteAndMinuteAlignment,ZeroWindowReturnsInputUTC}`. Bucket alignment so multi-node reads converge. | +| `agentnetwork_realstack_test.go` | `ProviderCRUD_FansOutToProxyAndClientPeers` — no-mock end-to-end through real account manager + network-map + agentnetwork: provider create propagates the updated map to both proxy peer and client peer with the synth DNS surface. | +| `agentnetwork_budgetrule_realstack_test.go` | `BudgetRuleCRUD_RealManager`; `UpdateSettings_PreservesImmutableAndTogglesCollection`. | + +## Known limitations / explicit non-goals + +- **`MergedGuardrails.TokenLimits/Budget/Retention` emit at zero** (`synthesizer.go:940-948`); real enforcement is `Policy.Limits` via `llm_limit_check`. Future cleanup implied. +- **Session keys picked from first enabled provider by created_at** (`pickServiceSessionKeys`, `synthesizer.go:270`). Existing session cookies survive provider edits only while the first-by-CreatedAt provider stays in place. Document for operators. +- **Reconcile failures silently swallowed** (`reconcile.go:42-44`). Persistent failures keep the proxy out of sync until the next reconcile. +- **`scoreCandidates` exposes only the LAST exhaustion's deny code** when multiple policies are exhausted. +- **`bootstrapSettingsIfNeeded` failure is non-fatal to provider create** (`manager.go:200`): provider lands, synth is no-op until the next provider create retries the bootstrap. +- **Budget rules do not trigger a reconcile** (`manager.go:476-477`). Request-time evaluation only; new rules take effect on the next request without a proxy push. + +## Cross-references + +- **Upstream:** [shared/api](10-shared-api.md), [management/store](20-management-store.md), reverseproxy `service`/`proxy`/`sessionkey` packages, `management/server/permissions` + `activity`. +- **Downstream:** [management/handlers (HTTP wiring)](22-management-handlers-wiring.md), [proxy/middleware-builtin](31-proxy-middleware-builtin.md), network-map controller (`injectAllProxyPolicies` fan-out). +- **End-to-end flow:** [../01-end-to-end-flows.md](../01-end-to-end-flows.md) — "Provider create → reconcile → proxy push → peer map refresh" and "request → policy select → record" diagrams. +- **Top-level:** [../00-overview.md](../00-overview.md) diff --git a/docs/agent-networks/modules/22-management-handlers-wiring.md b/docs/agent-networks/modules/22-management-handlers-wiring.md new file mode 100644 index 000000000..9b8a47445 --- /dev/null +++ b/docs/agent-networks/modules/22-management-handlers-wiring.md @@ -0,0 +1,203 @@ +# management/handlers + wiring — HTTP API + gRPC delivery + +> **Risk level:** Medium — the surface is mostly additive, but two changes are load-bearing: `injectAllProxyPolicies` runs on every per-peer compute, and `shallowCloneMapping` must round-trip `Private` (a missed field silently breaks every MODIFIED). +> **Backward-compat impact:** Additive on the wire (new routes, new RPCs, new proto fields, new gorm column on `AccessLogEntry`). One management-internal break: `nbhttp.NewAPIHandler` gains a trailing `agentNetworkManager` parameter; `nil` is tolerated and silently skips route registration. + +## Module boundary + +This module is the seam between the public Agent Network HTTP API and the proxy fleet that serves agent traffic. North side: a `/api/agent-network/*` surface (providers, policies, guardrails, budget rules, settings, consumption) on the existing gorilla router, delegating to `agentnetwork.Manager`. Handlers are thin — they translate `api.*` ↔ `types.*`, validate shape, forward. RBAC and event emission stay inside the manager (`manager.go:680-682`). + +South side: `ProxyServiceServer` (`proxy.go`) learns to (a) ship synth services to a proxy on initial snapshot, (b) resolve agent-network domains in `getServiceByDomain` for OIDC/session/tunnel-peer flows, (c) gate LLM requests via `CheckLLMPolicyLimits` + `RecordLLMUsage`, (d) preserve `Private` through `shallowCloneMapping` so per-proxy live updates don't silently flip services public. The network_map controller prepends synth services to `account.Services` on every per-peer compute; `accesslogentry.go` gains an indexed `AgentNetwork` column so the dashboard can filter cheaply. + +## Files + +| Path | Role | +| ---- | ---- | +| `handlers/agentnetwork/providers_handler.go` | Catalog + provider CRUD + central `AddEndpoints` | +| `handlers/agentnetwork/policies_handler.go` | Policy CRUD + shared `validatePolicy*` | +| `handlers/agentnetwork/guardrails_handler.go` | Guardrail CRUD | +| `handlers/agentnetwork/budget_handler.go` | Account-level budget rule CRUD | +| `handlers/agentnetwork/settings_handler.go` | GET (200+`null` if unbootstrapped) + PUT toggles | +| `handlers/agentnetwork/consumption_handler.go` | Read-only consumption rows | +| `handlers/agentnetwork/handlers_test.go` | Real-store fixture; wire round-trip + validation | +| `handlers/agentnetwork/budget_handler_test.go` | Budget-rule + settings toggles | +| `server/http/handler.go` | New `agentNetworkManager` arg; conditional `AddEndpoints` | +| `server/permissions/modules/module.go` | New `AgentNetwork` module key | +| `internals/server/boot.go` | Wires synthesiser adapter + limits service into proxy server | +| `internals/server/modules.go` | `AgentNetworkManager()` lazy-create node | +| `internals/controllers/network_map/controller/controller.go` | `injectAllProxyPolicies` replaces 4 `InjectProxyPolicies` calls | +| `internals/controllers/network_map/controller/repository.go` | `SynthesizeAgentNetworkServices` repo method | +| `internals/modules/reverseproxy/service/service.go` | `MiddlewareConfig`, capture limits, `AgentNetwork`, `DisableAccessLog` + proto | +| `internals/modules/reverseproxy/accesslogs/accesslogentry.go` | Indexed `AgentNetwork bool` from proto | +| `internals/shared/grpc/proxy.go` | Synth wiring, 2 RPCs, domain fallback, `Private` in clone | +| `internals/shared/grpc/proxy_clone_test.go` | Locks every `ProxyMapping` field minus `AuthToken` | +| `server/activity/codes.go` | 13 new activity codes (125-137) | + +## HTTP routes added + +All routes inherit the platform's auth middleware. Perms enforced inside `agentnetwork.Manager.requirePermission` (`manager.go:680-682`) on `modules.AgentNetwork`. Permission column shows the `op` passed to `requirePermission` — read = `Read`, etc. + +| Method | Path | Perm | Handler | +| ------ | ---- | ---- | ------- | +| GET | `/agent-network/catalog/providers` | authn only | `providers_handler.go:43` | +| GET | `/agent-network/providers` | read | `providers_handler.go:57` | +| POST | `/agent-network/providers` | create | `providers_handler.go:97` | +| GET | `/agent-network/providers/{providerId}` | read | `providers_handler.go:77` | +| PUT | `/agent-network/providers/{providerId}` | update | `providers_handler.go:132` | +| DELETE | `/agent-network/providers/{providerId}` | delete | `providers_handler.go:172` | +| GET | `/agent-network/policies` | read | `policies_handler.go:32` | +| POST | `/agent-network/policies` | create | `policies_handler.go:72` | +| GET | `/agent-network/policies/{policyId}` | read | `policies_handler.go:52` | +| PUT | `/agent-network/policies/{policyId}` | update | `policies_handler.go:102` | +| DELETE | `/agent-network/policies/{policyId}` | delete | `policies_handler.go:142` | +| GET | `/agent-network/guardrails` | read | `guardrails_handler.go:25` | +| POST | `/agent-network/guardrails` | create | `guardrails_handler.go:65` | +| GET | `/agent-network/guardrails/{guardrailId}` | read | `guardrails_handler.go:45` | +| PUT | `/agent-network/guardrails/{guardrailId}` | update | `guardrails_handler.go:95` | +| DELETE | `/agent-network/guardrails/{guardrailId}` | delete | `guardrails_handler.go:135` | +| GET | `/agent-network/budget-rules` | read | `budget_handler.go:24` | +| POST | `/agent-network/budget-rules` | create | `budget_handler.go:64` | +| GET | `/agent-network/budget-rules/{ruleId}` | read | `budget_handler.go:44` | +| PUT | `/agent-network/budget-rules/{ruleId}` | update | `budget_handler.go:95` | +| DELETE | `/agent-network/budget-rules/{ruleId}` | delete | `budget_handler.go:135` | +| GET | `/agent-network/settings` | read | `settings_handler.go:53` (200+`null` if no row) | +| PUT | `/agent-network/settings` | update | `settings_handler.go:27` | +| GET | `/agent-network/consumption` | read | `consumption_handler.go:21` | + +## gRPC RPCs added (or modified) + +| RPC | Direction | Trigger | +| --- | --------- | ------- | +| `CheckLLMPolicyLimits` | proxy→mgmt unary | Pre-flight gate; returns allow/deny, selected policy, attribution group, window, deny code+reason (`proxy.go:259-301`). `Unimplemented` when limits service is nil. | +| `RecordLLMUsage` | proxy→mgmt unary | Post-flight write of tokens+cost against policy-window dimensions + every applicable account budget rule (`proxy.go:303-349`). `window_seconds==0` ⇒ no policy cap, only account fan-out runs. | +| `GetMappingUpdate`/`SendServiceUpdate` (stream) | mgmt→proxy | Snapshot (`proxy.go:752-780`) now appends `SynthesizeServicesForCluster`. Live updates use `SendServiceUpdateToCluster` + `shallowCloneMapping`. | + +## Architecture & flow + +### HTTP request lifecycle + +```mermaid +sequenceDiagram + participant DB as Dashboard + participant R as gorilla.Router (/api) + participant H as handler (agentnetwork) + participant M as agentnetwork.Manager + participant S as store.Store + participant AM as accountManager (StoreEvent) + + DB->>R: POST /api/agent-network/providers + R->>H: createProvider (auth mw sets UserAuth) + H->>H: GetUserAuthFromContext + validate(req) + H->>M: CreateProvider(userID, provider, bootstrapCluster) + M->>M: requirePermission(AgentNetwork, Create) + M->>S: SaveAgentNetworkProvider + M->>AM: StoreEvent(AgentNetworkProviderCreated) + M-->>H: created provider + H-->>DB: 200 + api.AgentNetworkProvider JSON +``` + +### Synth-service delivery via gRPC + +```mermaid +sequenceDiagram + participant P as Proxy + participant G as ProxyServiceServer + participant SM as service.Manager (persisted) + participant SA as synthesizerAdapter + participant AN as SynthesizeServicesForCluster + participant ST as store.Store + + Note over P,G: Initial snapshot + P->>G: GetMappingUpdate (stream open) + G->>SM: GetServicesForCluster(conn.address) + SM-->>G: persisted []*Service + G->>SA: SynthesizeServicesForCluster(conn.address) + SA->>AN: SynthesizeServicesForCluster(store, clusterAddr) + AN->>ST: walk every account; read providers/policies/settings + AN-->>SA: in-memory []*Service + SA-->>G: []*Service + G->>P: response (persisted + synth) + + Note over G,P: Per-request live update + G->>G: SendServiceUpdateToCluster(update, clusterAddr) + G->>G: shallowCloneMapping(update) %% Private MUST survive + G->>P: response with single mapping +``` + +End-to-end: HTTP write persists rows and emits an activity event; the manager then triggers `proxyController.SendServiceUpdate` so proxies re-render. **The snapshot path is the only one that calls into the synthesiser** — on stream open it pulls persisted services then appends synth services for the cluster. Synth services are never persisted. For OIDC/session/tunnel-peer flows, `getServiceByDomain` falls back to `SynthesizeServicesForCluster(clusterFromDomain(domain))` when persisted lookup misses (`proxy.go:1763-1793`). The network_map contribution is orthogonal: per-peer compute prepends the same synth services to `account.Services` before `InjectProxyPolicies`. + +## Permissions model added + +- `permissions/modules/module.go:22` adds `AgentNetwork Module = "agent_network"`, registered in `All` (`module.go:42`). Standard `operations.{Read,Create,Update,Delete}` matrix. +- Handlers don't call `permissionsManager` directly — they extract `UserAuth` and delegate to `agentnetwork.Manager`, which gates every mutation through `requirePermission` (`manager.go:168, 308, 549`, etc.). Confirm your role-set provider has `agent_network` rows for owner/admin/user/billing-admin before merging. +- `getCatalogProviders` (`providers_handler.go:43`) intentionally skips RBAC — catalog is global static data. + +## Activity codes added + +`activity/codes.go:244-274` adds Activities 125-137 + string/code mappings (`codes.go:428-444`), following `..` (e.g., `agent_network.provider.create`). Audit-log exporters / SIEM forwarders need to know the new codes. + +## Invariants + +- **Synth services are never persisted.** Snapshot appends after `serviceManager.GetServicesForCluster` (`proxy.go:761-770`); network_map prepends before `InjectProxyPolicies` (`controller.go:117-126`). +- **`shallowCloneMapping` must round-trip every `ProxyMapping` field except `AuthToken`** — `proxy_clone_test.go:50-58` enforces via `gproto.Equal`. The bug it guards: a missing `Private` made every MODIFIED arrive `private=false`, the proxy skipped `ValidateTunnelPeer`, `UserGroups` stayed empty, `llm_router` denied `no_authorised_provider`; a restart "fixed" it because the snapshot uses the original mapping. +- **Limit-window floor is 60s** (`policies_handler.go:189-220`); enabled cap with both per-group and per-user at zero is rejected. Budget rules reuse the same validator (`budget_handler.go:170`). +- **Manager is optional at boot.** `NewAPIHandler` registers routes only when non-nil (`handler.go:129`); `ProxyServiceServer` returns `Unimplemented` from both RPCs when limits service is unwired (`proxy.go:262-265, 306-309`). +- **Settings GET on an unbootstrapped account returns 200 + `null`** (`settings_handler.go:65-72`) — not 404. + +## Things to scrutinize + +### Correctness +- **`injectAllProxyPolicies` runs on every per-peer compute**: `controller.go:163, 309, 415, 681`. `sendUpdateAccountPeers` is the target of the buffered fan-out — synth runs once per debounced account-update tick **and** once per direct `UpdateAccountPeer`. Cost is O(providers + policies × users-per-group) per account under `LockingStrengthNone`. No per-account synth cache — verify it fits the buffer interval for your largest tenant. +- **`clusterFromDomain` strips at the first `.`** (`proxy.go:1784-1792`). A zero-dot domain returns `""` and the synth call walks every account. Confirm no path reaches this with a malformed/internal domain. +- **Account-budget `RecordConsumption` fans out even when `window_seconds == 0`** (`proxy.go:341-348`) — intentional. Verify the proxy never sends `RecordLLMUsage` for a request that wasn't actually allowed. + +### Security +- Every handler extracts `UserAuth` via `nbcontext.GetUserAuthFromContext` before any work. Routes live behind the standard `/api` mux; bypass list is not extended. +- `CheckLLMPolicyLimits` / `RecordLLMUsage` ride the existing **proxy → mgmt** gRPC connection auth. No additional token check inside the RPCs — they trust the connection. Confirm the proxy-side token-verification interceptor in this package gates both. +- `RecordLLMUsage` only validates `account_id != ""` (`proxy.go:317-319`). A compromised proxy can attribute cost to any account in its cluster — was already true for prior RPCs but is louder now that data drives denials. + +### Concurrency +- `SetAgentNetworkSynthesizer` / `SetAgentNetworkLimitsService` write under `s.mu.Lock`; read paths copy the interface under read lock (`proxy.go:236-247, 260-263, 304-307`). Same pattern as existing `serviceManager`/`proxyController` setters. +- Manager writes use `LockingStrengthUpdate`; synth reads use `LockingStrengthNone` — read-after-write via the proxy snapshot can observe a stale view by up to one fan-out tick. +- Network_map controller is single-threaded per account; cross-account is parallel. + +### Backward compatibility +- `proxy_clone_test.go` is the regression net; any new `ProxyMapping` field must be cloned or explicitly nulled in the test. +- `AccessLogEntry` adds indexed `AgentNetwork bool` — implicit AutoMigrate; deploy story must handle table-rewrite cost on high-volume access-log tables. +- `TargetOptions` gains seven `omitempty` JSON fields (`service.go:69-94`); on-wire shape stays compatible. `targetOptionsToProto` tests all fields when deciding nil (`service.go:551-556`). +- `NewAPIHandler` signature changes — every caller must pass `agentNetworkManager`; `nil` is supported. + +### Observability +- 13 new activity codes via `accountManager.StoreEvent` in the manager — confirm dashboard's audit-log UI maps them. +- `AccessLogEntry.AgentNetwork` is indexed for the dashboard's agent-network log filter. +- New RPCs log at error level on store/selector failures (`proxy.go:284, 327, 332, 348`). Snapshot synth failures degrade to warnings — stream is not aborted (`proxy.go:765`). + +## Test coverage + +| Test | Locks down | +| ---- | ---------- | +| `handlers_test.go::TestPolicyHandler_WindowSecondsRoundTrip` | GET carries `window_seconds`; legacy `window_hours`/`window_days` absent. | +| `handlers_test.go::TestPolicyHandler_RejectsSubMinuteWindow` | POST `<60s` returns 4xx. | +| `handlers_test.go::TestConsumptionHandler_EmptyAccountReturnsArray` | `/consumption` returns `[]` — never null. | +| `handlers_test.go::TestConsumptionHandler_PopulatedAccountListsRows` | RecordConsumption×2 surfaces both with correct tokens/cost/window. | +| `budget_handler_test.go::TestBudgetRuleHandler_RoundTrip` | Targets + PolicyLimits shape round-trip. | +| `budget_handler_test.go::TestBudgetRuleHandler_ListReturnsArray` | Empty-list shape. | +| `budget_handler_test.go::TestBudgetRuleHandler_{RejectsMissingName,RejectsSubMinuteWindow}` | Validation rejections are 4xx. | +| `budget_handler_test.go::TestSettingsHandler_GetExposesCollectionToggles` | All four toggles + computed `Endpoint`. | +| `proxy_clone_test.go::TestShallowCloneMapping_PreservesAllFieldsExceptAuthToken` | Future-proofs clone; every field round-trips, `AuthToken` dropped. | + +Handler tests use a real sqlite store + real manager + always-allow permissions mock (`handlers_test.go:53-75`). Create/update/delete success paths flow through `accountManager.StoreEvent` which the fixture doesn't wire — covered by manager-level no-mock tests outside this module. + +## Known limitations / explicit non-goals + +- No pagination on any list endpoint; no bulk endpoints. +- Synth result is not cached — every snapshot and every per-peer compute repeats the store walk. +- `getSettings` returning `200 + null` is a deliberate dashboard concession. +- No rate-limiting beyond the global `/api` rate limiter. + +## Cross-references + +- Upstream: [shared/api](10-shared-api.md), [management/agentnetwork](21-management-agentnetwork.md), [management/store](20-management-store.md) +- Downstream: [proxy/runtime](33-proxy-runtime.md) +- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md) +- Top-level: [../00-overview.md](../00-overview.md) diff --git a/docs/agent-networks/modules/30-proxy-middleware-framework.md b/docs/agent-networks/modules/30-proxy-middleware-framework.md new file mode 100644 index 000000000..39322fdce --- /dev/null +++ b/docs/agent-networks/modules/30-proxy-middleware-framework.md @@ -0,0 +1,215 @@ +# proxy/middleware-framework — generic plugin system + +> **Risk level:** **High** — every proxied request transits this chain. Budget exhaustion, panic recovery, or chain-close bugs hit the hot path for all targets, not just agent-network ones. +> **Backward-compat impact:** Additive at the proxy. The `middleware` and `bodytap` packages are new (`proxy/internal/middleware/middleware.go:1`, `proxy/internal/middleware/bodytap/request.go:13`); existing proxy targets keep working until a chain is bound to them via `Manager.Rebuild`. + +This module is the **framework only** — no LLM/agent-network domain knowledge is required, since every example built into it is generic. + +## Module boundary + +This module is the **framework only**: slots, chains, registry, dispatcher, accumulator, body-tap, output filters. No middleware *implementation* lives here — those land in `proxy/internal/middleware/builtin/*` (covered in module 31). The package contract is: + +1. The proxy hands a `Manager` to its config-apply path. The synth pushes per-path `PathTargetBinding` lists (`proxy/internal/middleware/manager.go:26`) into `Manager.Rebuild`, which resolves each spec via the `Registry`/`Resolver` (`proxy/internal/middleware/registry.go:81-121`) and produces an immutable `Chain` keyed by `serviceID|pathID` (`proxy/internal/middleware/manager.go:410-412`). +2. The reverse-proxy handler captures the request body via `bodytap.CaptureRequest`, calls `Chain.RunRequest`, applies returned mutations (already filtered by `chain.applyMutations`), forwards to the upstream behind a `bodytap.CapturingResponseWriter`, then calls `Chain.RunResponse` and `Chain.RunTerminal`. +3. Middlewares are inert plugins that receive a deep-cloned `Input` and return an `Output` whose decision/mutations are clamped by the dispatcher's `filterOutput` (`proxy/internal/middleware/dispatcher.go:149-172`). + +Everything that crosses the framework boundary in either direction is value-typed and deep-copied — middlewares cannot mutate the live request directly, and the framework cannot inadvertently leak middleware-owned slices into the request hot path. + +## Files + +| Path | Role | +| ---- | ---- | +| `proxy/internal/middleware/middleware.go` | `Middleware` + `Factory` interfaces. | +| `proxy/internal/middleware/types.go` | `Slot`, `FailMode`, `Decision`, all limit constants, `Input`/`Output`/`Mutations`/`UpstreamRewrite`/`AuthHeader` value types. | +| `proxy/internal/middleware/spec.go` | Apply-time `Spec` (validated wire shape + runtime-injected fields) and `Clone`. | +| `proxy/internal/middleware/registry.go` | `Registry` (factory map, RWMutex) and `Resolver` (Spec → bound `Middleware`). | +| `proxy/internal/middleware/manager.go` | `Manager`, `chainTable` reverse index, `Rebuild`/`Invalidate*`, async chain close. | +| `proxy/internal/middleware/chain.go` | `Chain.RunRequest`/`RunResponse`/`RunTerminal`, mutation gating, `cloneInputFor`. | +| `proxy/internal/middleware/chain_test.go` | Metadata threading, LIFO response order, rewrite gating, UserGroups propagation, terminal accumulation. | +| `proxy/internal/middleware/dispatcher.go` | Timeout/panic recovery, fail-mode, error classification, `filterOutput`. | +| `proxy/internal/middleware/decision.go` | `RenderDenyResponse`, deny-code regex, status clamp. | +| `proxy/internal/middleware/headerpolicy.go` | Compile-in header denylist + `FilterHeaderMutations`. | +| `proxy/internal/middleware/bodypolicy.go` | `ValidateBodyReplace` / `ApplyBodyReplace` smuggling guards. | +| `proxy/internal/middleware/keys.go` | Metadata key namespace constants. | +| `proxy/internal/middleware/metadata.go` | `Accumulator` — allowlist, per-mw/per-request byte caps, redaction. | +| `proxy/internal/middleware/metrics.go` | OTel instrument bundle (`proxy.middleware.*`). | +| `proxy/internal/middleware/redaction.go` | `Scan` — PEM/JWT/AWS/bearer/Luhn-validated CC patterns. | +| `proxy/internal/middleware/bodytap/request.go` | Capture + replay reader, `Budget` semaphore, bypass reason codes. | +| `proxy/internal/middleware/bodytap/response.go` | `CapturingResponseWriter` (tee with `PassthroughWriter` for Flusher/Hijacker preservation). | + +## Slot model + +Three slots, declared per-middleware exactly once (`proxy/internal/middleware/types.go:27-41`): + +- **`SlotOnRequest`** (`Slot=1`) — runs **before** the upstream call, in registration order. May `DecisionDeny`, may emit `Mutations` (header add/remove, body replace, `UpstreamRewrite`) when both `Spec.CanMutate` and `Middleware.MutationsSupported()` are true. May emit metadata. Each middleware in the slot sees metadata that earlier ones in the same slot just emitted (`proxy/internal/middleware/chain.go:144-178`) — this is how the framework gives middlewares an intra-slot side channel without a global bag. +- **`SlotOnResponse`** (`Slot=2`) — runs **after** the upstream returns, in **reverse** registration order. Cannot deny (clamped in `dispatcher.filterOutput`, `proxy/internal/middleware/dispatcher.go:153-157`). May still mutate response headers in principle, but the current chain only forwards `RewriteUpstream` from on_request, so on_response mutations are observe-only in practice. Threads the same per-slot metadata view as on_request. +- **`SlotTerminal`** (`Slot=3`) — runs **after** every on_response middleware has emitted, in registration order. Sees the full accumulated bag plus prior terminal emissions (`chain.go:221-245`). Cannot deny, cannot mutate (`dispatcher.go:168-170`). Designed for sinks (access log, metrics push, audit emitter). + +Splitting a feature across slots (e.g. "parse on the way out, ship on terminal") is the explicit architectural choice — `types.go:7-15` and `types.go:22-25` make it clear no middleware participates in more than one slot. + +## Architecture & flow + +### Chain dispatch + +```mermaid +sequenceDiagram + autonumber + participant H as proxy HTTP handler + participant BT as bodytap.CaptureRequest + participant CH as Chain + participant DI as Dispatcher + participant MW as Middleware (per slot) + participant US as Upstream + participant CW as CapturingResponseWriter + + H->>BT: CaptureRequest(r, cfg, budget) + BT-->>H: body[], truncated, release() + H->>CH: RunRequest(ctx, r, Input, Accumulator) + loop on_request, registration order + CH->>CH: cloneInputFor(in, OnRequest) + CH->>DI: Invoke(ctx, spec, mw, call) + DI->>MW: mw.Invoke(callCtx, in) + MW-->>DI: Output{decision, metadata, mutations?} + DI->>DI: filterOutput (clamp deny, gate mutations) + DI-->>CH: filtered Output + CH->>CH: Accumulator.Emit (allowlist + caps + redact) + alt DecisionDeny + CH-->>H: denied, merged, rewrite + else allow + CH->>CH: applyMutations(r, m) and capture rewrite + end + end + CH-->>H: nil, merged, rewrite + H->>US: ProxyRequest (with rewrite/mutations applied) + US-->>CW: bytes (streamed, tee'd into cap-bounded buf) + CW-->>H: passthrough complete + H->>CH: RunResponse(ctx, Input{RespBody:CW.Body(),...}, acc) + loop on_response, REVERSE order (LIFO) + CH->>DI: Invoke (same wrappers) + end + H->>CH: RunTerminal(ctx, Input{Metadata:full bag}, acc) + H->>BT: release() + CW.Release() +``` + +### Body-tap mechanics (request + response) + +```mermaid +flowchart LR + subgraph req[Request capture — bodytap.CaptureRequest] + R0[r.Body] --> R1{cfg.MaxRequestBytes > 0?\nUpgrade absent?\nContent-Type allowed?\nCL <= cap?} + R1 -- no --> R2[bypass = reason\nbody = nil\nr.Body untouched] + R1 -- yes --> R3[Budget.Acquire(cap)] + R3 -- denied --> R4[bypass=BypassBudget] + R3 -- ok --> R5[io.LimitReader(r.Body, cap+1)\nio.ReadAll] + R5 --> R6{len > cap?} + R6 -- truncated --> R7[viewable = buf[:cap]\nr.Body = replayReadCloser{buf, tail}] + R6 -- whole --> R8[r.Body = NopCloser(bytes.Reader(buf))\nclose original] + R7 --> R9[(release captured\nbudget on req end)] + R8 --> R9 + end + + subgraph resp[Response capture — CapturingResponseWriter] + W0[client] -.-> CW[Write(p)] + CW --> P1[PassthroughWriter.Write(p)\n— bytes leave to client first] + P1 --> P2{!stopped?} + P2 -- yes --> P3{remaining = cap - buf.Len()} + P3 --> P4[buf.Write(p[:take])\nset truncated if take P5[silent drop into the tee\n(client write already done)] + end +``` + +The body-tap is the highest-leak-risk surface in this module; three details matter: + +1. **Request capture is "read-and-replay", not "read-and-forward".** `CaptureRequest` always swaps `r.Body` for either a `bytes.Reader` (whole body fit) or a `replayReadCloser` that replays the captured prefix then drains the remaining stream from the original body (`bodytap/request.go:178-201`). This means the **upstream still sees the full body even when the tap truncates**. The original `r.Body` is **not** closed in the truncated branch — `replayReadCloser.Close()` only closes the tail (`bodytap/request.go:199-201`), which is the same reader, so close once on request end is correct, but reviewers should confirm the upstream proxy always reads to EOF (otherwise the tail is leaked). +2. **Response capture is a write-through tee.** `CapturingResponseWriter.Write` forwards to the underlying writer **first** (`bodytap/response.go:116-117`), then tees into `buf` under its own mutex. Client never blocks on the tee. `Flusher`/`Hijacker` are preserved via the embedded `responsewriter.PassthroughWriter`. SSE/chunked streams flow through untouched; middlewares only see the bounded prefix. +3. **Budget is a single shared semaphore.** `Manager` constructs one `bodytap.Budget` at startup (`manager.go:138-144`, default `256 MiB` from `bodytap/request.go:39`). Every capture pre-acquires its full `MaxRequestBytes` / `MaxResponseBytes` from the budget regardless of actual body size; that prevents a flood of small captures from collectively exceeding the cap, but it also means a misconfigured `MaxRequestBytes = 1 MiB` with 256 concurrent requests already exhausts the default budget. Reviewers should sanity-check the operator-facing defaults that ship with synth-service. + +The framework explicitly aborts capture (and increments `proxy.middleware.capture_bypass_total`) before reading the first byte when `Upgrade`/`Connection: upgrade` is set (`bodytap/request.go:120-125`), when the content-type isn't in the allowlist (`bodytap/request.go:126-128`), or when the advertised `Content-Length` already exceeds the cap (`bodytap/request.go:131-133`). This is the right place to make sure WebSocket upgrades and large file uploads never reach the buffer. + +## Public contracts + +- **`Middleware` interface** (`middleware.go:14-36`): `ID()`, `Version()`, `Slot()`, `AcceptedContentTypes()`, `MetadataKeys()`, `MutationsSupported()`, `Invoke(ctx, *Input) (*Output, error)`, `Close()`. `MetadataKeys()` is the **closed set** the middleware is allowed to emit — the accumulator drops anything outside it (`metadata.go:71-75`). `Close` must be idempotent (called even when `Invoke` was never reached). +- **`Factory` interface** (`middleware.go:44-47`): `ID()`, `New(rawConfig []byte) (Middleware, error)`. `RawConfig` is opaque JSON bytes on the wire (`spec.go:6-12`); each factory owns its own typed config. +- **`Decision` type** (`types.go:59-69`): `Allow=0`, `Deny=1`, `Passthrough=2`. Default-zero is permissive — important because every middleware that omits `Decision` gets `Allow`. Dispatcher clamps `Deny` to `Passthrough` outside `SlotOnRequest` (`dispatcher.go:153-157`). +- **`Mutations`** (`types.go:196-201`): `HeadersAdd`/`HeadersRemove` (filtered through `headerpolicy.go`), `BodyReplace` (gated through `bodypolicy.go`), and `RewriteUpstream`. `RewriteUpstream` is **last-write-wins** within the on_request slot (`chain.go:170-172`, locked down by `TestChain_RunRequest_LatestRewriteWins`). +- **Metadata propagation keys** (`keys.go`): all keys live in a single file and follow `^[a-z][a-z0-9_-]*(\.[a-z0-9_-]*)+$` (`metadata.go:8`). Framework-injected error tagging uses `mw..error_kind` (`keys.go:81`) so operators can distinguish framework-emitted entries from middleware-emitted ones. + +## Invariants + +- **Per-request context isolation.** `cloneInputFor` deep-copies every mutable field (`Headers`, `RespHeaders`, `Metadata`, `Body`, `RespBody`, `UserGroups`, `UserGroupNames`) before each invocation (`chain.go:286-308`). A misbehaving middleware that mutates `in.Headers` only corrupts its own copy. +- **Body-tap bounded by capture limit.** Request side uses `io.LimitReader(r.Body, limit+1)` (`bodytap/request.go:152`) — the `+1` is how the code detects truncation (`bodytap/request.go:160`); the surfaced buffer is sliced back down to `limit`. Response side stops teeing once `buf.Len() >= cap` (`bodytap/response.go:121-133`). Neither side can grow the buffer past the configured cap. +- **Headers/body redaction order.** Accumulator runs `Scan(value)` **before** counting cost (`metadata.go:81-82`), so the byte budgets are computed against post-redaction sizes. `Scan` order is PEM → JWT → AWS key → bearer → Luhn-validated CC (`redaction.go:25-51`) — the comment block in `redaction.go:8-13` is explicit that this is best-effort, not DLP. +- **No middleware can starve the chain.** Every invocation runs inside `context.WithTimeout(ctx, clampTimeout(spec.Timeout))` in a separate goroutine (`dispatcher.go:51-94`), with the deadline race-`select`ed against the result channel. A blocked middleware fires the timeout path, gets fail-mode'd, and `IncError(kind=timeout)`. Timeouts are clamped to `[10ms, 5s]` (`types.go:80-86`, `dispatcher.go:174-185`). +- **Panic recovery.** `recover()` captures the panic, logs only the type + a 4 KiB stack prefix (no panic value — avoids leaking secrets the middleware was processing), and produces a `panicError` that flows through fail-mode (`dispatcher.go:64-76`). +- **Chain immutability + atomic swap.** `chainTable` is cloned on every `Rebuild`/`Invalidate*` and swapped via `atomic.Pointer` (`manager.go:44-69`, `manager.go:221-300`). Readers (`ChainFor`) are lock-free; writers serialise on `writeMu`. The retired chain is `Close`-d in a background goroutine bounded by `chainCloseTimeout = 2 * MaxTimeout` (`manager.go:21-22`, `manager.go:326-346`), so in-flight invocations finish on the old chain after the swap. + +## Things to scrutinize + +### Correctness + +- **Chain ordering deterministic from synth output?** `Manager.buildChain` iterates `b.Specs` in slice order and appends to `bound` (`manager.go:366-391`); `NewChain` then partitions by slot but **preserves slice order within each slot** (`chain.go:50-60`). So order on the wire = order observed at runtime. Synth must therefore emit specs in the intended execution order — there is no per-spec `Priority` field. Worth flagging. +- **Decision short-circuit semantics.** `RunRequest` returns immediately on `DecisionDeny` (`chain.go:164-167`) **with the metadata accumulated so far** plus the `denied.Metadata`. Callers that ignore `merged` on deny will lose framework-injected `mw..error_kind` entries. The proxy runtime is the only caller; confirm it always feeds `merged` into the access log on the deny path as well. +- **`UpstreamRewrite` `AuthHeader` bypass** (`types.go:218-235`). The `AuthHeader`/`StripHeaders` fields *intentionally* bypass the header denylist on the basis that the proxy itself rewrites auth. The denylist still blocks middleware-emitted `HeadersAdd: Authorization=...`. This is a delicate carve-out — review the runtime consumer to confirm only the trusted upstream-build path unpacks `AuthHeader`, never the generic `applyMutations` loop. +- **`replayReadCloser.Close` only closes the tail** (`bodytap/request.go:199-201`). The replay buffer doesn't own a resource, so this is correct, but it conflates "replay finished" with "underlying body closed". If a caller `Close()`s without reading to EOF, the original body is closed but the captured prefix is lost; harmless for the proxy path (upstream always reads to EOF) but worth a doc-comment. + +### Security + +- **Body-tap memory bounds.** Discussed above — bounded by `MaxBodyCapBytes = 1 MiB` per direction (`types.go:77`) and the shared `Budget` (default 256 MiB). The concerning case is the **deep-copy in `cloneInputFor`** (`chain.go:300-306`): every middleware invocation gets its **own copy** of `Body` and `RespBody`. A chain of N middlewares with a 1 MiB body allocates N MiB of transient bytes per request. With `MaxMiddlewaresPerChain = 16` (`types.go:103`) that's up to 16 MiB extra per in-flight request. Worth pricing into the budget model. +- **Header redaction completeness.** `denyHeaders` (`headerpolicy.go:5-17`) covers the auth/forwarding family and framing (`Content-Length`, `Transfer-Encoding`, `Trailer`). `denyHeaderPrefixes` covers `X-Authenticated-*`, `X-Forwarded-*`, `X-Remote-*`, `X-NetBird-*`. Notably absent: `Range`, `If-Match`/`If-None-Match` (mutation could cause cache poisoning), `Origin`/`Referer`. Not necessarily wrong, but worth a deliberate decision. +- **Metadata key collisions across middlewares.** The accumulator has no cross-middleware uniqueness check; two middlewares with the same key in their allowlist can both emit it, and both copies land in `merged` (`metadata.go:51-99`). Downstream consumers must tolerate duplicates. Worth documenting. +- **Deny rendering.** `RenderDenyResponse` only allows codes matching `^[a-z][a-z0-9._-]{0,63}$` (`decision.go:9`), redacts/truncates message + detail values, caps `Details` at 8 entries (`decision.go:42-50`), clamps status to `[400,499]\{401}` (`decision.go:65-73`). The deny body type is fixed; middlewares cannot inject arbitrary JSON. + +### Concurrency + +- **Per-request state vs shared state in factories.** Each `Factory.New` is called once per chain build; the returned `Middleware` instance is **shared across all requests** for that chain. `Invoke` must be reentrant. The framework does not enforce this — a buggy middleware that holds per-call state on the struct will silently race. Suggest a `// Invoke must be safe for concurrent use` doc on the interface. +- **`chainTable` clone-on-write** is correct, but `addChain`/`removeChain` mutate the *cloned* table before the swap (`manager.go:71-108`), and they're called under `writeMu`. Readers only ever see the post-swap pointer. Good. +- **`Chain.inflight` WaitGroup**. `Run*` does `Add(1)`/`Done()` (`chain.go:142-143`, `chain.go:194-195`, `chain.go:225-226`); `Close` waits on it bounded by ctx (`chain.go:75-85`). One concern: a *new* `RunRequest` can `Add(1)` *after* `Close` started waiting if the caller still holds a stale chain pointer. `WaitGroup` does not panic on this if the count was already > 0 at `Wait` time, but it does panic if `Add` happens after `Wait` returns and another `Wait` runs. `Close` is documented one-shot, so single-`Wait` is fine, but callers must drop the chain reference before calling `Close`. Worth a code comment near `Close`. +- **Goroutine leaks.** `Dispatcher.Invoke` spawns one goroutine per call and *always* writes to a buffered (cap=1) channel (`dispatcher.go:62-76`), so even if the timeout fires the goroutine completes its send and exits. No leak. +- **`closeChainsAsync`** detaches retired chains into a goroutine (`manager.go:326-346`). If `Manager` is never GC'd this is fine, but there's no shutdown hook to wait on outstanding closes. Reviewers should confirm the proxy shutdown path explicitly drains in-flight requests before tearing down `Manager`, or accept that the last chain-close round may be cut short on exit. + +### Performance + +- **Allocations per request.** `cloneInputFor` allocates new slices for `Headers`, `RespHeaders`, `Metadata`, `Body`, `RespBody`, `UserGroups`, `UserGroupNames` — once per middleware per request. For a typical 5-middleware chain on a 1 KiB body that's ~10 small slice allocs plus one `Body` copy each. Not a hot-path crisis, but `sync.Pool` for the per-call `Input` would be a natural follow-up. +- **Accumulator allocates a fresh `allowSet` per `Emit` call** (`metadata.go:55-58`). One per middleware per slot pass = up to 48 per request. Cheap, but worth noting. +- **Regex cost.** `Scan` runs five regex passes on every accepted metadata value (`redaction.go:25-51`). Bounded by `MaxMetadataValueBytes = 4 KiB` so worst case is small. + +### Observability + +- **Per-middleware metrics.** `proxy.middleware.requests_total{middleware,target_id,outcome}` (`metrics.go:34-41`), `duration_ms`, `invocations_total`, `errors_total{kind}`, `metadata_rejected_total{reason}`, `header_mutation_blocked_total{header}`, `capture_bypass_total{reason}`. Comprehensive surface; operators can alert on `errors_total{kind=panic}` and `errors_total{kind=timeout}` separately. **Latency histogram is in milliseconds with default OTel buckets** — for a 10ms–5s timeout range default buckets cover OK, but a custom bucket set centred on 1–500ms would resolve the agent-network response-parser tail better. +- **Decision logs.** Panic logs (`dispatcher.go:69`) include `request_id`, type, and stack but not the panic value (safe). `Chain.Close` logs middleware-close errors at debug (`chain.go:91`). `applyMutations` logs body-replace rejections at warn (`chain.go:278`). No log on the deny path itself — by design, since the access-log terminal middleware is expected to record outcomes. + +## Test coverage + +| Test file | Locks down | +| --------- | ---------- | +| `proxy/internal/middleware/chain_test.go:77` | `RunRequest` threads metadata across on_request middlewares (regression for the "later mw can't see earlier mw's emissions" bug). | +| `chain_test.go:110` | `RunResponse` reverse-order threading. | +| `chain_test.go:142` | `cost_meter`-shaped scenario: response_parser registered after cost_meter still emits *before* cost_meter sees the bag (guards the `cost.skipped=missing_tokens` regression). | +| `chain_test.go:178` | `UpstreamRewrite` last-write-wins. | +| `chain_test.go:206` | No middleware emits → nil rewrite. | +| `chain_test.go:224` | Rewrite filtered when `CanMutate=false`. | +| `chain_test.go:245` | `Input.UserGroups` propagates verbatim through `cloneInputFor`. | +| `chain_test.go:304` | Terminal middlewares see the full accumulated bag + prior terminal emissions. | + +**Gaps** worth raising with the author: +- No direct test for `Dispatcher.Invoke` timeout / panic / fail-mode behaviour at the framework level (covered indirectly by built-in tests, but a unit test pinning `errors_total{kind=...}` labels would be cheap insurance). +- No test for `bodytap.CaptureRequest` truncated replay (the upstream-sees-full-body invariant is exactly the kind of thing a regression would silently break). +- No test for `Budget` exhaustion behaviour under concurrency. +- No test for `Manager.InvalidateMiddleware` + `LiveServiceCheck` race (the auth-revocation race the comment at `manager.go:33-38` calls out is the load-bearing reason for `LiveServiceCheck`). + +## Known limitations / explicit non-goals + +- **No middleware-to-middleware RPC.** Side-channel is metadata only. +- **No streaming body inspection.** Middlewares see a bounded prefix; SSE / chunked parsing happens against that prefix in the response middleware. +- **No per-spec priority.** Order is registration order in the spec slice. +- **No retry / circuit-breaker** on middleware errors. Fail-mode is binary (open/closed) and per-spec. +- **Mutations cannot rewrite the request URL path or query** — only `RewriteUpstream` can change scheme/host (+ optional path replacement, see `types.go:218-235`). +- **Redaction is best-effort.** Explicitly documented in `redaction.go:8-13`. Not a DLP solution. + +## Cross-references + +- Upstream wire shape: [../modules/10-shared-api.md](10-shared-api.md) (Spec/RawConfig encoding from management). +- Built-in middlewares using this framework: [../modules/31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md). +- Runtime wiring (where `Manager`, `Chain`, and `bodytap` are consumed by the HTTP handler): [../modules/33-proxy-runtime.md](33-proxy-runtime.md). +- End-to-end request flow including capture + chain dispatch: [../01-end-to-end-flows.md](../01-end-to-end-flows.md). +- Top-level architecture: [../00-overview.md](../00-overview.md). diff --git a/docs/agent-networks/modules/31-proxy-middleware-builtin.md b/docs/agent-networks/modules/31-proxy-middleware-builtin.md new file mode 100644 index 000000000..904de6424 --- /dev/null +++ b/docs/agent-networks/modules/31-proxy-middleware-builtin.md @@ -0,0 +1,365 @@ +# proxy/middleware-builtin — the LLM chain + +The registry-mounted middleware set the proxy executes on every agent-network +LLM request. The two highest-blast-radius areas are the **capture-pointer +semantics** and the **limit_check ⇒ limit_record** record-once invariant. + +Sibling module: [32-proxy-llm-parsers.md](./32-proxy-llm-parsers.md) — the SDK +adapters + pricing catalog this chain delegates to. + +--- + +## Module boundary + +This module is the registry-mounted middleware set the proxy executes on +every agent-network LLM request. Each sub-package registers itself via +`init()` +([builtin.go:32–34](../../../proxy/internal/middleware/builtin/builtin.go)); +the proxy server anonymous-imports the set +([all_test.go:11–19](../../../proxy/internal/middleware/builtin/all_test.go)) +so the registry is populated at boot. The chain is wired by the management +synthesiser and executed by the framework +(`proxy/internal/middleware/{chain,dispatcher,accumulator}.go` — both out +of scope). Everything here reads from / writes to one envelope: the +`middleware.KV` metadata bag plus `middleware.Mutations` for header/body +rewrites. + +## The 8 middlewares + +| Name | Slot | Inputs (metadata read) | Outputs (metadata written) | Side effects | +|---|---|---|---|---| +| `llm_request_parser` | OnRequest | `Input.{URL,Body,BodyTruncated}` | `llm.{provider,model,stream,request_prompt_raw,capture_truncated}` | none | +| `llm_router` | OnRequest | `llm.model`, `Input.{URL,UserGroups}` | `llm.{resolved_provider_id,authorising_groups}`, `llm_policy.{decision,reason}` | upstream rewrite + auth strip/inject | +| `llm_limit_check` | OnRequest | `llm.{resolved_provider_id,model}`, `Input.{AccountID,UserID,UserGroups}` | `llm.{selected_policy_id,attribution_group_id,attribution_window_seconds}`, `llm_policy.{decision,reason}` | gRPC `CheckLLMPolicyLimits` | +| `llm_identity_inject` | OnRequest | `llm.{resolved_provider_id,authorising_groups}`, `Input.{UserEmail,UserID,UserGroups,UserGroupNames}` | none | header strip/inject + optional body rewrite | +| `llm_guardrail` | OnRequest | `llm.{model,request_prompt_raw}` | `llm_policy.{decision,reason}`, `llm.request_prompt` | none (model allowlist deny) | +| `llm_response_parser` | OnResponse | `llm.provider`, `Input.{RespHeaders,RespBody,Status}` | `llm.{input,output,total,cached_input,cache_creation}_tokens`, `llm.response_completion` | none | +| `cost_meter` | OnResponse | `llm.{provider,model}`, token buckets | `cost.usd_total` or `cost.skipped` | pricing lookup | +| `llm_limit_record` | OnResponse | `llm.{attribution_group_id,attribution_window_seconds,input_tokens,output_tokens}`, `cost.usd_total` | none | gRPC `RecordLLMUsage` | + +[all_test.go:26–40](../../../proxy/internal/middleware/builtin/all_test.go) +locks the ID set; adding or removing one is a conscious extension. + +## Files + +| File | LOC | Notes | +|---|---:|---| +| `builtin.go` | 86 | Registry + `FactoryContext` (ctx, data dir, meter, logger, mgmt client) | +| `all_test.go` | 41 | Locks the 8-ID registry surface | +| `agentnetwork_chain_integration_test.go` | 319 | Live sqlite + real gRPC bufconn; gate→recorder wire path | +| `llm_request_parser/*` | 162 / 66 / 356 | Provider detection, body parse, prompt extraction with capture-pointer gating | +| `llm_router/*` | 385 / 84 / 586 | Three-pass route selection (model → groups → path-prefix) | +| `llm_limit_check/*` | 196 / 38 / 182 | Pre-flight `CheckLLMPolicyLimits` (2s, fail-open) | +| `llm_identity_inject/*` | 440 / 108 / 666 | HeaderPair (LiteLLM) + JSONMetadata (Portkey) + ExtraHeaders | +| `llm_guardrail/*` | 176 / 82 / 75 / 219 / 217 | Model allowlist + optional prompt capture with PII redaction | +| `llm_response_parser/*` | 258 / 222 / 43 / 433 / 169 / 111 | Buffered + SSE accumulation; AWS event-stream accumulator (`streaming_bedrock.go`) for Bedrock; capture-pointer gates completion emit | +| `cost_meter/*` | 181 / 84 / 439 | Token → USD via `proxy/internal/llm/pricing` | +| `llm_limit_record/*` | 144 / 35 / 191 | Post-flight `RecordLLMUsage` (5s, debug-on-error) | + +## Per-middleware + +### llm_request_parser + +Detects the LLM provider via `llm.DetectParser` (URL sniff) or by name via +`llm.ParserByName` when synthesiser stamps `provider_id` +([middleware.go:96–99](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)). +**Path-routed providers short-circuit first:** `parseVertexPath` and +`parseBedrockPath` ([middleware.go:85–94](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)) +pull the model + vendor out of the URL before parser selection runs — Vertex +from `/v1/projects/.../publishers/{pub}/models/{model}:{action}` (publisher → +vendor via `vertexPublisherVendor`), Bedrock from `/model/{id}/{action}` with +`normalizeBedrockModel` stripping the region prefix + version suffix. See +[50-path-routed-providers.md](./50-path-routed-providers.md) for the full path +grammar. For body-routed providers it decodes the body into `RequestFacts` +(model + stream) and extracts the prompt. On +`capture_prompt=true` (or absent — see capture-pointer semantics below) the +prompt is run through `llm_guardrail.RedactPII` when `redact_pii=true` and +truncated rune-safely to 3500 bytes +([middleware.go:109–122](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)). +**Key invariant:** redaction is parser-side, not guardrail-side — access-log +reads `llm.request_prompt_raw` directly. + +### llm_router + +Three-pass route selection in `matchRoute` +([middleware.go:241–300](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)): +filter by `Models` claim → vendor-pin (a vendor-tagged request never crosses to +another vendor's route) → filter by `AllowedGroupIDs` intersection → model +precedence over path → tie-break by longest `UpstreamPath` prefix match. +Model-miss returns `llm_policy.model_not_routable`; known-but-unauthorised +returns `llm_policy.no_authorised_provider`. **Key invariant:** auth-header +strip+inject rides on `UpstreamRewrite.{StripHeaders,AuthHeader}` +([middleware.go:606–646](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)) +— NOT `HeadersAdd/HeadersRemove` — because the framework's mutation gate +blocks `Authorization` on the generic header path. + +**Path-routed providers route before the model table.** `Invoke` checks +`isVertexPath` / `isBedrockPath` +([middleware.go:138–216](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)) +ahead of the model lookup, so a path-carried model can't be claimed by a +same-vendor body-routed provider. `matchPathRoute` enforces the route's `Models` +allowlist (empty = catch-all) even though the model came from the URL. +Two path-only behaviours: +- **Vertex unmeterable publisher** — when `llm_request_parser` emits no + `llm.provider` (e.g. Gemini/`google`), the router denies with + `llm_policy.unmeterable_publisher` (403) rather than forward it uncounted. +- **GCP token minting** — when the route carries `GCPServiceAccountKeyB64` + (set from a `keyfile::` api_key), `gcpBearer` mints + caches a short-lived + OAuth2 token per request instead of injecting a static value; a bad key or + unreachable token endpoint denies with `llm_policy.upstream_auth_failed` + (502). Bedrock uses its static bearer token directly (no minting). +- **`/bedrock` prefix** — an optional `/bedrock` gateway-namespace prefix is + accepted and stripped via `RewriteUpstream.StripPathPrefix` so the native + `/model/...` path reaches the upstream. + +Full treatment in [50-path-routed-providers.md](./50-path-routed-providers.md). + +### llm_limit_check + +Pre-flight gate. Reads `llm.resolved_provider_id`, calls +`CheckLLMPolicyLimits` with a 2s context timeout +([middleware.go:24, 97–106](../../../proxy/internal/middleware/builtin/llm_limit_check/middleware.go)), +on allow stamps `llm.selected_policy_id`, `llm.attribution_group_id`, +`llm.attribution_window_seconds`. **Key invariant:** fail-open. Nil +`MgmtClient`, empty provider id, or RPC error returns `allowNoAttribution()` +— management outage doesn't take down every LLM request. Operators audit via +the access-log; a future flag may switch this to fail-closed. + +### llm_identity_inject + +Dispatches per-rule between LiteLLM-shaped `HeaderPair` +([middleware.go:169](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go)) +and Portkey-shaped `JSONMetadata` +([middleware.go:292](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go)). +Identity is the peer's email (or `UserID` fallback); tags are the +**authorising-groups intersection** emitted by `llm_router`, not the full +`UserGroups` — a peer in 5 groups authorised under 1 only tags as that 1. +**Anti-spoof:** every `HeadersAdd` is preceded by a `HeadersRemove` of the +same name; the framework runs `Remove` before `Add` so client-supplied +identity never reaches the upstream. Body-level inject (`tags_in_body`, +`end_user_id_in_body`) is skipped on empty / truncated / non-JSON bodies so +header attribution stays intact. + +### llm_guardrail + +Model allowlist deny + optional prompt-capture-with-redaction. Allowlist +match is case-insensitive via `normaliseModel`; empty allowlist disables the +check. Prompt capture reads `llm.request_prompt_raw` and emits +`llm.request_prompt` only when `prompt_capture.enabled` +([middleware.go:149–165](../../../proxy/internal/middleware/builtin/llm_guardrail/middleware.go)). +**Key invariant:** `RedactPII` is the exported function the parsers call — +single PII contract across all three keys. + +### llm_response_parser + +Buffered and SSE paths share one `Invoke` +([middleware.go:102–127](../../../proxy/internal/middleware/builtin/llm_response_parser/middleware.go)): +content-type sniffing dispatches to `invokeBuffered` (JSON, status<400) or +`invokeStreaming` (text/event-stream, partial bodies tolerated). Streaming +delegates to `accumulateStream` +([streaming.go:21–30](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go)) +using `llm.NewScanner`. A third path, `accumulateBedrockStream` +([streaming_bedrock.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock.go)), +decodes the AWS binary event-stream (`application/vnd.amazon.eventstream`) +returned by Bedrock's `-stream` actions — InvokeModel `chunk` frames wrap a +base64 Anthropic event, Converse frames carry text + a trailing usage block. +Cached / cache-creation buckets emit only when non-zero, preserving the existing +token schema. + +### cost_meter + +Reads `llm.provider` + `llm.model` + token buckets, looks up per-1k rate via +`pricing.Loader`, emits `cost.usd_total` or a closed-set `cost.skipped` +reason (`missing_provider/model/tokens`, `unparseable_tokens`, `zero_tokens`, +`unknown_model`). Loader's hot-reload goroutine is bound to proxy-lifetime +context via `startReloader`. **Key invariant:** provider-shape switch lives +in `pricing.Table.Cost` (sibling doc) — `cost_meter` stays provider-agnostic. + +### llm_limit_record + +Post-flight write. Always returns `DecisionAllow`; response has already been +served so RPC errors mustn't surface (logged at `Debugf`). Skip-on-no-signal +at line 81 (zero tokens + zero cost). **Key invariant:** the +skip-on-missing-attribution guard at line 98 is a safety net independent of +the framework's deny short-circuit — if the gate denied and the framework +still runs the recorder, the recorder skips on absent +`UserID`+`groupID`+`UserGroups` and no phantom counter materialises. + +## Full-chain diagram (canonical order) + +```mermaid +flowchart TD + A[HTTP request] --> B[llm_request_parser
OnRequest] + B -->|llm.provider, llm.model,
llm.stream, llm.request_prompt_raw| C[llm_router
OnRequest] + C -->|llm.resolved_provider_id,
llm.authorising_groups,
upstream rewrite + auth| D[llm_limit_check
OnRequest] + D -->|deny path| Z1[403 llm_policy.*] + D -->|allow + llm.selected_policy_id,
llm.attribution_group_id,
llm.attribution_window_seconds| E[llm_identity_inject
OnRequest] + E -->|header strip+inject
+ optional body rewrite| F[llm_guardrail
OnRequest] + F -->|deny: model_blocked| Z2[403 llm_policy.model_blocked] + F -->|allow + llm.request_prompt| G[upstream LLM call] + G --> H[llm_response_parser
OnResponse] + H -->|llm.{input,output,total,cached_input,cache_creation}_tokens,
llm.response_completion| I[cost_meter
OnResponse] + I -->|cost.usd_total or cost.skipped| J[llm_limit_record
OnResponse] + J --> K[response to client] +``` + +## limit_check ⇒ limit_record record-once invariant + +```mermaid +sequenceDiagram + participant LC as llm_limit_check + participant M as management gRPC + participant U as upstream LLM + participant LR as llm_limit_record + participant DB as sqlite consumption table + + LC->>M: CheckLLMPolicyLimits (2s) + alt allow + M-->>LC: selected_policy_id, attribution_group_id, window_s + LC->>U: stamps attribution metadata + U-->>LR: response + tokens (via llm_response_parser + cost_meter) + LR->>M: RecordLLMUsage (5s, debug-on-error) + M->>DB: increment (user, group, window) row + else deny + M-->>LC: llm_policy.token_cap_exceeded + Note over LR: framework short-circuits; even if invoked,
recorder skips on absent UserID+groupID+UserGroups + else mgmt nil / rpc error + LC-->>LC: allowNoAttribution() — fail open + Note over LR: no window_s ⇒ recorder books only account-level
budget rules (which run independently) + end +``` + +The integration test +[agentnetwork_chain_integration_test.go](../../../proxy/internal/middleware/builtin/agentnetwork_chain_integration_test.go) +exercises all three branches against a real sqlite store + bufconn gRPC — +no mocks. Tests: `TestChain_AllowPath_StampsAttributionAndRecordsCounter` +(line 130), `TestChain_DenyPath_GateRejectsAndNoConsumptionWritten` (line +207), `TestChain_CapExhaustTransition` (line 265). + +## Public contracts (per-middleware JSON config) + +| Middleware | Config shape | +|---|---| +| `llm_request_parser` | `{provider_id?, redact_pii?, capture_prompt?: *bool}` ([factory.go:19–37](../../../proxy/internal/middleware/builtin/llm_request_parser/factory.go)) | +| `llm_router` | `{providers: [{id, models, upstream_scheme, upstream_host, upstream_path?, auth_header_name, auth_header_value, allowed_group_ids}]}` | +| `llm_limit_check` | `{}` — pulls `MgmtClient` from `FactoryContext` | +| `llm_identity_inject` | `{providers: [{provider_id, header_pair?|json_metadata?, extra_headers?}]}` | +| `llm_guardrail` | `{model_allowlist: []string, prompt_capture: {enabled, redact_pii}}` | +| `llm_response_parser` | `{redact_pii?, capture_completion?: *bool}` | +| `cost_meter` | `{pricing_path?}` (basename inside data-dir; defaults `pricing.yaml`) | +| `llm_limit_record` | `{}` — same pattern as `llm_limit_check` | + +All factories accept empty / null / `{}` / whitespace as zero-value config; +only structurally invalid JSON is rejected so misconfig surfaces at chain +build time. + +## Invariants + +1. **limit_check ↔ limit_record paired.** They MUST appear together. Gate + stamps attribution metadata on the request leg; recorder reads it on the + response leg. If a chain contains only the recorder, the + skip-on-missing-attribution guard at + [llm_limit_record/middleware.go:81–87, 98–103](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go) + keeps counters consistent but no enforcement runs. Only-gate means + counters never tick and headroom appears infinite. + +2. **`capture_prompt` / `capture_completion` pointer semantics.** Both are + `*bool`. `nil` = "preserve legacy emit" (back-compat default for + non-agent-network callers and pre-toggle tests). `false` = suppress the + key entirely (access-log row carries zero prompt / completion content). + `true` = emit. The synthesiser sets the pointer explicitly to the + account's `EnablePromptCollection` toggle. The handling lives + in [llm_request_parser/factory.go:55–61](../../../proxy/internal/middleware/builtin/llm_request_parser/factory.go) + and the symmetric [llm_response_parser/middleware.go:62–68](../../../proxy/internal/middleware/builtin/llm_response_parser/middleware.go); + a missing pointer must not be treated as `false` (that would suppress + capture for legacy non-agent-network callers). + `redact_pii` is an orthogonal `bool` controlling **form** of emitted + content, not whether it's emitted. + +3. **`redact_pii` is parser-side.** Both parsers import + `llm_guardrail.RedactPII` and run it BEFORE stamping the metadata bag. + Load-bearing because the access-log sink reads `llm.request_prompt_raw` + and `llm.response_completion` directly — by the time `llm_guardrail` + runs its own pass on `llm.request_prompt`, the raw key has already been + stamped. Tests: `TestInvoke_RedactPii_RedactsBeforeEmittingRawPrompt`, + `TestInvoke_RedactPii_RedactsCompletionBeforeEmit`. + +4. **Metadata allowlist enforcement.** Every middleware declares + `MetadataKeys()`. The framework accumulator drops any KV outside that + allowlist. When adding a new key, also extend the docstring in + `middleware/keys.go`. + +5. **Closed deny-code set.** All deny paths emit one of: + `llm_policy.model_not_routable`, `llm_policy.no_authorised_provider`, + `llm_policy.model_blocked`, `llm_policy.token_cap_exceeded`, + `llm_policy.unmeterable_publisher` (path-routed Vertex publisher with no + parser → 403), `llm_policy.upstream_auth_failed` (GCP token mint failure → + 502), or the management-supplied code on `llm_limit_check`. These surface + verbatim; arbitrary middleware text never reaches the wire. + +## Things to scrutinise + +**Correctness.** `llm_router` model match treats an empty `Models` slice as +"claim every model" +([middleware.go:238–248](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)) +for gateway-style providers — confirm no real provider record ships with an +empty `Models` by accident. Path-prefix tie-break falls back to declaration +order when no candidate prefix-matches, so the synthesiser must emit a +deterministic order. `llm_limit_record` discards `strconv.ParseInt` errors +([middleware.go:78–80](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go)) +— relies on `llm_response_parser` always emitting parseable values; spot-check +the streaming partial path on truncated bodies. + +**Security.** Auth headers must NEVER appear on `Mutations.HeadersAdd/Remove` +for the router — a direct headers path would bypass the framework gate. The +capture-pointer handling is the kind of place a bug ships PII to logs +silently; every synthesiser config path must set the pointer explicitly. +`llm_identity_inject` body inject silently skips on a +non-object `metadata` field +([middleware.go:262–270](../../../proxy/internal/middleware/builtin/llm_identity_inject/middleware.go)) +— header path still attributes, but body-level tag-budget enforcement +doesn't run for that request. + +**Concurrency.** `cost_meter` shares a `pricing.Loader` via +`atomic.Pointer[Table]`; readers always see a consistent table. Every +middleware is a stateless value receiver. Integration test uses real bufconn +gRPC — race detector is the meaningful bar. + +**Perf.** Hot path is `lookupKV` linear scan over <10 KVs; `cost_meter.Cost` +is O(1); SSE accumulation is single-pass. No map allocation per call. + +**Observability.** Every deny stamps `llm_policy.decision=deny` and a +matching `llm_policy.reason` — access-log can pivot on either. +`llm_limit_record` only logs at `Debugf` on RPC failure +([middleware.go:125–130](../../../proxy/internal/middleware/builtin/llm_limit_record/middleware.go)); +operators need an alternate signal (metric on `RecordLLMUsage` failures) for +counter accuracy. + +## Test coverage + +| File | Tests | Notes | +|---|---:|---| +| `all_test.go` | 1 | Registry surface lock | +| `agentnetwork_chain_integration_test.go` | 3 | Allow/deny/cap-exhaust vs live sqlite + bufconn gRPC | +| `llm_request_parser/middleware_test.go` | 18 | `provider_id` bypass, redaction, capture-pointer, rune-safe truncation | +| `llm_router/middleware_test.go` | 19 | Three-pass match, deny codes, path-prefix tie-break, header strip+inject | +| `llm_limit_check/middleware_test.go` | 6 | Allow/deny, fail-open on nil mgmt / RPC error, attribution stamping | +| `llm_identity_inject/middleware_test.go` | 28 | HeaderPair, JSONMetadata, ExtraHeaders, body inject, anti-spoof | +| `llm_guardrail/middleware_test.go` | 15 | Allowlist case-insensitivity, prompt capture toggle, deny shape | +| `llm_guardrail/redact_test.go` | 15 | Email, SSN, phone (E.164 + NA), bearer, IPv4; fixture-driven | +| `llm_response_parser/middleware_test.go` | 18 | Buffered OAI+Anthro, capture-pointer, redact, truncation | +| `llm_response_parser/streaming_test.go` | 7 | OAI usage frame, Anthro message_delta, truncated body best-effort | +| `cost_meter/middleware_test.go` | 17 | Each skip reason, provider-shape, pricing loader integration | +| `llm_limit_record/middleware_test.go` | 7 | Skip-on-no-signal, skip-on-missing-attribution, RPC failure swallowed | + +## Cross-references + +- Sibling: [32-proxy-llm-parsers.md](./32-proxy-llm-parsers.md) — SDK adapters + + SSE framer + pricing loader. +- Path-routed providers (Vertex AI + Bedrock), `keyfile::` credential, GCP + token minting, `/bedrock` prefix: + [50-path-routed-providers.md](./50-path-routed-providers.md). +- Upstream config: `management/server/agentnetwork/synthesizer` (out of scope). +- Framework: `proxy/internal/middleware/{chain,dispatcher,accumulator,registry}.go`. +- Metadata key registry: `proxy/internal/middleware/keys.go`. +- gRPC surface: `proto.ProxyServiceClient.{CheckLLMPolicyLimits,RecordLLMUsage}`. diff --git a/docs/agent-networks/modules/32-proxy-llm-parsers.md b/docs/agent-networks/modules/32-proxy-llm-parsers.md new file mode 100644 index 000000000..0376bc988 --- /dev/null +++ b/docs/agent-networks/modules/32-proxy-llm-parsers.md @@ -0,0 +1,392 @@ +# proxy/llm-parsers — SDK adapters + pricing + SSE + +The runtime-agnostic LLM library: the OpenAI Responses API (`/v1/responses`) +and the older Chat Completions API (`/v1/chat/completions`), the Anthropic +Messages API (`/v1/messages`), the SSE wire format (`event:` / `data:` lines, +`\n\n` framing, CRLF tolerance), and per-provider token accounting (OpenAI's +cached-prompt **subset** vs Anthropic's cache_read **additive** model). The +pricing table's per-provider cost formula is the highest-leverage place a +small bug would silently mis-bill operators. + +Sibling module: [31-proxy-middleware-builtin.md](./31-proxy-middleware-builtin.md) +— the 8 middlewares that consume this package's parsers + pricing loader. + +--- + +## Module boundary + +`proxy/internal/llm` is the runtime-agnostic LLM library shared by every +middleware that needs to understand provider-specific shapes. Zero +proxy-framework dependencies: + +- `parser.go` — `Parser` interface, `Provider` enum, public factories + (`Parsers`, `DetectParser`, `ParserByName`). +- `openai.go` / `anthropic.go` / `bedrock.go` — per-provider `Parser` impls. +- `sse.go` — SSE scanner (`Scanner`, `Event`, `NewScanner`). +- `errors.go` — sentinels callers branch on with `errors.Is`. +- `pricing/` — embedded-default + hot-reload override table with + symlink-safe Unix loader (build-tagged stub elsewhere). +- `fixtures/` — captured request/response/stream bodies the tests replay. + +The package carries zero proxy-framework dependencies so the same parsers can +be reused later by a WASM adapter +([parser.go:1–6](../../../proxy/internal/llm/parser.go)). + +## Files + +| File | LOC | Notes | +|---|---:|---| +| `parser.go` | 104 | Interface + factories + `Provider{Unknown,OpenAI,Anthropic}` enum | +| `openai.go` | 347 | Chat Completions + Completions + Responses API; cached_tokens subset | +| `openai_test.go` | 222 | 11 tests; fixture replay + cached/Responses-API matrix | +| `anthropic.go` | 172 | Messages + legacy `/v1/complete`; cache_read + cache_creation additive | +| `anthropic_test.go` | 154 | 7 tests including streaming-extraction-skipped contract | +| `bedrock.go` | 190 | AWS Bedrock InvokeModel (snake_case) + Converse (camelCase) response shapes; model lives in URL path | +| `bedrock_test.go` | — | InvokeModel + Converse usage shapes; AWS event-stream content-type → `ErrStreamingUnsupported` on buffered `ParseResponse` | +| `sse.go` | 117 | `bufio`-backed scanner; CRLF normalised; trailing-event handling | +| `sse_test.go` | 175 | 12 tests; fixture replay + multiline + size limits | +| `parser_test.go` | 53 | `Parsers()`, `DetectParser`, provider enum values | +| `errors.go` | 31 | 6 sentinels: `Err{Unknown,Unsupported}Provider/Model`, `Err{NotLLM,Malformed}Response`, `ErrStreamingUnsupported`, `ErrMalformedRequest` | +| `pricing/pricing.go` | 421 | `Loader`, `Table`, `Entry`; embedded defaults + atomic swap + mtime reload | +| `pricing/pricing_unix.go` | 69 | `O_NOFOLLOW` + fstat-from-FD + 1 MiB cap | +| `pricing/pricing_other.go` | 21 | Stub returning "not supported on this platform" | +| `pricing/pricing_test.go` | 432 | 21 tests — symlink rejection, reload race, path traversal, oversize | +| `pricing/defaults_pricing.yaml` | 85 | go:embed source of truth | +| `fixtures/*` | 21–59 | OAI chat/responses/stream + Anthro messages/stream + pricing starter | + +## Request body → parser dispatch + +```mermaid +flowchart TD + A[HTTP request
URL + JSON body] --> B{ParserByName?
provider_id config set} + B -- yes --> P[matched Parser] + B -- no --> C[DetectParser] + C --> D{loop Parsers
OpenAIParser, AnthropicParser} + D -- DetectFromURL match --> P + D -- no match --> X[ok=false
middleware skips] + P --> E[ParseRequest body] + E -->|err: ErrMalformedRequest| Y[middleware emits provider only] + E --> F[RequestFacts
model + stream] + P --> G[ExtractPrompt body] + G --> H[joinMessages
extractContentParts
decodeStringOrJoin] + H --> I[prompt text
or empty] + F --> J[stamps llm.model + llm.stream] + I --> K[stamps llm.request_prompt_raw
subject to capture_prompt gate] +``` + +OpenAI's URL hints +([openai.go:27–33](../../../proxy/internal/llm/openai.go)) include +both `/v1/chat/completions` and the bare `/chat/completions` — the latter +covers Cloudflare AI Gateway, which rewrites the canonical version segment. +Anthropic's hints are `/v1/messages` and `/v1/complete` +([anthropic.go:14–17](../../../proxy/internal/llm/anthropic.go)). +Both implementations use case-insensitive substring matching so a proxy prefix +strip / rewrite doesn't defeat detection. + +`ParserByName` ([parser.go:93–103](../../../proxy/internal/llm/parser.go)) +is the **agent-network bypass**: the synthesiser knows which parser to use +because it built the synth service from the catalog, so it stamps +`provider_id` on the parser config and the middleware skips URL sniffing +entirely. This is what makes the same parser set work whether the request +flows to OpenAI direct, to LiteLLM, to Portkey, or to any gateway with a +non-canonical URL shape. + +**Path-routed providers (Vertex AI, Bedrock) bypass both `ParserByName` and +`DetectParser`.** The model and the parser surface live in the URL path, so the +request middleware extracts them directly (`parseVertexPath` / +`parseBedrockPath`) before the parser-selection step. For Vertex the publisher +segment picks the parser (`anthropic` → Anthropic parser; `google`/Gemini → +none, request denied as unmeterable). For Bedrock the dedicated `BedrockParser` +handles the response. Full treatment in +[50-path-routed-providers.md](./50-path-routed-providers.md). + +## Streaming response → SSE chunker → response parser → completion + token count + +```mermaid +sequenceDiagram + participant U as upstream LLM + participant LR as llm_response_parser
(OnResponse) + participant S as llm.NewScanner
(SSE framer) + participant P as Parser-specific accumulator
(accumulateOpenAIStream
or accumulateAnthropicStream) + + U-->>LR: text/event-stream
(buffered prefix in RespBody) + LR->>S: NewScanner(bytes.NewReader(body)) + loop until EOF or [DONE] + S-->>LR: Event{Type, Data} + LR->>P: dispatch per event.Type
(OpenAI: data-only
Anthropic: named events) + P-->>P: accumulate completion text
track usage from final frame + end + P-->>LR: llm.Usage + completion string + LR->>LR: appendUsage stamps
llm.{input,output,total,cached_input,cache_creation}_tokens + LR->>LR: truncateCompletion(3500 bytes, rune-safe) + LR->>LR: redactPII if redact_pii && captureCompletion +``` + +`Scanner.Next` +([sse.go:44–87](../../../proxy/internal/llm/sse.go)) returns one +event per `\n\n` boundary; multiple `data:` lines join with `\n`; comment lines +(starting with `:`) are skipped per the SSE spec; a trailing event without a +closing blank line is still returned before `io.EOF` so a server that closes +the connection cleanly doesn't lose the last frame +([sse.go:55–58](../../../proxy/internal/llm/sse.go)). CRLF is +normalised in `trimEOL` so fixtures captured from live servers replay +unchanged. + +## Per-provider + +### OpenAI + +[openai.go:54–67](../../../proxy/internal/llm/openai.go) defines +`openAIRequest` with three prompt fields: `messages` (Chat Completions), +`prompt` (legacy), `input` (Responses API). The decoder uses +`json.RawMessage` so each shape is parsed lazily. + +`ParseResponse` +([openai.go:117–146](../../../proxy/internal/llm/openai.go)) +accepts both naming conventions: Chat Completions returns +`prompt_tokens`/`completion_tokens`, Responses API returns +`input_tokens`/`output_tokens`. `pickInt64` prefers Responses-API names and +falls back — same parser handles both endpoints without per-route config. +`openAICachedTokens` mirrors the fallback for +`input_tokens_details.cached_tokens` vs `prompt_tokens_details.cached_tokens`. + +**Key invariant:** `CachedInputTokens` for OpenAI is a SUBSET of +`InputTokens`. The cost meter clamps to guard against malformed upstream +responses where `cached > total`. + +### Anthropic + +[anthropic.go:37–49](../../../proxy/internal/llm/anthropic.go) +defines `anthropicRequest` covering Messages API (`system` + `messages[]`) +and legacy `/v1/complete` (`prompt` string). `ExtractPrompt` emits +`system: ` first when present, then per-message `role: content`. + +`ParseResponse` +([anthropic.go:82–104](../../../proxy/internal/llm/anthropic.go)) +fills three independent token buckets: `InputTokens`, `CacheReadInputTokens`, +`CacheCreationInputTokens`. Latter two are **additive** (not subset). +`TotalTokens` sums all four so downstream dashboards render one "tokens" +number without double-counting. + +`ExtractCompletion` walks `content[]` `{type, text}` parts and concatenates +non-empty text with newlines, falling back to legacy `completion`. + +### Bedrock + +[bedrock.go](../../../proxy/internal/llm/bedrock.go) implements the +`Parser` interface for the AWS Bedrock runtime. Bedrock is **path-routed**: the +model lives in the URL (`/model/{id}/{action}`), so the request middleware +extracts it (see [50-path-routed-providers.md](./50-path-routed-providers.md)) +and `ParseRequest` is a deliberate no-op. The parser's real work is on the +response leg, covering both Bedrock body shapes: + +- **InvokeModel** — vendor-native. Anthropic-on-Bedrock returns snake_case usage + (`input_tokens`, `output_tokens`, `cache_read_input_tokens`, + `cache_creation_input_tokens`) with the same additive cache buckets as + first-party Anthropic. +- **Converse** — unified camelCase (`inputTokens`, `outputTokens`, + `totalTokens`). `firstNonZero` folds the two naming conventions into one + `Usage`; when Converse omits `totalTokens` the parser sums the buckets. + +`ProviderName()` returns `"bedrock"` — its own `defaults_pricing.yaml` block, +keyed by the **normalised** model id (region prefix + version suffix stripped by +the request parser). `ParseResponse` returns `ErrStreamingUnsupported` for an +AWS binary event-stream content-type (`application/vnd.amazon.eventstream`, +`isAWSEventStream`) so the caller routes to the streaming accumulator instead. + +### SSE framing + +`Scanner` is `bufio`-backed, 64 KiB read buffer, 1 MiB max line so a +malicious upstream can't blow process memory +([sse.go:33–38, 97–100](../../../proxy/internal/llm/sse.go)). +`splitField` strips one space after the `:` per the SSE spec. Documented +`not safe for concurrent use`; every consumer creates a fresh scanner per +response body. Streaming accumulators live in the middleware package +([llm_response_parser/streaming.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go)) +but use `llm.NewScanner` so the framing contract stays here. + +### Pricing catalog + +`Table.Cost` +([pricing.go:129–174](../../../proxy/internal/llm/pricing/pricing.go)) +is the cost formula — most security-relevant math in this module: + +| Provider | Formula | +|---|---| +| `openai` | `(inTokens − clamped) × InputPer1K + clamped × CachedInputPer1K + outTokens × OutputPer1K` where `clamped = min(cachedInput, inTokens)` | +| `anthropic`, `bedrock` | `inTokens × InputPer1K + cachedInput × CacheReadPer1K + cacheCreation × CacheCreationPer1K + outTokens × OutputPer1K` | +| default | `inTokens × InputPer1K + outTokens × OutputPer1K` | + +`bedrock` shares the Anthropic additive-cache formula +([pricing.go:172-174](../../../proxy/internal/llm/pricing/pricing.go)): +Anthropic-on-Bedrock reports the same additive cache buckets, while non-Anthropic +Bedrock models (Nova, Llama) simply report zero in those buckets so cost reduces +to `input + output`. + +Each per-bucket rate falls back to `InputPer1K` when zero — operators opt in +to discounts by setting the field. + +`Loader` +([pricing.go:212–268](../../../proxy/internal/llm/pricing/pricing.go)) +overlays an optional `pricing.yaml` from data-dir on top of the go:embed +defaults. Atomic pointer swap means readers never observe a partial update. +The mtime-poll reloader (30s default cadence) keeps the previous table on +parse failure so cost annotation never goes blank during a botched edit. + +`defaults_pricing.yaml` is the source of truth for built-in pricing. +Operator overrides only carry the entries they want to change. + +## Public contracts + +**`Parser` interface** +([parser.go:50–66](../../../proxy/internal/llm/parser.go)): + +```go +type Parser interface { + Provider() Provider + ProviderName() string + DetectFromURL(path string) bool + ParseRequest(body []byte) (RequestFacts, error) + ParseResponse(status int, contentType string, body []byte) (Usage, error) + ExtractPrompt(body []byte) string + ExtractCompletion(status int, contentType string, body []byte) string +} +``` + +Adding a provider means implementing this interface and appending to the +slice returned by `Parsers()` ([parser.go:78–84](../../../proxy/internal/llm/parser.go)). +Order matters: `DetectFromURL` ties resolve by registration order. +`Parsers()` today returns `{OpenAIParser, AnthropicParser, BedrockParser}`. + +**`Provider` enum** +([parser.go:8–18](../../../proxy/internal/llm/parser.go)): +`ProviderUnknown = 0`, `ProviderOpenAI = 1`, `ProviderAnthropic = 2`, +`ProviderBedrock = 3`. Numeric values are persisted in nothing today but treat +them as wire-stable — new providers must take fresh numbers. + +**`Pricing` lookup** +([pricing.go:129](../../../proxy/internal/llm/pricing/pricing.go)): + +```go +func (t *Table) Cost(provider, model string, inTokens, outTokens, cachedInput, cacheCreation int64) (float64, bool) +``` + +Nil-safe: `t.Cost` on a nil receiver returns `(0, false)` +([pricing.go:130–132](../../../proxy/internal/llm/pricing/pricing.go)). +`ok=false` means provider or model is absent from the loaded table; the caller +emits `cost.skipped=unknown_model`. + +## Invariants + +1. **Cross-platform pricing build.** `pricing_unix.go` carries the only + functional `loadPricing` (uses `syscall.O_NOFOLLOW` and `f.Stat()` on an + open descriptor — both Unix-only). `pricing_other.go` is a build-tag + fallback that returns `"not supported on this platform"` + ([pricing_other.go:14–16](../../../proxy/internal/llm/pricing/pricing_other.go)). + The proxy is Linux-only in production today; a Windows port needs an + equivalent path-as-handle implementation. Reviewers building on Windows + should expect this surface to return an error at startup if an override + file is configured. + +2. **SSE scanner handles partial chunks.** A buffered prefix that doesn't end + in `\n\n` still yields its accumulated event before `io.EOF` + ([sse.go:55–58](../../../proxy/internal/llm/sse.go)). Tests: + `TestSSEScanner_OpenAIFixture`, `TestSSEScanner_AnthropicFixture`, + `TestSSEScanner_MultilineData`, `TestSSEScanner_CRLF`. The streaming + accumulators ride on this: `accumulateAnthropicStream` and + `accumulateOpenAIStream` `break` on any scanner error to return partial + usage rather than aborting + ([streaming.go:68–73, 144–150](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming.go)). + +3. **`defaults_pricing.yaml` is the source of truth.** Compiled into the + binary via `//go:embed` + ([pricing.go:29–30](../../../proxy/internal/llm/pricing/pricing.go)). + `DefaultTable()` parses once and panics on parse failure + ([pricing.go:42–49](../../../proxy/internal/llm/pricing/pricing.go)) + — by design: a broken embedded YAML must not ship to production. + +4. **Loader path validation.** `resolveMiddlewareDataPath` + ([pricing.go:370–394](../../../proxy/internal/llm/pricing/pricing.go)) + rejects absolute paths, traversal segments, and basenames that fail + `basenameRegex = ^[a-zA-Z0-9._-]+$`. The resolved path must remain + inside `baseDir` even after `filepath.Clean`. Tests: + `TestNewLoader_PathValidation`, `TestNewLoader_PathValidation_Extended`, + `TestNewLoader_SymlinkOutsideBaseDirRejected`, `TestNewLoader_SymlinkRejected`. + +5. **Unix loader symlink safety.** `O_NOFOLLOW` on open, `f.Stat()` on the + open descriptor (never re-stat by path), `info.Mode().IsRegular()` check, + `io.LimitReader(f, maxPricingBytes+1)` with a final size assertion + ([pricing_unix.go:25–57](../../../proxy/internal/llm/pricing/pricing_unix.go)). + A mid-read symlink swap is detected because the fstat is on the original + fd. Test: `TestNewLoader_RejectsOversizedFile_FixesM4`. + +6. **`yaml.NewDecoder(...).KnownFields(true)`** + ([pricing.go:397–398](../../../proxy/internal/llm/pricing/pricing.go)) + rejects YAML files that carry fields not in the schema. A typo in an + operator override file fails loud instead of silently zeroing rates. + +## Things to scrutinise + +**Correctness.** Verify OpenAI cached-prompt clamp at +[pricing.go:147–149](../../../proxy/internal/llm/pricing/pricing.go) +short-circuits before subtraction. `Anthropic.TotalTokens` sums all four +buckets (in + out + cache_read + cache_creation) — downstream dashboards +need to know this differs from `input + output`. +`OpenAIParser.ExtractPrompt` falls through `messages → input → prompt`; a +request sending all three reports only `messages` (uncommon but worth +noting). + +**Security.** `Scanner.maxLine = 1 MiB`; a 2 MiB single-line `data:` event +errors from `Scanner.Next` and both accumulators stop with partial usage. +Pricing file 1 MiB cap is orders of magnitude larger than realistic. Confirm +new schema additions are mirrored in both `pricingFile` and `Entry`; +`KnownFields(true)` will reject silently-typo'd operator overrides +otherwise. + +**Concurrency.** `Loader.table` is `atomic.Pointer[Table]`; readers never +block or see a torn table. `Loader.Reload` is one goroutine, cancelled via +context (`TestLoader_ReloadBackgroundLoopCancellation`). `DefaultTable()` +uses `sync.Once`. Per-call `Scanner` instances mean no shared state across +concurrent response-parser calls. + +**Perf.** `Table.Cost` is two map lookups + multiplications, O(1). +`Scanner.Next` is one `ReadString('\n')` per line. Pricing reload poll 30s. + +**Observability.** Reload failures count via `metric.Int64Counter` keyed +`plugin`; warning log rate-limited at 5 min so a broken file doesn't flood. +Parser errors return sentinels — middleware uses `errors.Is` to map to the +right `cost.skipped` reason. + +## Test coverage + +| File | Tests | Coverage highlights | +|---|---:|---| +| `parser_test.go` | 3 | `Parsers()` shape lock, `DetectParser` URL matrix, provider enum stability | +| `openai_test.go` | 11 | Chat Completions + Responses API + legacy `prompt`; cached-tokens subset for both naming conventions; fixture replays | +| `anthropic_test.go` | 7 | Messages + legacy `/v1/complete`; streaming REJECTED on `ParseResponse` (must use scanner); fixture replays | +| `sse_test.go` | 12 | Fixture replay both providers; multiline `data:`; CRLF; comment skip; trailing-event-without-blank-line; oversize rejection | +| `pricing/pricing_test.go` | 21 | Provider-shape switch; cached-rate fallback; cached-clamp; symlink rejection (target outside basedir + symlink to file); path validation matrix; oversize rejection; reload-keeps-previous-on-parse-error; mtime change detection; goroutine cancellation | + +**Fixtures** ([proxy/internal/llm/fixtures/](../../../proxy/internal/llm/fixtures/)): +`openai_chat_completion.json` (chat.completions with usage), +`openai_responses.json` (Responses API shape), +`openai_stream.txt` (3 deltas + usage + `[DONE]`), +`anthropic_messages.json` (Messages API non-streaming), +`anthropic_stream.txt` (full 7-event sequence: message_start → +content_block_{start,delta×2,stop} → message_delta (usage) → message_stop), +`pricing.yaml` (realistic-pricing starter for operator overrides). + +## Cross-references + +- Sibling: [31-proxy-middleware-builtin.md](./31-proxy-middleware-builtin.md) + — the chain that calls `llm.Parsers()`, `llm.ParserByName`, + `llm.NewScanner`, `pricing.NewLoader`. +- Path-routed providers (Vertex AI + Bedrock), credential syntax, and the + Bedrock AWS event-stream accumulator: + [50-path-routed-providers.md](./50-path-routed-providers.md). +- Direct callers: `llm_request_parser/middleware.go:82–94`, + `llm_response_parser/middleware.go:113–123`, + `llm_response_parser/streaming.go:65, 142`, `cost_meter/factory.go:49–57`. +- Related elsewhere: the agent-network synthesiser stamping `provider_id` + is covered in the management-side module guide; proxy server boot + + `FactoryContext` construction is covered in the proxy-framework guide. diff --git a/docs/agent-networks/modules/33-proxy-runtime.md b/docs/agent-networks/modules/33-proxy-runtime.md new file mode 100644 index 000000000..f553473f8 --- /dev/null +++ b/docs/agent-networks/modules/33-proxy-runtime.md @@ -0,0 +1,194 @@ +# proxy/runtime — translate + serve + log + +> **Risk level:** High — every config push from management is translated here, and the chain runs on every HTTP request to a synth target. +> **Backward-compat impact:** Additive at the wire (`PathTargetOptions.middlewares`, `agent_network`, `disable_access_log`, capture caps) and on the proxy `Server` struct (`MiddlewareDataDir`, `MiddlewareCaptureBudgetBytes`). Non-agent-network targets stay on the no-middleware fast path. + +## Module boundary + +Turns the synth-service wire format from `ProxyService.SyncMappings`/`GetMappingUpdate` into in-process middleware chains and runs them on top of the existing `httputil.ReverseProxy`. Four concerns: (a) **translate** — `proto.MiddlewareConfig` → validated `middleware.Spec` (proxy/middleware_translate.go) + self-register the eight built-ins (proxy/middleware_register.go); (b) **boot + rebuild** — construct the `middleware.Manager`, share the OTel meter, install the live-service check, rebuild per-path chains on every `addMapping`/`modifyMapping` (proxy/server.go); (c) **serve** — resolve chain at request time, capture bodies under a global budget, invoke `RunRequest`/`RunResponse`/`RunTerminal`, render deny responses, apply `UpstreamRewrite` (proxy/internal/proxy/reverseproxy.go); (d) **log + tag** — emit access-log entries with the new `agent_network` flag, gate emission on `EnableLogCollection` via `DisableAccessLog` (proxy/internal/accesslog). + +**Inert for non-agent-network targets**: nil or empty chain → existing fast path (reverseproxy.go:127-139); `SuppressAccessLog` defaults false so the access-log middleware emits unchanged. + +## Files + +| Path | Role | +| ---- | ---- | +| proxy/middleware_translate.go | proto→Spec translation; slot/failmode/timeout mapping; caps | +| proxy/middleware_translate_test.go | translator unit tests | +| proxy/middleware_register.go | blank-imports the eight builtins for `init()` registration | +| proxy/server.go | `initMiddlewareManager`, `rebuildMiddlewareChains`, `isLiveService`, `buildMiddlewareBindings`, new Server fields, `protoToMapping` stamps AgentNetwork/DisableAccessLog/CaptureConfig/Middlewares | +| proxy/internal/proxy/reverseproxy.go | `WithMiddlewareManager`, chain dispatch, body capture, `applyUpstreamRewrite`/`Headers`, `buildRequestInput`, response-leg respInput identity fields | +| proxy/internal/proxy/reverseproxy_test.go | `TestBuildRequestInput_PropagatesIdentityAndGroups` | +| proxy/internal/proxy/context.go | `agentNetwork`, `suppressAccessLog`, `userGroupNames` on `CapturedData` | +| proxy/internal/proxy/servicemapping.go | new `PathTarget` fields | +| proxy/internal/proxy/agent_network_chain_realstack_test.go | end-to-end self-contained chain test | +| proxy/internal/accesslog/logger.go | `logEntry.AgentNetwork` → `proto.AccessLog` | +| proxy/internal/accesslog/middleware.go | reads `GetAgentNetwork()`; gates `l.log` on `!GetSuppressAccessLog()` | +| proxy/internal/accesslog/middleware_test.go | suppress/default/preserves-usage assertions | +| proxy/internal/auth/middleware_test.go | tunnel-peer group propagation contract | +| proxy/internal/metrics/metrics.go | `Meter()` getter for the middleware manager | + +## Architecture & flow + +### Synth-service ingestion → translate → register → serve + +```mermaid +flowchart TD + A[Management SyncMappings/GetMappingUpdate] --> B["processMappings\nserver.go:1492"] + B --> C{Mapping type} + C -->|CREATED| D["addMapping → setupHTTPMapping → updateMapping"] + C -->|MODIFIED| E["modifyMapping → cleanupMappingRoutes → setupHTTPMapping → updateMapping"] + C -->|REMOVED| F["removeMapping → cleanupMappingRoutes → invalidateMiddlewareChains"] + D --> G["protoToMapping\nserver.go:2181"] + E --> G + G --> H["translateMiddlewareConfigs\nmiddleware_translate.go:55"] + G --> I["translateMiddlewareCaptureConfig\nmiddleware_translate.go:18"] + H --> J["[]middleware.Spec on PathTarget"] + I --> K["*bodytap.Config on PathTarget"] + J --> L["proxy.AddMapping\nservicemapping.go:118"] + K --> L + L --> M["rebuildMiddlewareChains\nserver.go:2017 → Manager.Rebuild"] + F --> N["Manager.Invalidate(serviceID)"] +``` + +### Per-request lifecycle through the chain + accesslog + +```mermaid +sequenceDiagram + autonumber + participant C as Client + participant M as accesslog.Middleware + participant A as auth.Middleware (Protect) + participant RP as ReverseProxy.ServeHTTP + participant CH as middleware.Chain + participant U as Upstream + C->>M: HTTP request + M->>M: NewCapturedData(requestID), WithCapturedData(ctx) + M->>A: next.ServeHTTP + A->>A: Private → ValidateTunnelPeer → stamp UserID/Email/Groups/GroupNames/AuthMethod + A->>RP: next.ServeHTTP + RP->>RP: findTargetForRequest → targetResult + RP->>RP: stamp ServiceID/AccountID/AgentNetwork/SuppressAccessLog on CapturedData + RP->>RP: resolveChain via Manager.ChainFor + alt chain == nil or Empty + RP->>U: httputil.ReverseProxy.ServeHTTP (fast path) + else chain non-empty + RP->>RP: bodytap.CaptureRequest (global budget) + RP->>CH: RunRequest + CH-->>RP: denyOutput? requestMeta + upstreamRewrite + alt deny + RP->>C: RenderDenyResponse + else allow + RP->>RP: capturingWriter + applyUpstreamRewrite/Headers + RP->>U: httputil.ReverseProxy.ServeHTTP(respWriter) + U-->>RP: response + RP->>CH: RunResponse (respInput carries UserGroups) + RP->>CH: RunTerminal (merged request+response metadata) + end + end + RP-->>M: handler returns + M->>M: build logEntry incl. AgentNetwork + alt SuppressAccessLog == true + M->>M: skip l.log; still trackUsage + else default + M->>M: l.log → goroutine SendAccessLog + end +``` + +### EnableLogCollection suppression path + +```mermaid +flowchart LR + S["agentnetwork.Settings.EnableLogCollection"] --> B["synthesizer: target.DisableAccessLog = !EnableLogCollection"] + B --> P["proto PathTargetOptions.disable_access_log (field 13)"] + P --> T["protoToMapping reads GetDisableAccessLog()\nserver.go:2211"] + T --> M["PathTarget.DisableAccessLog\nservicemapping.go:47"] + M --> R["ServeHTTP: cd.SetSuppressAccessLog\nreverseproxy.go:106"] + R --> G["accesslog middleware: if !GetSuppressAccessLog l.log\nmiddleware.go:95"] + R --> U["trackUsage unconditional — bandwidth telemetry preserved"] +``` + +**Ingestion** lands as a `ProxyMapping` batch on `handleSyncMappingsStream`/`handleMappingStream`. `processMappings` dispatches to `addMapping`/`modifyMapping`/`removeMapping`; HTTP goes `setupHTTPMapping → updateMapping → protoToMapping`. `protoToMapping` (server.go:2181) is the single translation surface that materialises `[]middleware.Spec`, `*bodytap.Config`, `AgentNetwork`, `DisableAccessLog` onto each `PathTarget`; `updateMapping` finishes with `s.proxy.AddMapping(m)` (atomic swap under `mappingsMux`) and `s.rebuildMiddlewareChains(svcID, m)`. + +At **request time** the access-log middleware stamps `CapturedData`; the auth chain runs (Private services lift `peer_group_ids` from `ValidateTunnelPeer` — auth/middleware_test.go:322). `ReverseProxy.ServeHTTP` resolves the chain; nil or empty → original `httputil.ReverseProxy`, no body capture. When a chain matches, body is captured under the global budget, `RunRequest` produces an `UpstreamRewrite` (`llm_router` selects a provider, rewrites scheme/host/path, injects `Authorization`), and `RunResponse`+`RunTerminal` run after the upstream returns. The terminal slot sees the merged metadata bag — that's how `llm_limit_record` ships the consumption sample. The **access-log** addition: `logEntry.AgentNetwork` from `GetAgentNetwork()` onto `proto.AccessLog.AgentNetwork`; the gate at middleware.go:95 honors `EnableLogCollection`, skipping `l.log` but keeping `trackUsage` so bandwidth telemetry survives. + +## Public contracts touched + +- `proxy.Server.MiddlewareDataDir` (string) — base dir for file-backed middleware config (server.go:238-241). +- `proxy.Server.MiddlewareCaptureBudgetBytes` (int64) — process-wide capture cap; defaults to 256 MiB (server.go:248-250). +- `proxy/internal/proxy.WithMiddlewareManager(*middleware.Manager) Option` — new option on `NewReverseProxy`; nil keeps the fast path (reverseproxy.go:48-56). +- `proxy/internal/proxy.PathTarget` adds `Middlewares`, `CaptureConfig`, `AgentNetwork`, `DisableAccessLog` (servicemapping.go:27-51), all zero-default. +- `proxy/internal/proxy.CapturedData` adds `agentNetwork`, `suppressAccessLog`, `userGroupNames` behind `sync.RWMutex`; slices deep-copied (context.go:47-66, 183-258). +- `accesslog.logEntry.AgentNetwork` + `proto.AccessLog.AgentNetwork` (logger.go:131, 268). +- `metrics.Metrics.Meter()` exposes the OTel meter for the middleware manager (metrics.go:53-58). + +## Invariants + +- **Synth-service updates are live (no proxy restart).** Every `MODIFIED` flows through `modifyMapping → cleanupMappingRoutes` (invalidates chains) `→ setupHTTPMapping → updateMapping → rebuildMiddlewareChains`. **ProxyMapping.Private preservation:** the relevant logic lives in `management/internals/shared/grpc/proxy.go:shallowCloneMapping`, not this module, but it surfaces here — if a `MODIFIED` synth service arrives `private=false`, auth skips `ValidateTunnelPeer`, `CapturedData.UserGroups` stays empty, and `llm_router` denies with `llm_policy.no_authorised_provider` until a management restart re-pushes the snapshot. This module assumes `mapping.GetPrivate()` is correct on every batch. +- **`EnableLogCollection=false` suppresses access-log writes but middleware still runs.** Gate is one `if !cd.GetSuppressAccessLog()` immediately around `l.log(entry)` (middleware.go:95); `trackUsage` runs below the gate. Locked by `TestMiddleware_SuppressAccessLog_PreservesUsageTracking` (middleware_test.go:139). +- **`agent_network` flag on access-log entries is set when the chain processed the request.** Source `target.AgentNetwork`, stamped at reverseproxy.go:105, read at accesslog/middleware.go:86. +- **auth → builtin group propagation.** `Protect` writes `UserGroups`/`UserGroupNames`; `buildRequestInput` (reverseproxy.go:333) copies them into `middleware.Input`. The response-leg `respInput` (reverseproxy.go:196-223) also carries `UserEmail`/`UserGroups`/`UserGroupNames` — `llm_limit_record` needs `UserGroups` to ship `group_ids` so management's group-targeted budget rules match (comment at reverseproxy.go:211-215). +- **Empty chains stay on the fast path.** `ServeHTTP` skips body capture and the run sequence when `chain == nil || chain.Empty()` (reverseproxy.go:127). +- **Self-registration is the only way a builtin reaches the registry.** `middleware_register.go` blank-imports each builtin; `init()` adds the factory to `mwbuiltin.DefaultRegistry()`. Missing it → translator drops the entry with a warn (translate.go:97). + +## Things to scrutinize + +### Correctness +- **Translate edge cases** — drops on nil cfg, empty ID, unknown ID, UNSPECIFIED slot; each logs one warn; volume bounded by `MaxMiddlewaresPerChain`. +- **Re-translate without dropping in-flight requests** — `Manager.Rebuild` is the only call from `rebuildMiddlewareChains`. Reverse proxy reads `ChainFor` once per request (reverseproxy.go:327) and runs the captured `*Chain` for the whole request. Verify in module 30 that `Rebuild` swaps atomically. +- **ProxyMapping.Private preservation** — enforced management-side in `shallowCloneMapping`. Proxy-side regression catches: `TestProtect_PrivateService_TunnelPeerGroupsPropagate` + the integration test. +- **Body-capture cleanup** — `defer releaseBudget()` (reverseproxy.go:145) and `defer capturingWriter.Release()` (reverseproxy.go:180) must run on every return; confirm no future `return` lands between acquisition and defer. +- **`applyUpstreamRewrite` clones the URL** — `cloned := *orig` value-copies `*url.URL`; safe because overwritten fields are strings, not slices/maps (reverseproxy.go:285-292). + +### Security +- **Translate validates every config** — registry membership rejects unknown IDs; UNSPECIFIED slot drops; ID-less drops; raw config copied (not aliased) at translate.go:109. +- **`AuthHeader`/`StripHeaders` only reachable via `UpstreamRewrite`** — regular mutation surface goes through the framework denylist (`Authorization`/`Cookie` blocked); only the router middleware can replace `Authorization` (reverseproxy.go:296-304). Confirm in module 30 nothing outside the proxy-trusted path populates `UpstreamRewrite.AuthHeader`. +- **`stampNetBirdIdentity` strips client-sent values first** (reverseproxy.go:742-743) — anti-spoof for `X-NetBird-User`/`X-NetBird-Groups`; control chars filtered; comma-bearing labels dropped (reverseproxy_test.go:1217/:1243/:1193). +- **Auth → group propagation** — `auth/middleware_test.go:322` and `:366` cover the contract. If auth ever stops calling `ValidateTunnelPeer` for Private services, every agent-network request silently denies. + +### Concurrency +- **Chain replacement under in-flight requests** — `findTargetForRequest` takes `mappingsMux.RLock`; `AddMapping` writes. `resolveChain` calls `ChainFor` once; even if `Rebuild` swaps mid-request, in-flight requests keep running on the captured pointer. +- **`CapturedData` mutation across slots** — accessors take `sync.RWMutex`; slices deep-copied on both Set and Get. Verify no caller mutates the returned slice expecting it to land back. +- **`Manager.Invalidate` race** — `removeMapping` invalidates after `cleanupMappingRoutes`; mapping read happens before chain resolution, so requests before invalidate run captured chains; later ones fail `findTargetForRequest`. +- **`Logger.log` goroutine** — `logSem` caps at `maxLogWorkers = 4096`; overflow → `dropped.Add(1)` + debug log. Middleware test uses a buffered channel and 150ms negative-assertion window — review whether 150ms holds on slow CI. + +### Backward compatibility +- **Non-agent-network services unaffected** — `protoToMapping` reads new fields only when `opts != nil`; defaults leave `Middlewares`/`CaptureConfig` nil → chain resolves nil → fast path. Existing `reverseproxy_test.go` (non-chain) still passes. +- **`disable_access_log` is proto field 13, default false** — every existing target unset; gate is no-op. Locked by `TestMiddleware_SuppressAccessLog_DefaultEmitsLog` (middleware_test.go:104). +- **`Server` additions optional** — 256 MiB default when `MiddlewareCaptureBudgetBytes ≤ 0` (server.go:1997-2000). + +### Performance +- **Translate cost per push** — O(n) with per-entry registry lookup and `config_json` copy; negligible vs. the upstream gRPC unmarshal. +- **Empty-chain hot path** — one `ChainFor` map lookup + one `chain.Empty()` check; no allocation delta vs. pre-PR. +- **Body capture buffer churn** — `bodytap.CaptureRequest` allocates `MaxRequestBytes` per chain-hitting request; `releaseBudget` ties allocation to the 256 MiB proxy-wide budget. Confirm in module 30 the budget is a hard cap. + +### Observability +- **Metrics** — `Metrics.Meter()` shared with `middleware.NewMetrics` (server.go:1990-1993) so middleware instruments land in the same prometheus exporter. No new metrics defined here. +- **Access-log accuracy** — every entry carries `AgentNetwork`; terminal-slot metadata merged into `CapturedData.Metadata` (reverseproxy.go:238-241). +- **Deny logs at `Infof`** (reverseproxy.go:170) — review whether `Info` is too noisy at high deny rates; consider Debug or rate-limit. + +## Test coverage + +| Test file | Locks down | +| --------- | ---------- | +| proxy/middleware_translate_test.go | Empty/nil → nil; field preservation; unknown ID skip; nil registry permissive; timeout clamping; fail-mode + slot incl. UNSPECIFIED-drop; empty-ID drop; truncation above + at `MaxMiddlewaresPerChain` | +| proxy/internal/proxy/reverseproxy_test.go | Rewrite host/headers/cookies/query; trusted proxy; path forwarding; classifyProxyError; X-NetBird-User/Groups anti-spoof + CSV-join + control-char/comma rejection + fallback-to-ID; `TestBuildRequestInput_PropagatesIdentityAndGroups` (UserGroups/Email/GroupNames/AgentNetwork reach `middleware.Input`) | +| proxy/internal/proxy/agent_network_chain_realstack_test.go | **The end-to-end integration test.** Drives a real agent-network request through `ReverseProxy.ServeHTTP` with the chain the synthesizer produces, against an in-process management gRPC (bufconn) backed by a real sqlite store + real `agentnetwork.Manager`, plus an `httptest` upstream — no external infrastructure or real LLM. Guarantees: (1) response-leg `respInput` carries `UserGroups` so `llm_limit_record` ships non-empty `group_ids` and the admin-group consumption row increments; (2) `RedactPii=true` redacts both prompt and completion on captured metadata; (3) the full chain runs against a real management stack. **Line 189-211 inlines the proto→Spec mapping** instead of calling the proxy's private `translateMiddlewareConfig` — keep that inline mirror in sync with `proxy/middleware_translate.go` or the test silently diverges from production. | +| proxy/internal/accesslog/middleware_test.go | `SuppressAccessLog=true` skips `SendAccessLog` (150ms negative wait); default emits one send (2s positive); usage tracking runs under suppression | +| proxy/internal/auth/middleware_test.go | `TestProtect_PrivateService_TunnelPeerGroupsPropagate` proves `peer_group_ids` reach `CapturedData.UserGroups`; `TestProtect_PrivateService_TunnelPeerDenied` proves rejected peers 403 without reaching the handler | + +The integration test runs in a few seconds with no external infrastructure — exercising the real synthesizer, `Manager.Rebuild`, `ServeHTTP` dispatch, and `llm_limit_record` writing a real consumption row through the real `agentnetwork.Manager` over real gRPC. + +## Known limitations / explicit non-goals + +- **Translator does not validate `RawConfig` JSON** — factory's job at `New([]byte)`. Confirm in module 30 that a per-binding factory failure doesn't poison the rest of the chain. +- **No throttle on management push rate** — every `MODIFIED` triggers `Manager.Rebuild`. Mitigation upstream. +- **Streaming responses (SSE)** — body capture is streaming-aware, but response-leg middleware runs only after the response completes; long SSE streams delay `llm_limit_record` until close. +- **OIDC-only path doesn't carry tunnel-peer groups** — agent-network synth services rely on the Private tunnel-peer path; JWT groups claim is the only carrier for non-Private OIDC. +- **`agent_network` flag on L4 entries** not added; HTTP-only. +- **`mw.capture.bypass_reason` metadata key** documented at reverseproxy.go:151,184; namespace this in module 30/31 to avoid collisions. + +## Cross-references +- Upstream: [shared/api](10-shared-api.md), [proxy/middleware-framework](30-proxy-middleware-framework.md), [proxy/middleware-builtin](31-proxy-middleware-builtin.md), [proxy/llm-parsers](32-proxy-llm-parsers.md) +- End-to-end flow: [../01-end-to-end-flows.md](../01-end-to-end-flows.md) +- Top-level: [../00-overview.md](../00-overview.md) diff --git a/docs/agent-networks/modules/40-dashboard.md b/docs/agent-networks/modules/40-dashboard.md new file mode 100644 index 000000000..4ed9021bb --- /dev/null +++ b/docs/agent-networks/modules/40-dashboard.md @@ -0,0 +1,228 @@ +# dashboard — UI for agent-networks + +This module documents code that lives in the **dashboard repo** (under +`src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`), not +in this repo. It is co-located here so backend readers see the full picture. + +> **Risk level:** Medium. The new surface is isolated under `src/modules/agent-network/` and `src/app/(dashboard)/agent-network/`, but it also reshapes the sidebar, splits `/peers`, renames `reverse-proxy/clusters` → `self-hosted-proxies`, and overlays the Control Center graph. Regressions here would be cross-cutting. +> **Backward-compat impact:** Additive on the API side. Breaking on URL/navigation: `/peers` redirects to `/peers/devices` (src/app/(dashboard)/peers/page.tsx:7-15), `/reverse-proxy/clusters` was renamed to `/reverse-proxy/self-hosted-proxies`, the sidebar lost Access Control / Networks / Reverse Proxy / DNS / standalone Guardrails / Consumption / Activity (Navigation.tsx:165-171 — routes still resolve via URL), and the standalone `/agent-network/{access-log,consumption,global-controls}` routes are gone in favor of `/agent-network/observability`. + +## Module boundary + +The dashboard is the only place an operator interacts with agent-networks: provider catalog, configured providers, policies, guardrails, account-level budget rules, account settings (collection / redaction toggles), per-request access log, and consumption rollups all render, paginate, and edit here. Data flows in via SWR (`useFetchApi`) keyed by REST URL. One big context provider (`src/modules/agent-network/AIProvidersProvider.tsx`) aggregates five resources (providers, policies, guardrails, budget rules, settings) plus the proxy access-log stream filtered to `agent_network=true`, and exposes `add* / update* / toggle* / delete*` mutators that call through `useApiCall` and re-`mutate()` SWR. Pages mount the provider once at the top and compose presentational tables and modals beneath. The control-center page additionally fetches `/agent-network/{providers,policies}` directly (control-center/page.tsx:123-130) to overlay graph nodes. + +## What the UI delivers + +- **AI Observability** page with four tabs: Access Logs, Budget Dashboard, + Budget Settings, Log Settings (replaces the standalone access-log, + consumption, and global-controls routes). +- **Providers** page: provider catalog + connect/edit wizard with per-vendor + copy (LiteLLM, Portkey, Bifrost, Cloudflare, Vercel, OpenRouter, custom). +- **Policies** page: group → provider authorization with per-policy Limits + (minute-granular windows) + guardrail attach. +- **Guardrails** page: reusable model-allowlist + prompt-capture sets. +- **Account controls**: Log Collection / Prompt Collection / Redact PII toggles. +- **Budget rules**: account-level rules reusing the policy Limits UI. +- **Control Center overlay**: provider + agent-policy nodes on the graph. +- **Navigation + peers reshaping**: peers split into Devices / Agents, + `reverse-proxy/clusters` renamed to `self-hosted-proxies`, sidebar + repackaged for agent-network focus. + +## Surface added + +### New pages + +| Route | Purpose | Backing module(s) | +| ----- | ------- | ----------------- | +| `/agent-network` | Redirect to `/agent-network/providers` | page.tsx:7-15 | +| `/agent-network/providers` | List + connect providers; header surfaces per-account base URL | providers/page.tsx + AgentProvidersTable + AIProviderModal | +| `/agent-network/policies` | Group → Provider authorization with per-policy Limits + Guardrail attach | policies/page.tsx + AgentPoliciesTable + AgentPolicyModal | +| `/agent-network/guardrails` | Reusable guardrail sets (model allowlist + prompt capture) | guardrails/page.tsx + AgentGuardrailsTable + AgentGuardrailModal | +| `/agent-network/observability` | Tabs: Access Logs / Budget Dashboard / Budget Settings / Log Settings | observability/page.tsx | +| `/peers/devices`, `/peers/agents` | Split of `/peers`, shared via `PeersListView` keyed by `kind` | peers/{devices,agents}/page.tsx | +| `/reverse-proxy/self-hosted-proxies` | Renamed from `clusters` | self-hosted-proxies/page.tsx | + +Removed in favor of `/agent-network/observability`: `/agent-network/access-log`, `/agent-network/consumption`, `/agent-network/global-controls`. + +### New modules under src/modules/agent-network + +| File | Role | +| ---- | ---- | +| AIProvidersProvider.tsx (~1158 LOC) | Aggregates every agent-network resource via SWR; normalises snake↔camel; exposes mutators; holds wizard-open state | +| AIProviderModal.tsx (~1268 LOC) | Connect / edit provider wizard with per-vendor copy (Bifrost, Portkey, LiteLLM, Cloudflare, Vercel, OpenRouter, custom) | +| AIProviderLogo + useProviderCatalog | Catalog-driven brand swatch + SWR hook over `/agent-network/catalog/providers` | +| AgentPoliciesTable + AgentPolicyModal + AgentPolicyGuardrailsTab + AgentPolicyLimitsTab | Policies; modal has 3 tabs (Rule, Limits, Guardrails) | +| AgentGuardrailsTable + AgentGuardrailModal + AgentGuardrailBrowseModal + AgentGuardrailChecksCell | Guardrails CRUD + attach-from-policy | +| AgentBudgetRulesTable + AgentBudgetRuleModal | Account-level budget rules; modal reuses AgentPolicyLimitsTab verbatim | +| AgentAccountControlsCard | Three account-wide toggles (Log Collection / Prompt Collection / Redact PII) | +| AgentAccessLogTable + AgentAccessLogExpandedRow | Access log on `/events/proxy?agent_network=true` | +| AgentConsumptionPanel + AgentConsumptionTable | Token + cost panel: charts + counter table | +| table/AgentProvidersTable + AgentProviderActionCell | Providers table + per-row actions | +| data/mockData.ts | Domain types and a few residual `MOCK_*` constants (see scrutinize) | + +### Touched non-agent-network areas + +- **control-center**: agent-network overlay (provider + agent-policy nodes); removed the All Networks dropdown; hid the Networks tab in FlowSelector (FlowSelector.tsx:9-14 — enum value kept so `?tab=networks` still type-checks); wrapped `ControlCenterView` in `AIProvidersProvider` (page.tsx:73-83); `agentPolicyNode` clicks routed to a separate state slot (page.tsx:1871-1874). New node renderers: nodes/ProviderNode.tsx, nodes/AgentPolicyNode.tsx (registered at utils/nodes.ts:21-22). +- **peers**: Split into Devices and Agents sub-routes; shared via `PeersListView` keyed by `kind` (PeersListView.tsx:24-95). New compact-toolbar `UserFilterSelector` (users/UserFilterSelector.tsx). +- **reverse-proxy**: Folder rename `clusters/` → `self-hosted-proxies/`; deleted `ClustersFeaturesCell.tsx`, `ClusterTypeIndicator.tsx`; new ReverseProxyClusterTargetSelector for cluster target type; Private toggle on target modal; body-capture knobs removed; new ReverseProxyEventExpandedRow. +- **events**: `ReverseProxyEventsUserCell` rewritten with user + peer fallback (ReverseProxyEventsUserCell.tsx:14-21), shared with the access-log table. +- **navigation**: Full repackaging in Navigation.tsx — Agent Network items flattened (no collapsible parent), distinct icons per item; Access Control, Networks, Reverse Proxy, DNS, standalone Guardrails, Consumption, Activity removed (still URL-reachable, per lines 165-171). + +## Architecture & flow + +### Page → Provider → Table/Modal hierarchy + +```mermaid +graph TD + Nav[Navigation.tsx] + Nav --> ProvidersPage[/agent-network/providers/] + Nav --> PoliciesPage[/agent-network/policies/] + Nav --> GuardrailsPage[/agent-network/guardrails/] + Nav --> ObsPage[/agent-network/observability/] + + ProvidersPage --> AIPP1[AIProvidersProvider] + PoliciesPage --> AIPP2[AIProvidersProvider] + GuardrailsPage --> AIPP3[AIProvidersProvider] + ObsPage --> AIPP4[AIProvidersProvider] + ObsPage -.wraps.-> GroupsProvider + ObsPage -.wraps.-> PeersProvider + + AIPP1 --> ProvTable[AgentProvidersTable] + ProvTable --> ProvModal[AIProviderModal] + AIPP2 --> PolTable[AgentPoliciesTable] + PolTable --> PolModal[AgentPolicyModal] + PolModal --> PolGuardTab[AgentPolicyGuardrailsTab] + PolModal --> PolLimitsTab[AgentPolicyLimitsTab] + PolGuardTab --> GuardBrowse[AgentGuardrailBrowseModal] + PolGuardTab --> GuardModal[AgentGuardrailModal] + AIPP3 --> GuardTable[AgentGuardrailsTable] + GuardTable --> GuardModal + AIPP4 --> Tabs[Tabs] + Tabs --> AccessLog[AgentAccessLogTable] + Tabs --> Consumption[AgentConsumptionPanel] + Tabs --> BudgetRules[AgentBudgetRulesTable] + Tabs --> AccountCtl[AgentAccountControlsCard] + BudgetRules --> BudgetModal[AgentBudgetRuleModal] + BudgetModal -.reuses.-> PolLimitsTab +``` + +### AI Observability tab page + +```mermaid +graph LR + Page[AIObservabilityPage] --> RA[RestrictedAccess
permission.services.read] + RA --> GP[GroupsProvider] + GP --> PP[PeersProvider] + PP --> AIP[AIProvidersProvider] + AIP --> Tabs[Tabs / TabsList] + Tabs --> T1[Access Logs
AgentAccessLogTable] + Tabs --> T2[Budget Dashboard
AgentConsumptionPanel] + Tabs --> T3[Budget Settings
AgentBudgetRulesTable] + Tabs --> T4[Log Settings
AgentAccountControlsCard] + T1 -.GET.-> EP[/events/proxy?agent_network=true/] + T2 -.GET poll 5s.-> CONS[/agent-network/consumption/] + T3 -.GET/PUT.-> BR[/agent-network/budget-rules/] + T4 -.GET/PUT.-> ST[/agent-network/settings/] +``` + +### Data fetch path + +```mermaid +graph TD + Page[Page component] --> Prov[AIProvidersProvider] + Prov -->|useFetchApi| SWR[(SWR cache
key = URL)] + SWR -.GET.-> P[/agent-network/providers/] + SWR -.GET.-> POL[/agent-network/policies/] + SWR -.GET.-> G[/agent-network/guardrails/] + SWR -.GET.-> BR[/agent-network/budget-rules/] + SWR -.GET ignoreError.-> ST[/agent-network/settings/] + SWR -.GET.-> CAT[/agent-network/catalog/providers/] + SWR -.GET pageSize=100.-> EVT[/events/proxy agent_network=true/] + Prov --> Mut[useApiCall.post/put/del] + Mut -.on success.-> MutateSWR[SWR mutate keys] + Prov --> Children[Tables / Modals via useAIProviders] +``` + +Every list view reaches management through SWR over `/api/agent-network/*`. The provider context maps snake-case payloads to camelCase domain types (`fromAPI`, `policyFromAPI`, `guardrailFromAPI`, `budgetRuleFromAPI`, `settingsFromAPI`, `accessLogFromAPI` — AIProvidersProvider.tsx:138-562) and back via matching `*ToRequest` adaptors. The access log piggy-backs on `/events/proxy` with `agent_network=true&page_size=100` (line 707-709) and decodes LLM-specific fields from per-event `metadata`. Group IDs on events are resolved to current names through the surrounding GroupsProvider catalog (lines 515-521, 717-731) — no extra round trip. Mutators run `*ToRequest`, await `useApiCall.post/put/del`, call SWR `mutate()`, then `notify`. Errors caught and surfaced via `notify` — no exceptions escape into render. The Connect Provider modal's open state lives in the provider itself (`isWizardOpen` at lines 732-735) so the providers-page empty-state CTA and the table's + button share one modal. Control-center re-fetches `/agent-network/{providers,policies}` directly on top of `AIProvidersProvider` — SWR de-dupes but the code path is harder to reason about. + +## Public contracts consumed + +- `GET/POST /api/agent-network/providers`, `PUT/DELETE /:id` +- `GET/POST /api/agent-network/policies`, `PUT/DELETE /:id` +- `GET/POST /api/agent-network/guardrails`, `PUT/DELETE /:id` +- `GET/POST /api/agent-network/budget-rules`, `PUT/DELETE /:id` +- `GET/PUT /api/agent-network/settings` (ignoreError-tolerant; 404 = not yet bootstrapped — auto-bootstrap on first provider create via `bootstrap_cluster` field — AIProvidersProvider.tsx:737-760) +- `GET /api/agent-network/catalog/providers` (read-only declarative; backend owns vendor list, IDs, brand colors, models, extra_headers, identity_injection — useProviderCatalog.ts:6-95) +- `GET /api/agent-network/consumption` (polled every 5s on Budget Dashboard — ConsumptionPanel.tsx:53,65-71) +- `GET /api/events/proxy?agent_network=true&page_size=100` (shared with Proxy Events) +- `permission?.services?.read` gates every agent-network route via RestrictedAccess. + +`AIProviderId` is a closed union in dashboard types (data/mockData.ts:8-21) but the converter tolerates anything the backend ships — unknown ids fall through to `"custom"` (AIProvidersProvider.tsx:497-506). Catalog values are pure read-through: anything declared in `extra_headers` renders in the modal automatically, copy keyed by header name (`EXTRA_HEADER_UI` in AIProviderModal.tsx:61-89), labeled-fallback for unknown ones. + +## Invariants + +- Provider context wrap order on user-attribution pages: `GroupsProvider > PeersProvider > AIProvidersProvider` (observability/page.tsx:87-89). Reverse it and access-log group resolution silently drops names. +- Every agent-network route checks `permission?.services?.read` via `RestrictedAccess` (observability/page.tsx:85, providers/page.tsx:184, policies/page.tsx:53, guardrails/page.tsx:55). +- Modal `key={open ? 1 : 0}` pattern is used to force unmount/remount on close so internal `useState` resets between edits (AgentBudgetRuleModal.tsx:60, AgentPolicyModal.tsx:66). Removing this would leak prior-row state into a new-row session. +- `mockData.ts` is the canonical home for ALL agent-network domain types; `MOCK_*` constants must never reach a production code path. One leak remains (below). + +## Things to scrutinize + +### Correctness + +- **Tab-state URL hand-off is one-way.** observability/page.tsx:53-58 reads `?tab=` on mount (despite the file comment at line 28 saying URL hand-off is future) but `setTab` does NOT push back, so reload preserves the chosen tab only if it came in via the link. Inconsistent with control-center (page.tsx:1817-1831). +- **Provider overlay runs only in `applySingleGroupView` / `applyPeerView`** (control-center/page.tsx:557, 1159-1166). User view does NOT show providers — if agent-network is a primary lens, that's a gap. +- **Two useEffects race to invalidate the control-center layout.** page.tsx:1655-1657 drops `layoutInitialized` when `agentPolicies` / `agentProviders` arrive; the main effect (1786-1799) also lists them as deps. Functional but fragile — watch for flash-of-empty-graph. +- **`updateProvider` / `updatePolicy` / `updateBudgetRule` use `??` on `enabled`** (AIProvidersProvider.tsx:784, 859, 1018). Toggle paths are safe; any caller sending `enabled: false` thinking "leave it off" gets `existing.enabled` instead. Audit modal callers. +- **Form validation in modals is minimal.** Window-seconds picker — mockData.ts:209-215 documents "minimum 60 — one minute" but there is no matching UI guard in PolicyLimitsTab; the backend validator is the enforcement point. + +### Security + +- **No client-side enforcement claims** — every cap, allowlist, and toggle is display + edit; proxy is the source of truth for deny decisions (AccessLogTable.tsx:177-191 renders backend-emitted `denyReason` as-is). +- **Prompt display is gated by what the backend stamps.** When `enable_prompt_collection` is OFF the proxy must not put prompt/completion into event metadata; the dashboard renders whatever it gets verbatim (AccessLogTable lines 532-534, AccessLogExpandedRow.tsx:42-57). No UI filter on top of backend collection switches. +- Account Controls disables `Redact PII` when `Prompt Collection` is off (AgentAccountControlsCard.tsx:122) and clears it on off-transition (line 100), but relies on backend to enforce the same gate at write — confirm PUT handler rejects `redact_pii=true && enable_prompt_collection=false`. +- **Bifrost identity-header overrides**: empty-string vs nil semantics documented in AIProvidersProvider.tsx:772-781 ("omitted = preserve, empty = explicit clear"). Mishandling could leak group attribution to a header the operator thought disabled. Focused read of Bifrost code path in AIProviderModal.tsx recommended. + +### Accessibility + +- Observability TabsList (observability/page.tsx:96-113) uses the shared Tabs component — should inherit Radix roving-tabindex. All four TabsTriggers carry only icon + text, no `aria-label`; fine because text is visible. +- Modal focus traps are inherited from the shared Modal; agent-network modals don't override them. Quick keyboard pass recommended. +- `EndpointBadge` Copy button (providers/page.tsx:66-76) has an `aria-label`, good. + +### Performance + +- `AgentConsumptionPanel` polls `/agent-network/consumption` every 5s (ConsumptionPanel.tsx:53,70). Tab switches unmount the panel, so the poll stops — verify in network panel. +- `AgentAccessLogTable` is hard-capped at 100 rows via `page_size=100` (AIProvidersProvider.tsx:707-709). Server-side pagination is future work; high-traffic tenants miss everything past row 100 — known limitation. +- Observability page mounts providers ONCE at page level (observability/page.tsx:87-89); tab switches keep SWR cache hot. Moving the provider mount inside `TabsContent` would re-fetch the access log on every switch. + +### Visual consistency + +- The observability tab style mirrors peers/page.tsx. Outer Tabs `pt-4 pb-0 mb-0`, TabsList `px-8` (observability/page.tsx:94-96) — confirm chrome height matches so the page doesn't visually jump. +- Sidebar: `Boxes` for Providers, `AccessControlIcon` for Policies, `TelescopeIcon` for AI Observability (Navigation.tsx:113,120,133). Reusing `AccessControlIcon` makes Policies look identical to the (now hidden) Access Control item — if Access Control ever comes back, they collide. +- `AgentNetworkIcon` is used in breadcrumbs on every agent-network page but NOT in the sidebar (per-page icons instead). Deliberate departure — record so it doesn't get reverted. + +## Test coverage + +- **Cypress**: One file (`cypress/e2e/test.cy.ts`) covering only the install-page copy-to-clipboard flow. NOTHING covers agent-network UI. +- **Component / unit tests**: `src/utils/version.test.ts` is the only `.test.*` file in the repo. The agent-network modules ship without component tests. +- Data-cy hooks exist on key controls: `save-account-controls` (AgentAccountControlsCard.tsx:71), `enable-log-collection`, `enable-prompt-collection`, `redact-pii`, plus existing `data-cy={policy.name}` / `data-cy={provider.name}` on ActiveInactiveRow. Sufficient hooks for Cypress flows; none written yet. +- **Tooling gap (pre-existing):** `npm run lint` (`next lint`) is broken in Next 16 — the `lint` subcommand was removed from the Next CLI in 16.x, so the dashboard effectively has no working lint gate. The fix is to add either a flat-config `eslint .` script or wire ESLint via an explicit `eslint-config-next` invocation. + +## Known limitations / explicit non-goals + +- **`data/mockData.ts` still contains `MOCK_GROUPS`, `MOCK_PROVIDERS`, `MOCK_PEERS`.** Only `MOCK_GROUPS` is referenced from production — AgentPoliciesTable.tsx:45,76 uses it as a name-lookup fallback when a policy references a group ID the real GroupsProvider doesn't know about. `MOCK_PROVIDERS` / `MOCK_PEERS` are unreferenced; safe to delete. The file is `/* eslint-disable */` so dead-code warnings don't flag them. +- **Tab-state URL hand-off on observability page is one-way** (read-only). +- **Access log hard-capped at 100 rows**; no server-side pagination. +- **No optimistic updates.** All mutations are round-trip; failures rollback via SWR revalidation. +- **`FlowView.NETWORKS` retained but hidden** from FlowSelector (FlowSelector.tsx:9-14). Old `?tab=networks` links still route to the hidden view because `applyNetworksView` still runs. +- **Redirects are not query-preserving** — `router.replace("/peers/devices")` (peers/page.tsx:13) strips any incoming filter params. +- **Control-center cross-fetches** `/agent-network/{providers,policies}` directly on top of `AIProvidersProvider`. Could be collapsed. +- **Sidebar permanently hides Access Control, Networks, Reverse Proxy, standalone Guardrails, DNS, Activity, Consumption.** Routes still resolve via URL (Navigation.tsx:165-171); intentional. + +## Cross-references + +- Upstream API contracts: [shared/api](10-shared-api.md) +- Backend persistence: [management/store](20-management-store.md) +- Backend handler wiring: [management/handlers + wiring](22-management-handlers-wiring.md) +- End-to-end flow narrative: [../01-end-to-end-flows.md](../01-end-to-end-flows.md) +- Top-level overview: [../00-overview.md](../00-overview.md) diff --git a/docs/agent-networks/modules/50-path-routed-providers.md b/docs/agent-networks/modules/50-path-routed-providers.md new file mode 100644 index 000000000..b7cda3a97 --- /dev/null +++ b/docs/agent-networks/modules/50-path-routed-providers.md @@ -0,0 +1,251 @@ +# path-routed providers — Vertex AI + Bedrock + +This guide pulls the **path-routed** provider story together in one place +because it crosses the catalog, the synthesiser, the request parser, and the +router. The relevant building blocks are the `llm_router` / +`llm_request_parser` middlewares +([31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md)), the +per-provider parser surface ([32-proxy-llm-parsers.md](32-proxy-llm-parsers.md)), +and the synthesiser's catalog → `ProviderRoute` mapping +([21-management-agentnetwork.md](21-management-agentnetwork.md)). + +Sibling modules: [31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md) +(router + request parser) and [32-proxy-llm-parsers.md](32-proxy-llm-parsers.md) +(Bedrock parser + pricing). + +--- + +## What "path-routed" means + +Most catalog providers carry the model in the request **body** (`{"model": …}`), +so `llm_router` selects an upstream by matching the model name against each +provider's `Models` claim. Two providers instead carry the model in the **URL +path**, so they are routed by path before the model/vendor table is consulted: + +| Catalog id | Style flag | Request path shape | +|---|---|---| +| `vertex_ai_api` | `IsVertexPathStyle` → `ProviderRoute.Vertex` | `/v1/projects/{project}/locations/{region}/publishers/{publisher}/models/{model}:{action}` | +| `bedrock_api` | `IsBedrockPathStyle` → `ProviderRoute.Bedrock` | `/model/{modelId}/{action}` (optionally behind `/bedrock`) | + +The catalog declares the style with +[`catalog.IsVertexPathStyle` / `catalog.IsBedrockPathStyle`](../../../management/server/agentnetwork/catalog/catalog.go) +and the synthesiser copies the result onto the router route as the `Vertex` / +`Bedrock` booleans +([synthesizer.go:450-451](../../../management/server/agentnetwork/synthesizer.go)). +On the request leg `llm_router.Invoke` dispatches `isVertexPath` / `isBedrockPath` +**before** the model lookup +([llm_router/middleware.go:138-216](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)) +so a model the parser extracted from the path can't be claimed by a same-vendor +*body-routed* provider (e.g. `claude-*` on `api.anthropic.com`). + +## Google Vertex AI (`vertex_ai_api`) + +### Catalog entry + +`KindProvider`, parser surface left unset on the catalog entry — the request +parser picks the parser from the URL **publisher** segment, not from +`ParserID`. Upstream host is `-aiplatform.googleapis.com` +(`https://aiplatform.googleapis.com` for the `global` location). The catalog +lists the Claude-on-Vertex lineup (`claude-opus-4-*`, `claude-sonnet-4-*`, +`claude-haiku-4-5`, `claude-fable-5`) at the same per-token rates as the +first-party Anthropic entry +([catalog.go:333-363](../../../management/server/agentnetwork/catalog/catalog.go)). + +### Credential — service-account OAuth (`keyfile::`) + +Vertex does **not** accept a static API key. The operator sets the provider +`api_key` to: + +``` +keyfile:: +``` + +The synthesiser recognises the `keyfile::` prefix in `providerAuthHeader` +([synthesizer.go:897-903](../../../management/server/agentnetwork/synthesizer.go)), +emits **no** static auth value, and carries the base64 key material on the +route as `GCPServiceAccountKeyB64` +([factory.go:56-61](../../../proxy/internal/middleware/builtin/llm_router/factory.go)). +At request time the router mints a short-lived OAuth2 access token from the key +(cloud-platform scope) and injects `Authorization: Bearer ` — +never the key itself +([llm_router/middleware.go:621-692](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)): + +- One auto-refreshing `oauth2.TokenSource` is cached per key (keyed by a + SHA-256 of the base64 material), so token minting happens once and refreshes + amortise across requests. +- Mint / refresh is bounded by a 10s timeout HTTP client (`gcpTokenTimeout`) so + a slow Google token endpoint can't hang the request. +- A malformed key or an unreachable token endpoint fails the request with + `llm_policy.upstream_auth_failed` at HTTP **502** (an upstream problem, not a + policy denial) — see `denyUpstreamAuth`. + +### Metering — Anthropic-on-Vertex only + +The request parser extracts `{publisher, model, action}` from the path +(`parseVertexPath`, [llm_request_parser/middleware.go:237-263](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)), +strips the `@version` suffix from the model, and maps the publisher to a parser +surface via `vertexPublisherVendor`: + +- `anthropic` → `llm.provider="anthropic"` → metered through the Anthropic + parser, priced under the **`anthropic`** block in `defaults_pricing.yaml` + (the parser emits the standard Anthropic provider label, so Vertex Claude + reuses first-party Anthropic prices). +- `openai` → `llm.provider="openai"` (reserved; not in the catalog lineup + today). +- anything else (notably `google` / Gemini) → empty vendor → **no parser**. + +**Gemini is intentionally denied as unmeterable.** When the parser emits no +`llm.provider` for a Vertex publisher, `llm_router` returns +`llm_policy.unmeterable_publisher` (403) rather than forwarding the request +uncounted — serving it would bypass token / budget metering +([llm_router/middleware.go:144-162, 712-728](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)). +A Gemini parser would lift this restriction; until then the `google` publisher +is omitted from the catalog. + +> Caveat: cross-region inference profiles in `eu` / `apac` carry a ~10% price +> premium that the base per-token rates do **not** model — cost annotations for +> those regions read low. Operators who need exact regional billing override +> the affected entries in `pricing.yaml`. + +## AWS Bedrock (`bedrock_api`) + +### Catalog entry + +`KindProvider`, upstream host `bedrock-runtime..amazonaws.com`. Metered +models are the Anthropic-on-Bedrock lineup (`anthropic.claude-*`) plus Amazon +Nova and Llama 3.3 entries +([catalog.go:300-332](../../../management/server/agentnetwork/catalog/catalog.go)). +Anthropic-on-Bedrock reuses the first-party Claude prices (with additive cache +buckets); Nova / Llama report no cache, so cost is `input + output`. + +### Credential — static bearer token + +Bedrock uses the **AWS Bedrock API key** as a static bearer. The operator sets +the provider `api_key` directly (no `keyfile::` prefix); the catalog template +is `Authorization: Bearer ${API_KEY}` +([catalog.go:306-307](../../../management/server/agentnetwork/catalog/catalog.go)). +No token minting — the synthesiser substitutes the key into the template and +the router injects the resulting `Authorization` header after stripping inbound +vendor auth (including client-supplied AWS SigV4 material: `X-Amz-Date`, +`X-Amz-Security-Token`, `X-Amz-Content-Sha256`, see `strippedAuthHeaders`). + +### Model id form — cross-region inference profiles + +Bedrock model ids in the request path must be the cross-region +**inference-profile** form, e.g. +`eu.anthropic.claude-sonnet-4-5-20250929-v1:0`. The bare +`anthropic.claude-…` id is rejected by AWS. `normalizeBedrockModel` +([llm_request_parser/middleware.go:398-414](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)) +strips the region prefix (`us.` / `eu.` / `apac.` / `global.`), an optional ARN +wrapper, and the `-YYYYMMDD-vN[:N]` version/throughput suffix so the normalised +id (`anthropic.claude-sonnet-4-5`) matches the catalog/pricing key. + +### Supported endpoints + actions + +`/model/{modelId}/{action}` where action ∈ `invoke`, +`invoke-with-response-stream`, `converse`, `converse-stream` +([llm_request_parser/middleware.go:363-390](../../../proxy/internal/middleware/builtin/llm_request_parser/middleware.go)). +`invoke` / `converse` are non-streaming; the `-stream` actions set the streaming +flag. + +- **InvokeModel** body uses the vendor-native shape — for Anthropic that means + `"anthropic_version":"bedrock-2023-05-31"` and snake_case usage with additive + cache buckets. +- **Converse** uses the unified camelCase shape with a precomputed `totalTokens`. +- The `BedrockParser` reads both shapes on the response leg + ([bedrock.go](../../../proxy/internal/llm/bedrock.go)); the request parser + doesn't need to distinguish them (`ParseRequest` is a no-op — model + stream + come from the path). + +### Streaming — AWS binary event-stream + +The `-stream` actions return `application/vnd.amazon.eventstream` (the AWS +binary event-stream framing), and streaming **is metered**. +`accumulateBedrockStream` +([llm_response_parser/streaming_bedrock.go](../../../proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock.go)) +decodes the frames with `aws-sdk-go-v2/aws/protocol/eventstream`: + +- InvokeModel `chunk` frames wrap a base64 `{"bytes":…}` payload carrying a + vendor-native (Anthropic) stream event — folded through the shared Anthropic + stream accumulator. +- Converse `contentBlockDelta` frames carry text; the trailing `metadata` frame + carries the final usage block. +- A truncated stream (cut at the body-tap capture cap) decodes best-effort: + frames up to the cut are applied and partial usage is returned. + +### Optional `/bedrock` gateway-namespace prefix + +Clients may place an optional `/bedrock` prefix before the native path +(`/bedrock/model/{modelId}/{action}`) to disambiguate Bedrock from other +providers that also use `/model/...`. Both the request parser +(`trimBedrockNamespace`) and the router (`splitBedrockNamespace`) accept it. +When the prefix is present, the router sets +`RewriteUpstream.StripPathPrefix = "/bedrock"` so the **native** path +(`/model/...`) is what reaches `bedrock-runtime..amazonaws.com` +([llm_router/middleware.go:168-184, 320-348](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)). + +## Model allowlist on path-routed providers + +Because the model lives in the URL rather than the body, a path-routed provider +credential could otherwise be used for any model the upstream supports. The +router still enforces the route's `Models` allowlist via `matchPathRoute` +([llm_router/middleware.go:370-416](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)): + +1. Filter to routes of the matching style (`Vertex` / `Bedrock`). +2. Filter to routes whose `AllowedGroupIDs` authorise the caller's groups + (else `no_authorised_provider`). +3. Filter to routes that **claim the requested model**. As with body-routed + providers, an **empty `Models` list = catch-all** (serve any model); + a non-empty list serves only the listed models (else `model_not_routable`). +4. Multiple survivors disambiguate by longest `UpstreamPath` prefix match. + +So an operator who lists explicit models on a Vertex/Bedrock provider gets a +hard allowlist; an operator who leaves `Models` empty accepts every model the +upstream serves (still subject to the unmeterable-publisher gate on Vertex). + +Model-less OpenAI endpoints (`GET /v1/models`) are **never** routed to a +Vertex/Bedrock provider — `matchModelless` skips path-routed routes +([llm_router/middleware.go:427-462](../../../proxy/internal/middleware/builtin/llm_router/middleware.go)) +so a model-listing call can't be rewritten onto an upstream that would 404 it. + +## Catalog ↔ pricing cross-check + +Catalog prices and context windows are cross-checked against LiteLLM's +`model_prices_and_context_window.json`. The proxy's embedded +`defaults_pricing.yaml` covers **every metered first-party model** the catalog +enumerates — guarded by +`TestDefaultTable_FirstPartyModelCoverage` +([pricing/defaults_coverage_test.go](../../../proxy/internal/llm/pricing/defaults_coverage_test.go)), +which fails if a catalog model has no embedded price. Bedrock entries are keyed +by the **normalised** id the request parser emits (region prefix + version +suffix stripped). Vertex Claude carries no Bedrock-style prefix, so it prices +straight off the `anthropic` block. + +## Things to scrutinise + +**Security.** The Vertex service-account key is never forwarded — only a minted +short-lived bearer. Confirm the key material stays out of access logs (it lives +on `ProviderRoute.GCPServiceAccountKeyB64`, not in any emitted metadata key). +The unmeterable-publisher deny is the only thing standing between an +operator-misconfigured Vertex provider and unmetered Gemini traffic; verify +`vertexPublisherVendor` stays conservative (deny by default for unknown +publishers). + +**Correctness.** `normalizeBedrockModel` is the join between the wire id and the +pricing key — a model that normalises to something not in `defaults_pricing.yaml` +meters at `cost.skipped=unknown_model` rather than failing the request. The +`/bedrock` prefix strip must run on both the parser side (so the model is +extracted) and the router side (so the upstream path is native); a regression in +either silently breaks the other. + +**Metering caveats.** eu/apac cross-region Bedrock + Vertex profiles carry a +~10% premium not modelled by base pricing — flagged in both the catalog comment +and `defaults_pricing.yaml`. Operators needing exact regional billing override +the relevant entries. + +## Cross-references + +- Router + request-parser detail: [31-proxy-middleware-builtin.md](31-proxy-middleware-builtin.md) +- Bedrock parser + pricing + SSE / event-stream: [32-proxy-llm-parsers.md](32-proxy-llm-parsers.md) +- Catalog → route synthesis + `keyfile::` handling: [21-management-agentnetwork.md](21-management-agentnetwork.md) +- Overview: [../00-overview.md](../00-overview.md) diff --git a/e2e/agentnetwork/bootstrap_test.go b/e2e/agentnetwork/bootstrap_test.go new file mode 100644 index 000000000..a73d55117 --- /dev/null +++ b/e2e/agentnetwork/bootstrap_test.go @@ -0,0 +1,30 @@ +//go:build e2e + +package agentnetwork + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestCombinedBootstrap proves Pillar 1: the shared combined server came up and +// the /api/setup-minted PAT authenticates a real management API call through +// the typed REST client (the bootstrap itself ran in TestMain). +func TestCombinedBootstrap(t *testing.T) { + ctx := context.Background() + + require.NotEmpty(t, srv.PAT, "TestMain must have minted an admin PAT") + + users, err := srv.API().Users.List(ctx) + require.NoError(t, err, "authenticated Users.List must round-trip") + require.NotEmpty(t, users, "the bootstrapped account must have at least one user") + + var emails []string + for _, u := range users { + emails = append(emails, u.Email) + } + assert.Contains(t, emails, "admin@netbird.test", "the bootstrapped owner should appear in the users list") +} diff --git a/e2e/agentnetwork/chat_test.go b/e2e/agentnetwork/chat_test.go new file mode 100644 index 000000000..5e3a79273 --- /dev/null +++ b/e2e/agentnetwork/chat_test.go @@ -0,0 +1,281 @@ +//go:build e2e + +package agentnetwork + +import ( + "context" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/e2e/harness" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// providerCase is one entry in the live provider matrix. The same scenario runs +// for every available provider; availability is keyed off env vars so the suite +// covers whatever credentials are present (source ~/.llm-keys locally / set the +// Actions secrets in CI). +type providerCase struct { + name string + catalogID string + upstream string + apiKey string + model string // body model (chat/messages) or path model@version (vertex) + kind string // harness.WireChat, harness.WireMessages, or harness.WireVertex + project string // vertex only: GCP project for the rawPredict path + region string // vertex only: GCP region for the rawPredict path +} + +// availableProviders builds the matrix from the provider env vars that are set. +func availableProviders() []providerCase { + var ps []providerCase + if k := os.Getenv("OPENAI_TOKEN"); k != "" { + ps = append(ps, providerCase{name: "openai", catalogID: "openai_api", upstream: "https://api.openai.com", apiKey: k, model: "gpt-4o-mini", kind: harness.WireChat}) + } + if k := os.Getenv("ANTHROPIC_TOKEN"); k != "" { + ps = append(ps, providerCase{name: "anthropic", catalogID: "anthropic_api", upstream: "https://api.anthropic.com", apiKey: k, model: "claude-haiku-4-5", kind: harness.WireMessages}) + } + if k, u := os.Getenv("VERCEL_TOKEN"), os.Getenv("VERCEL_URL"); k != "" && u != "" { + ps = append(ps, providerCase{name: "vercel", catalogID: "vercel_ai_gateway", upstream: u, apiKey: k, model: "openai/gpt-4o-mini", kind: harness.WireChat}) + } + if k, u := os.Getenv("OPENROUTER_TOKEN"), os.Getenv("OPENROUTER_URL"); k != "" && u != "" { + // Distinct model string from Vercel so each provider routes unambiguously + // while all are enabled together. + ps = append(ps, providerCase{name: "openrouter", catalogID: "openrouter", upstream: u, apiKey: k, model: "openai/gpt-4o", kind: harness.WireChat}) + } + if k, u := os.Getenv("CLOUDFLARE_TOKEN"), os.Getenv("CLOUDFLARE_URL"); k != "" && u != "" { + // Cloudflare AI Gateway routes by a provider segment in the URL path; + // append the openai provider unless the gateway URL already carries one. + if !strings.Contains(u, "/openai") { + u = strings.TrimRight(u, "/") + "/openai" + } + // Raw model (distinct string from OpenAI's gpt-4o-mini). + ps = append(ps, providerCase{name: "cloudflare", catalogID: "cloudflare_ai_gateway", upstream: u, apiKey: k, model: "gpt-4o", kind: harness.WireChat}) + } + // Vertex (vertex_ai_api): Anthropic-on-Vertex, path-routed, SA-OAuth + // (api_key = keyfile::). The model travels in the rawPredict path rather + // than the body, so the provider is created without a models array. Region + // defaults to "global" (host aiplatform.googleapis.com); a real region uses + // -aiplatform.googleapis.com. + if sa := os.Getenv("GOOGLE_VERTEX_SA_BASE64"); sa != "" { + project := os.Getenv("GOOGLE_VERTEX_PROJECT") + if project != "" { + region := os.Getenv("GOOGLE_VERTEX_REGION") + if region == "" { + region = "global" + } + host := "aiplatform.googleapis.com" + if region != "global" { + host = region + "-aiplatform.googleapis.com" + } + model := os.Getenv("GOOGLE_VERTEX_MODEL") + if model == "" { + model = "claude-sonnet-4-5@20250929" + } + ps = append(ps, providerCase{ + name: "vertex", catalogID: "vertex_ai_api", upstream: "https://" + host, + apiKey: "keyfile::" + sa, model: model, kind: harness.WireVertex, + project: project, region: region, + }) + } + } + + // Bedrock: path-routed, bearer auth. Model is a cross-region inference + // profile id (distinct string from the first-party Anthropic case). + if k := os.Getenv("AWS_BEARER_TOKEN_BEDROCK"); k != "" { + region := os.Getenv("AWS_REGION") + if region == "" { + region = "us-east-1" + } + ps = append(ps, providerCase{name: "bedrock", catalogID: "bedrock_api", upstream: "https://bedrock-runtime." + region + ".amazonaws.com", apiKey: k, model: "us.anthropic.claude-haiku-4-5", kind: harness.WireMessages}) + } + return ps +} + +// providerRequest builds a create request for a matrix provider: enabled, with +// a uniquely-priced model for body-routed providers and none for the +// path-routed Vertex (whose model lives in the request path). +func providerRequest(pc providerCase) api.AgentNetworkProviderRequest { + req := api.AgentNetworkProviderRequest{ + Name: pc.name, + ProviderId: pc.catalogID, + UpstreamUrl: pc.upstream, + ApiKey: &pc.apiKey, + Enabled: ptr(true), + } + if pc.kind != harness.WireVertex { + req.Models = &[]api.AgentNetworkProviderModel{ + {Id: pc.model, InputPer1k: 0.001, OutputPer1k: 0.002}, + } + } + return req +} + +// TestProvidersMatrix is Pillar 3: it provisions every available provider (all +// enabled, each with a unique model so routing stays unambiguous), runs proxy + +// client once, and drives the same live chat-completion scenario through each +// provider over the WireGuard tunnel. Each provider must return 200 and produce +// an ingested access-log row. +func TestProvidersMatrix(t *testing.T) { + matrix := availableProviders() + if len(matrix) == 0 { + t.Skip("no provider keys set; source ~/.llm-keys to run the provider matrix") + } + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute) + defer cancel() + + // Group + setup key the client joins into; the policy authorizes it. + grp, err := srv.API().Groups.Create(ctx, api.PostApiGroupsJSONRequestBody{Name: "e2e-agents"}) + require.NoError(t, err, "create agents group") + t.Cleanup(func() { _ = srv.API().Groups.Delete(context.Background(), grp.Id) }) + + ephemeral := false + sk, err := srv.API().SetupKeys.Create(ctx, api.PostApiSetupKeysJSONRequestBody{ + Name: "e2e-client", + Type: "reusable", + ExpiresIn: 86400, + UsageLimit: 0, + AutoGroups: []string{grp.Id}, + Ephemeral: &ephemeral, + }) + require.NoError(t, err, "mint setup key") + require.NotEmpty(t, sk.Key, "setup key plaintext") + + // Create every provider, all enabled, each with a unique model string so the + // proxy's connect-time snapshot carries them all and model→provider routing + // is unambiguous (provider toggles after connect don't reconcile to the + // proxy, so we enable everything up front). The first create bootstraps the + // cluster. + ids := make([]string, 0, len(matrix)) + for i, pc := range matrix { + req := providerRequest(pc) + if i == 0 { + req.BootstrapCluster = ptr(harness.AgentNetworkCluster) + } + prov, perr := srv.CreateProvider(ctx, req) + require.NoError(t, perr, "create provider %s", pc.name) + ids = append(ids, prov.Id) + id := prov.Id + t.Cleanup(func() { _ = srv.DeleteProvider(context.Background(), id) }) + } + + enabled := true + pol, err := srv.CreatePolicy(ctx, api.AgentNetworkPolicyRequest{ + Name: "e2e-allow", + Enabled: &enabled, + SourceGroups: []string{grp.Id}, + DestinationProviderIds: ids, + // Token limit at the 60s window floor with caps far above the few hundred + // tokens this suite drives, so it never blocks traffic but switches on + // usage metering, which is what makes consumption rows get recorded. + Limits: &api.AgentNetworkPolicyLimits{ + TokenLimit: api.AgentNetworkPolicyTokenLimit{ + Enabled: true, + GroupCap: 10_000_000, + UserCap: 10_000_000, + WindowSeconds: 60, + }, + }, + }) + require.NoError(t, err, "create policy") + t.Cleanup(func() { _ = srv.DeletePolicy(context.Background(), pol.Id) }) + + settings, err := srv.GetSettings(ctx) + require.NoError(t, err, "read settings for endpoint") + require.NotEmpty(t, settings.Endpoint, "agent-network endpoint must be assigned") + + // Proxy (global CLI token) + client, brought up once. + proxyToken, err := srv.CreateProxyTokenCLI(ctx, "e2e-proxy") + require.NoError(t, err, "mint proxy token via CLI") + px, err := harness.StartProxy(ctx, srv, proxyToken) + require.NoError(t, err, "start proxy") + t.Cleanup(func() { _ = px.Terminate(context.Background()) }) + + cl, err := harness.StartClient(ctx, srv, sk.Key) + require.NoError(t, err, "start client") + t.Cleanup(func() { _ = cl.Terminate(context.Background()) }) + + require.NoError(t, cl.WaitConnected(ctx, 90*time.Second), "client must connect to management") + if err := cl.WaitProxyPeer(ctx, 180*time.Second); err != nil { + t.Fatalf("client did not see the proxy peer: %v\n=== proxy logs ===\n%s", err, px.Logs(context.Background())) + } + proxyIP, err := cl.ResolveProxyIP(ctx, settings.Endpoint) + require.NoError(t, err, "resolve agent-network endpoint to proxy IP") + + for _, pc := range matrix { + pc := pc + t.Run(pc.name, func(t *testing.T) { + before, _ := srv.ListAccessLogs(ctx) + + // Unique per provider so we can find this provider's row by its + // session id and confirm the marker propagated end-to-end. + sessionID := "e2e-session-" + pc.name + + // Retry briefly to absorb tunnel/DNS jitter on the first call. + var code int + var body string + deadline := time.Now().Add(90 * time.Second) + for time.Now().Before(deadline) { + var c int + var b string + var cerr error + if pc.kind == harness.WireVertex { + c, b, cerr = cl.Vertex(ctx, settings.Endpoint, proxyIP, pc.project, pc.region, pc.model, "Reply with exactly: pong", sessionID) + } else { + c, b, cerr = cl.Chat(ctx, settings.Endpoint, proxyIP, pc.kind, pc.model, "Reply with exactly: pong", sessionID) + } + if cerr == nil { + code, body = c, b + if code == 200 { + break + } + } + time.Sleep(5 * time.Second) + } + require.Equal(t, 200, code, "chat through %s (%s %s) should return 200; body: %s", pc.name, pc.kind, pc.model, body) + + require.Eventually(t, func() bool { + logs, lerr := srv.ListAccessLogs(ctx) + return lerr == nil && logs.TotalRecords > before.TotalRecords + }, 30*time.Second, 2*time.Second, "an access-log row should be ingested for %s", pc.name) + + // The session id sent as x-session-id must round-trip into the + // access-log row for this provider. + require.Eventually(t, func() bool { + logs, lerr := srv.ListAccessLogs(ctx) + if lerr != nil { + return false + } + for _, r := range logs.Data { + if r.SessionId != nil && *r.SessionId == sessionID { + return true + } + } + return false + }, 30*time.Second, 2*time.Second, "session id %q must be recorded in an access-log row for %s", sessionID, pc.name) + }) + } + + // Metering: the policy's uncapped token limit switches on usage recording, + // so the live traffic just driven must surface as consumption rows with + // positive token counts. Consumption is account-scoped (keyed by source + // group / user and time window, not per provider), and ingest is async, so + // poll for any row that has booked tokens. + require.Eventually(t, func() bool { + rows, lerr := srv.ListConsumption(ctx) + if lerr != nil { + return false + } + for _, r := range rows { + if r.TokensInput > 0 && r.TokensOutput > 0 { + return true + } + } + return false + }, 60*time.Second, 3*time.Second, "consumption must be recorded with positive token counts after live traffic") +} diff --git a/e2e/agentnetwork/main_test.go b/e2e/agentnetwork/main_test.go new file mode 100644 index 000000000..17c5e00be --- /dev/null +++ b/e2e/agentnetwork/main_test.go @@ -0,0 +1,46 @@ +//go:build e2e + +// Package agentnetwork holds the container-based agent-network e2e suite. A +// single combined server is built and bootstrapped once per package run +// (TestMain) and shared across tests via srv; each test creates and cleans up +// its own resources so order doesn't matter. +package agentnetwork + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/netbirdio/netbird/e2e/harness" +) + +// srv is the shared combined server for the package, ready (PAT-authenticated) +// by the time any Test runs. +var srv *harness.Combined + +func TestMain(m *testing.M) { + os.Exit(run(m)) +} + +func run(m *testing.M) int { + // Generous timeout to cover a cold image build on first run. + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + var err error + srv, err = harness.StartCombined(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "e2e: start combined server: %v\n", err) + return 1 + } + defer func() { _ = srv.Terminate(context.Background()) }() + + if _, err := srv.Bootstrap(ctx); err != nil { + fmt.Fprintf(os.Stderr, "e2e: bootstrap admin PAT: %v\n", err) + return 1 + } + + return m.Run() +} diff --git a/e2e/agentnetwork/management_test.go b/e2e/agentnetwork/management_test.go new file mode 100644 index 000000000..cfd03f63c --- /dev/null +++ b/e2e/agentnetwork/management_test.go @@ -0,0 +1,221 @@ +//go:build e2e + +package agentnetwork + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/e2e/harness" + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func ptr[T any](v T) *T { return &v } + +// newProvider creates an OpenAI-catalog provider with a dummy key (these tests +// never call the upstream) and registers cleanup. +func newProvider(t *testing.T, ctx context.Context, name string) api.AgentNetworkProvider { + t.Helper() + prov, err := srv.CreateProvider(ctx, api.AgentNetworkProviderRequest{ + Name: name, + ProviderId: "openai_api", + UpstreamUrl: "https://api.openai.com", + ApiKey: ptr("sk-dummy-e2e-key"), + BootstrapCluster: ptr("eu.proxy.netbird.test"), + }) + require.NoError(t, err, "create provider %q", name) + t.Cleanup(func() { _ = srv.DeleteProvider(context.Background(), prov.Id) }) + return prov +} + +// requireClientError asserts err is a REST APIError with a 4xx status. +func requireClientError(t *testing.T, err error) { + t.Helper() + var apiErr *rest.APIError + require.ErrorAs(t, err, &apiErr, "expected a REST APIError") + assert.GreaterOrEqual(t, apiErr.StatusCode, 400, "expected a 4xx status") + assert.Less(t, apiErr.StatusCode, 500, "expected a 4xx status") +} + +// TestProviderLifecycle covers create → get → list → delete → 404 for every +// available real provider catalog (and a synthetic OpenAI provider when no +// provider keys are set), so each catalog's create and field round-trip is +// exercised. Create is offline — no upstream call — so this stays fast and +// burns no provider quota. +func TestProviderLifecycle(t *testing.T) { + ctx := context.Background() + + cases := availableProviders() + if len(cases) == 0 { + cases = []providerCase{{ + name: "openai", catalogID: "openai_api", upstream: "https://api.openai.com", + apiKey: "sk-dummy-e2e-key", model: "gpt-4o-mini", kind: harness.WireChat, + }} + } + + for i, pc := range cases { + i, pc := i, pc + t.Run(pc.name, func(t *testing.T) { + req := providerRequest(pc) + req.Name = "lc-" + pc.name + // Bootstrap the cluster on the first create in case the matrix has + // not run (e.g. no provider keys → settings not yet bootstrapped). + if i == 0 { + req.BootstrapCluster = ptr(harness.AgentNetworkCluster) + } + + prov, err := srv.CreateProvider(ctx, req) + require.NoError(t, err, "create %s provider", pc.name) + t.Cleanup(func() { _ = srv.DeleteProvider(context.Background(), prov.Id) }) + + assert.NotEmpty(t, prov.Id, "created provider must have an id") + assert.Equal(t, pc.catalogID, prov.ProviderId, "catalog id must round-trip") + assert.Equal(t, req.Name, prov.Name, "name must round-trip") + assert.Equal(t, pc.upstream, prov.UpstreamUrl, "upstream must round-trip") + + got, err := srv.GetProvider(ctx, prov.Id) + require.NoError(t, err, "get provider") + assert.Equal(t, prov.Id, got.Id) + + list, err := srv.ListProviders(ctx) + require.NoError(t, err, "list providers") + var ids []string + for _, p := range list { + ids = append(ids, p.Id) + } + assert.Contains(t, ids, prov.Id, "created provider must appear in the list") + + require.NoError(t, srv.DeleteProvider(ctx, prov.Id), "delete provider") + _, err = srv.GetProvider(ctx, prov.Id) + requireClientError(t, err) + }) + } +} + +// TestProviderValidation exercises the create-time validation rules. These are +// uniform across catalogs (no per-provider required-field rules exist: a +// catalog-specific malformed value such as a Vertex key without the keyfile:: +// prefix is accepted at create and only fails at the proxy), so the cases here +// are catalog-agnostic: missing API key, unknown catalog id, an invalid upstream +// URL, and a blank name. +func TestProviderValidation(t *testing.T) { + ctx := context.Background() + + _, err := srv.CreateProvider(ctx, api.AgentNetworkProviderRequest{ + Name: "No Key", + ProviderId: "openai_api", + UpstreamUrl: "https://api.openai.com", + }) + requireClientError(t, err) + + _, err = srv.CreateProvider(ctx, api.AgentNetworkProviderRequest{ + Name: "Unknown Catalog", + ProviderId: "totally_unknown_provider", + UpstreamUrl: "https://example.com", + ApiKey: ptr("sk-dummy"), + }) + requireClientError(t, err) + + _, err = srv.CreateProvider(ctx, api.AgentNetworkProviderRequest{ + Name: "Bad Upstream", + ProviderId: "openai_api", + UpstreamUrl: "not-a-url", + ApiKey: ptr("sk-dummy"), + }) + requireClientError(t, err) + + _, err = srv.CreateProvider(ctx, api.AgentNetworkProviderRequest{ + Name: " ", + ProviderId: "openai_api", + UpstreamUrl: "https://api.openai.com", + ApiKey: ptr("sk-dummy"), + }) + requireClientError(t, err) +} + +// TestSettingsRoundTrip flips the collection toggles and confirms cluster / +// subdomain stay immutable, then restores the original state. +func TestSettingsRoundTrip(t *testing.T) { + ctx := context.Background() + + // Settings are bootstrapped on first provider create. + newProvider(t, ctx, "Settings Bootstrap") + + before, err := srv.GetSettings(ctx) + require.NoError(t, err, "get settings") + require.NotEmpty(t, before.Cluster, "settings must carry an assigned cluster") + + flipped, err := srv.UpdateSettings(ctx, api.AgentNetworkSettingsRequest{ + EnableLogCollection: !before.EnableLogCollection, + EnablePromptCollection: !before.EnablePromptCollection, + RedactPii: !before.RedactPii, + }) + require.NoError(t, err, "update settings") + assert.Equal(t, !before.EnableLogCollection, flipped.EnableLogCollection, "log collection toggle must flip") + assert.Equal(t, !before.EnablePromptCollection, flipped.EnablePromptCollection, "prompt collection toggle must flip") + assert.Equal(t, before.Cluster, flipped.Cluster, "cluster must be immutable across updates") + assert.Equal(t, before.Subdomain, flipped.Subdomain, "subdomain must be immutable across updates") + + // Restore the original toggles. + _, err = srv.UpdateSettings(ctx, api.AgentNetworkSettingsRequest{ + EnableLogCollection: before.EnableLogCollection, + EnablePromptCollection: before.EnablePromptCollection, + RedactPii: before.RedactPii, + }) + require.NoError(t, err, "restore settings") +} + +// TestPolicyWindowFloor rejects an enabled limit below the 60s window floor and +// accepts one at the floor. +func TestPolicyWindowFloor(t *testing.T) { + ctx := context.Background() + + grp, err := srv.API().Groups.Create(ctx, api.PostApiGroupsJSONRequestBody{Name: "e2e-policy-grp"}) + require.NoError(t, err, "create source group") + t.Cleanup(func() { _ = srv.API().Groups.Delete(context.Background(), grp.Id) }) + + prov := newProvider(t, ctx, "Policy Provider") + + limits := func(window int64) *api.AgentNetworkPolicyLimits { + return &api.AgentNetworkPolicyLimits{ + TokenLimit: api.AgentNetworkPolicyTokenLimit{ + Enabled: true, + GroupCap: 1000, + UserCap: 1000, + WindowSeconds: window, + }, + } + } + + _, err = srv.CreatePolicy(ctx, api.AgentNetworkPolicyRequest{ + Name: "e2e-below-floor", + SourceGroups: []string{grp.Id}, + DestinationProviderIds: []string{prov.Id}, + Limits: limits(30), + }) + requireClientError(t, err) + + pol, err := srv.CreatePolicy(ctx, api.AgentNetworkPolicyRequest{ + Name: "e2e-at-floor", + SourceGroups: []string{grp.Id}, + DestinationProviderIds: []string{prov.Id}, + Limits: limits(60), + }) + require.NoError(t, err, "policy at the 60s floor must be accepted") + assert.NotEmpty(t, pol.Id, "created policy must have an id") + t.Cleanup(func() { _ = srv.DeletePolicy(context.Background(), pol.Id) }) +} + +// TestConsumptionList confirms the read endpoint always returns an array, never +// a 404/500. +func TestConsumptionList(t *testing.T) { + ctx := context.Background() + + rows, err := srv.ListConsumption(ctx) + require.NoError(t, err, "consumption list must not error") + assert.NotNil(t, rows, "consumption must be a JSON array (possibly empty)") +} diff --git a/e2e/harness/Dockerfile.client b/e2e/harness/Dockerfile.client new file mode 100644 index 000000000..114577d60 --- /dev/null +++ b/e2e/harness/Dockerfile.client @@ -0,0 +1,24 @@ +# Multistage build for the NetBird client used in e2e tests. The repo has no +# source-building client Dockerfile (client/Dockerfile packages a goreleaser +# artifact), so this mirrors its alpine runtime + entrypoint while compiling the +# CGO-free client inline. BuildKit cache mounts keep rebuilds incremental. + +FROM golang:1.25-bookworm AS builder +WORKDIR /src +COPY go.mod go.sum ./ +RUN --mount=type=cache,target=/go/pkg/mod go mod download +COPY . . +RUN --mount=type=cache,target=/go/pkg/mod \ + --mount=type=cache,target=/root/.cache/go-build \ + CGO_ENABLED=0 GOOS=linux go build -o /out/netbird ./client + +FROM alpine:3.24 +RUN apk add --no-cache bash ca-certificates ip6tables iproute2 iptables +ENV NETBIRD_BIN="/usr/local/bin/netbird" \ + NB_LOG_FILE="console,/var/log/netbird/client.log" \ + NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ + NB_ENABLE_CAPTURE="false" \ + NB_ENTRYPOINT_SERVICE_TIMEOUT="30" +ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] +COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh +COPY --from=builder /out/netbird /usr/local/bin/netbird diff --git a/e2e/harness/agentnetwork.go b/e2e/harness/agentnetwork.go new file mode 100644 index 000000000..192385ab1 --- /dev/null +++ b/e2e/harness/agentnetwork.go @@ -0,0 +1,130 @@ +//go:build e2e + +package harness + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// The shared REST client doesn't (yet) expose typed agent-network methods, so +// these helpers drive the /api/agent-network/* endpoints through the client's +// NewRequest primitive — reusing its auth, error handling (rest.APIError on +// non-2xx), and transport — while still speaking the generated api types. + +// anRequest issues an agent-network API call and decodes the JSON response into +// T. A non-2xx response surfaces as a *rest.APIError from the client, which +// tests inspect for negative-path status assertions. +func anRequest[T any](ctx context.Context, c *Combined, method, path string, body any) (T, error) { + var out T + var reader io.Reader + if body != nil { + bs, err := json.Marshal(body) + if err != nil { + return out, fmt.Errorf("marshal %s %s: %w", method, path, err) + } + reader = bytes.NewReader(bs) + } + + resp, err := c.api.NewRequest(ctx, method, path, reader, nil) + if err != nil { + return out, err + } + defer resp.Body.Close() + + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return out, fmt.Errorf("decode %s %s response: %w", method, path, err) + } + return out, nil +} + +// anDelete issues a DELETE and discards the (empty-object) body. +func anDelete(ctx context.Context, c *Combined, path string) error { + resp, err := c.api.NewRequest(ctx, http.MethodDelete, path, nil, nil) + if err != nil { + return err + } + resp.Body.Close() + return nil +} + +// CreateProvider creates an agent-network provider. +func (c *Combined) CreateProvider(ctx context.Context, req api.AgentNetworkProviderRequest) (api.AgentNetworkProvider, error) { + return anRequest[api.AgentNetworkProvider](ctx, c, http.MethodPost, "/api/agent-network/providers", req) +} + +// GetProvider fetches a provider by id. +func (c *Combined) GetProvider(ctx context.Context, id string) (api.AgentNetworkProvider, error) { + return anRequest[api.AgentNetworkProvider](ctx, c, http.MethodGet, "/api/agent-network/providers/"+id, nil) +} + +// ListProviders returns all providers for the account. +func (c *Combined) ListProviders(ctx context.Context) ([]api.AgentNetworkProvider, error) { + return anRequest[[]api.AgentNetworkProvider](ctx, c, http.MethodGet, "/api/agent-network/providers", nil) +} + +// DeleteProvider removes a provider by id. +func (c *Combined) DeleteProvider(ctx context.Context, id string) error { + return anDelete(ctx, c, "/api/agent-network/providers/"+id) +} + +// SetProviderEnabled toggles a provider's enabled flag, preserving its other +// fields (the API key is omitted, which keeps the stored one). Used to run one +// provider at a time so model→provider routing is unambiguous. +func (c *Combined) SetProviderEnabled(ctx context.Context, id string, enabled bool) error { + p, err := c.GetProvider(ctx, id) + if err != nil { + return err + } + _, err = anRequest[api.AgentNetworkProvider](ctx, c, http.MethodPut, "/api/agent-network/providers/"+id, api.AgentNetworkProviderRequest{ + Name: p.Name, + ProviderId: p.ProviderId, + UpstreamUrl: p.UpstreamUrl, + Enabled: &enabled, + Models: &p.Models, + }) + return err +} + +// CreatePolicy creates an agent-network policy. +func (c *Combined) CreatePolicy(ctx context.Context, req api.AgentNetworkPolicyRequest) (api.AgentNetworkPolicy, error) { + return anRequest[api.AgentNetworkPolicy](ctx, c, http.MethodPost, "/api/agent-network/policies", req) +} + +// UpdatePolicy replaces a policy by id. +func (c *Combined) UpdatePolicy(ctx context.Context, id string, req api.AgentNetworkPolicyRequest) (api.AgentNetworkPolicy, error) { + return anRequest[api.AgentNetworkPolicy](ctx, c, http.MethodPut, "/api/agent-network/policies/"+id, req) +} + +// DeletePolicy removes a policy by id. +func (c *Combined) DeletePolicy(ctx context.Context, id string) error { + return anDelete(ctx, c, "/api/agent-network/policies/"+id) +} + +// GetSettings returns the account's agent-network settings row. It exists only +// after the first provider create bootstraps it. +func (c *Combined) GetSettings(ctx context.Context) (api.AgentNetworkSettings, error) { + return anRequest[api.AgentNetworkSettings](ctx, c, http.MethodGet, "/api/agent-network/settings", nil) +} + +// UpdateSettings applies the mutable collection toggles. +func (c *Combined) UpdateSettings(ctx context.Context, req api.AgentNetworkSettingsRequest) (api.AgentNetworkSettings, error) { + return anRequest[api.AgentNetworkSettings](ctx, c, http.MethodPut, "/api/agent-network/settings", req) +} + +// ListConsumption returns the account's consumption rows (possibly empty). +func (c *Combined) ListConsumption(ctx context.Context) ([]api.AgentNetworkConsumption, error) { + return anRequest[[]api.AgentNetworkConsumption](ctx, c, http.MethodGet, "/api/agent-network/consumption", nil) +} + +// ListAccessLogs returns the account's agent-network access-log page (the +// flattened per-request rows the proxy ships and management ingests). +func (c *Combined) ListAccessLogs(ctx context.Context) (api.AgentNetworkAccessLogsResponse, error) { + return anRequest[api.AgentNetworkAccessLogsResponse](ctx, c, http.MethodGet, "/api/agent-network/access-logs", nil) +} diff --git a/e2e/harness/bootstrap.go b/e2e/harness/bootstrap.go new file mode 100644 index 000000000..defa03c14 --- /dev/null +++ b/e2e/harness/bootstrap.go @@ -0,0 +1,47 @@ +//go:build e2e + +package harness + +import ( + "context" + "fmt" + + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// Bootstrap creates the initial admin owner through the unauthenticated +// /api/setup endpoint and returns the plaintext admin PAT. It also wires an +// authenticated REST client on the Combined (see API). create_pat requires the +// server to run with NB_SETUP_PAT_ENABLED=true, which the harness sets. A +// second call returns an error (the server reports setup already completed). +func (c *Combined) Bootstrap(ctx context.Context) (string, error) { + // The setup endpoint is unauthenticated; use a tokenless client. + setupClient := rest.NewWithOptions(rest.WithManagementURL(c.BaseURL)) + + createPAT := true + expireDays := 1 + resp, err := setupClient.Instance.Setup(ctx, api.PostApiSetupJSONRequestBody{ //nolint:gosec // static throwaway test credentials + Email: "admin@netbird.test", + Password: "Netbird-e2e-Passw0rd!", + Name: "E2E Admin", + CreatePat: &createPAT, + PatExpireIn: &expireDays, + }) + if err != nil { + return "", fmt.Errorf("instance setup: %w", err) + } + if resp.PersonalAccessToken == nil || *resp.PersonalAccessToken == "" { + return "", fmt.Errorf("setup succeeded but no PAT returned (is NB_SETUP_PAT_ENABLED set?)") + } + + c.PAT = *resp.PersonalAccessToken + c.api = rest.New(c.BaseURL, c.PAT) + return c.PAT, nil +} + +// API returns the PAT-authenticated management REST client. It is nil until +// Bootstrap runs. +func (c *Combined) API() *rest.Client { + return c.api +} diff --git a/e2e/harness/cert.go b/e2e/harness/cert.go new file mode 100644 index 000000000..c8a28e470 --- /dev/null +++ b/e2e/harness/cert.go @@ -0,0 +1,66 @@ +//go:build e2e + +package harness + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "path/filepath" + "time" +) + +// writeSelfSignedCert generates a self-signed TLS cert/key pair covering the +// given DNS names and writes them as tls.crt / tls.key in dir. The proxy serves +// this for the agent-network endpoint; the client curls with -k, so validity +// chains don't matter — the proxy just needs a usable cert to present. +func writeSelfSignedCert(dir string, dnsNames []string) error { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return fmt.Errorf("generate key: %w", err) + } + + serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return fmt.Errorf("generate serial: %w", err) + } + + tmpl := x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: dnsNames[0]}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + DNSNames: dnsNames, + BasicConstraintsValid: true, + } + + der, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv) + if err != nil { + return fmt.Errorf("create certificate: %w", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + if err := os.WriteFile(filepath.Join(dir, "tls.crt"), certPEM, 0o644); err != nil { //nolint:gosec // public cert, bind-mounted and read by the proxy container + return fmt.Errorf("write cert: %w", err) + } + + keyDER, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return fmt.Errorf("marshal key: %w", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + // World-readable so the (non-root) proxy container can read the bind-mounted + // key on Linux CI runners; this is a throwaway self-signed e2e key. + if err := os.WriteFile(filepath.Join(dir, "tls.key"), keyPEM, 0o644); err != nil { //nolint:gosec // throwaway self-signed e2e key, must be readable by the proxy container uid + return fmt.Errorf("write key: %w", err) + } + return nil +} diff --git a/e2e/harness/client.go b/e2e/harness/client.go new file mode 100644 index 000000000..cf7ef8945 --- /dev/null +++ b/e2e/harness/client.go @@ -0,0 +1,256 @@ +//go:build e2e + +package harness + +import ( + "context" + "fmt" + "io" + "os/exec" + "strings" + "time" + + "github.com/docker/docker/api/types/container" + "github.com/testcontainers/testcontainers-go" + tcexec "github.com/testcontainers/testcontainers-go/exec" +) + +const ( + clientDockerfile = "e2e/harness/Dockerfile.client" + // defaultClientImage is the local tag the client is built under from + // clientDockerfile. Override with NB_E2E_CLIENT_IMAGE: a value with a "/" is + // pulled as a published image; a bare tag is built under that name. + defaultClientImage = "netbird-client:e2e" + clientAlias = "client" + curlImage = "curlimages/curl:latest" +) + +// Client is a running NetBird client container joined to the combined server. +type Client struct { + container testcontainers.Container +} + +// StartClient builds the client image and runs it on the combined server's +// network, joining via the given setup key. The image entrypoint brings the +// daemon up automatically; callers wait for connectivity with WaitConnected / +// WaitProxyPeer. +func StartClient(ctx context.Context, c *Combined, setupKey string) (*Client, error) { + root, err := repoRoot() + if err != nil { + return nil, err + } + clientImage, err := resolveImage(ctx, root, "NB_E2E_CLIENT_IMAGE", defaultClientImage, clientDockerfile) + if err != nil { + return nil, err + } + + req := testcontainers.ContainerRequest{ + Image: clientImage, + Networks: []string{c.network.Name}, + NetworkAliases: map[string][]string{c.network.Name: {clientAlias}}, + Env: map[string]string{ + "NB_MANAGEMENT_URL": combinedExposedURL, + "NB_SETUP_KEY": setupKey, + "NB_LOG_LEVEL": "info", + // Match the proxy: the combined relay is WebSocket-only, so the + // client must use WS transport to keep a stable relay link to it. + "NB_RELAY_TRANSPORT": "ws", + }, + HostConfigModifier: func(hc *container.HostConfig) { + hc.CapAdd = append(hc.CapAdd, "NET_ADMIN", "SYS_ADMIN", "SYS_RESOURCE") + }, + } + + ctr, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + if err != nil { + return nil, fmt.Errorf("start client container: %w", err) + } + return &Client{container: ctr}, nil +} + +// Restart bounces the client connection (netbird down/up) so it pulls a fresh +// network map — the documented workaround for a freshly-joined client not yet +// seeing a synthesized agent-network service. +func (cl *Client) Restart(ctx context.Context) error { + if _, _, err := cl.container.Exec(ctx, []string{"netbird", "down"}, tcexec.Multiplexed()); err != nil { + return fmt.Errorf("netbird down: %w", err) + } + time.Sleep(2 * time.Second) + code, reader, err := cl.container.Exec(ctx, []string{"netbird", "up"}, tcexec.Multiplexed()) + if err != nil { + return fmt.Errorf("netbird up: %w", err) + } + if code != 0 { + out, _ := io.ReadAll(reader) + return fmt.Errorf("netbird up exited %d: %s", code, string(out)) + } + return nil +} + +// Status returns `netbird status` output from inside the client. +func (cl *Client) Status(ctx context.Context) (string, error) { + code, reader, err := cl.container.Exec(ctx, []string{"netbird", "status"}, tcexec.Multiplexed()) + if err != nil { + return "", err + } + out, _ := io.ReadAll(reader) + if code != 0 { + return string(out), fmt.Errorf("netbird status exited %d", code) + } + return string(out), nil +} + +// WaitConnected polls until the client reports Management: Connected. +func (cl *Client) WaitConnected(ctx context.Context, timeout time.Duration) error { + return cl.pollStatus(ctx, timeout, "Management: Connected") +} + +// WaitProxyPeer polls until the client sees the proxy peer connected (1/1). +func (cl *Client) WaitProxyPeer(ctx context.Context, timeout time.Duration) error { + return cl.pollStatus(ctx, timeout, "1/1 Connected") +} + +func (cl *Client) pollStatus(ctx context.Context, timeout time.Duration, want string) error { + deadline := time.Now().Add(timeout) + var last string + for time.Now().Before(deadline) { + out, _ := cl.Status(ctx) + last = out + if strings.Contains(out, want) { + return nil + } + time.Sleep(3 * time.Second) + } + return fmt.Errorf("timed out waiting for %q; last status:\n%s", want, last) +} + +// ResolveProxyIP resolves the agent-network endpoint to the proxy peer's +// NetBird IP from inside the client (via magic DNS). +func (cl *Client) ResolveProxyIP(ctx context.Context, endpoint string) (string, error) { + code, reader, err := cl.container.Exec(ctx, []string{"getent", "hosts", endpoint}, tcexec.Multiplexed()) + if err != nil { + return "", err + } + out, _ := io.ReadAll(reader) + if code != 0 { + return "", fmt.Errorf("getent hosts %s exited %d", endpoint, code) + } + fields := strings.Fields(string(out)) + if len(fields) == 0 { + return "", fmt.Errorf("no address for %s", endpoint) + } + return fields[0], nil +} + +// Wire shapes for Chat. +const ( + // WireChat is the OpenAI-compatible /v1/chat/completions shape. + WireChat = "chat" + // WireMessages is the Anthropic /v1/messages shape. + WireMessages = "messages" + // WireVertex is the Anthropic-on-Vertex rawPredict shape: the client posts + // the full Vertex model path and the proxy mints the SA OAuth token. + WireVertex = "vertex" +) + +// Chat issues a chat-completion POST to the agent-network endpoint over the +// client's tunnel, returning the HTTP status and response body. kind selects +// the wire shape: WireChat (OpenAI) or WireMessages (Anthropic). A non-empty +// sessionID is sent as the universal x-session-id header the proxy records. +func (cl *Client) Chat(ctx context.Context, endpoint, proxyIP, kind, model, prompt, sessionID string) (int, string, error) { + var path, body string + var headers []string + switch kind { + case WireMessages: + path = "/v1/messages" + headers = []string{"anthropic-version: 2023-06-01"} + body = fmt.Sprintf(`{"model":%q,"max_tokens":64,"messages":[{"role":"user","content":%q}]}`, model, prompt) + default: + path = "/v1/chat/completions" + body = fmt.Sprintf(`{"model":%q,"messages":[{"role":"user","content":%q}]}`, model, prompt) + } + return cl.post(ctx, endpoint, proxyIP, path, body, withSessionID(headers, sessionID)) +} + +// Vertex issues an Anthropic-on-Vertex rawPredict POST over the tunnel. Unlike +// Chat, the model is carried in the request path (project/region/model), so the +// proxy routes by path and mints the service-account OAuth token; the body uses +// the Vertex anthropic_version rather than a model field. A non-empty sessionID +// is sent as the universal x-session-id header the proxy records. +func (cl *Client) Vertex(ctx context.Context, endpoint, proxyIP, project, region, model, prompt, sessionID string) (int, string, error) { + path := fmt.Sprintf("/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", project, region, model) + body := fmt.Sprintf(`{"anthropic_version":"vertex-2023-10-16","max_tokens":64,"messages":[{"role":"user","content":%q}]}`, prompt) + return cl.post(ctx, endpoint, proxyIP, path, body, withSessionID(nil, sessionID)) +} + +// withSessionID appends the x-session-id header when sessionID is non-empty. +func withSessionID(headers []string, sessionID string) []string { + if sessionID == "" { + return headers + } + return append(headers, "x-session-id: "+sessionID) +} + +// post runs curl in a throwaway container sharing the client's network +// namespace so the request traverses the WireGuard tunnel, pinning the endpoint +// to the proxy IP. It returns the HTTP status and response body. +func (cl *Client) post(ctx context.Context, endpoint, proxyIP, path, body string, extraHeaders []string) (int, string, error) { + url := "https://" + endpoint + path + args := []string{ + "run", "--rm", + "--network", "container:" + cl.container.GetContainerID(), + curlImage, + "-sk", "--connect-timeout", "5", "--max-time", "90", + "--resolve", endpoint + ":443:" + proxyIP, + "-o", "/dev/stderr", "-w", "%{http_code}", + "-X", "POST", url, + "-H", "Content-Type: application/json", + } + for _, h := range extraHeaders { + args = append(args, "-H", h) + } + args = append(args, "--data", body) + cmd := exec.CommandContext(ctx, "docker", args...) + // -w writes the status code to stdout; -o /dev/stderr writes the body to + // stderr so we can capture both separately. + var stdout, stderr strings.Builder + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return 0, stderr.String(), fmt.Errorf("curl through tunnel: %w", err) + } + + code := 0 + _, _ = fmt.Sscanf(strings.TrimSpace(stdout.String()), "%d", &code) + return code, stderr.String(), nil +} + +// Logs returns the client container logs, for diagnostics on failure. +func (cl *Client) Logs(ctx context.Context) string { + return containerLogs(ctx, cl.container) +} + +// Terminate stops the client container. +func (cl *Client) Terminate(ctx context.Context) error { + if cl.container == nil { + return nil + } + return cl.container.Terminate(ctx) +} + +// containerLogs reads up to 256 KiB of a container's logs for diagnostics. +func containerLogs(ctx context.Context, c testcontainers.Container) string { + if c == nil { + return "" + } + r, err := c.Logs(ctx) + if err != nil { + return fmt.Sprintf("", err) + } + defer r.Close() + b, _ := io.ReadAll(io.LimitReader(r, 256<<10)) + return string(b) +} diff --git a/e2e/harness/combined.go b/e2e/harness/combined.go new file mode 100644 index 000000000..a6f43a139 --- /dev/null +++ b/e2e/harness/combined.go @@ -0,0 +1,243 @@ +//go:build e2e + +package harness + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/docker/docker/api/types/container" + "github.com/docker/go-connections/nat" + "github.com/testcontainers/testcontainers-go" + tcexec "github.com/testcontainers/testcontainers-go/exec" + "github.com/testcontainers/testcontainers-go/network" + "github.com/testcontainers/testcontainers-go/wait" + + "github.com/netbirdio/netbird/shared/management/client/rest" +) + +const ( + combinedDockerfile = "combined/Dockerfile.multistage" + // defaultCombinedImage is the local tag the combined server is built under + // from combinedDockerfile, so the e2e exercises this branch's code. Override + // with NB_E2E_COMBINED_IMAGE: a value containing a "/" is pulled as a + // published image; a bare tag is built under that name instead. + defaultCombinedImage = "netbird-combined:e2e" + combinedHTTPPort = "8080/tcp" + + // combinedAlias is the combined server's network alias AND the deployment + // domain. The working manual setup uses a single NETBIRD_DOMAIN for the + // management exposed address, the proxy domain, and the agent-network + // cluster — so we mirror that: peers reach management/signal/relay at this + // name, the proxy registers this as its cluster, and the agent-network + // endpoint is .. + combinedAlias = "netbird.local" + combinedExposedURL = "http://" + combinedAlias + ":8080" + + // containerIssuer is the embedded IdP issuer, used only for internal JWT + // validation (peers authenticate with setup keys / proxy tokens, not OIDC), + // so the in-container localhost address is fine. + containerIssuer = "http://localhost:8080/oauth2" +) + +// Combined is a running combined NetBird server (management + signal + relay + +// STUN + embedded IdP) plus the connection details tests need. It owns the +// shared docker network that the proxy and client containers join. +type Combined struct { + container testcontainers.Container + network *testcontainers.DockerNetwork + // BaseURL is the host-reachable management API root, e.g. http://127.0.0.1:51234. + BaseURL string + // PAT is the admin Personal Access Token minted via Bootstrap. + PAT string + + api *rest.Client + workDir string +} + +// StartCombined builds the combined server from its multistage Dockerfile and +// boots it with setup-PAT enabled on a fresh shared network, returning once the +// API is serving. The caller still owns minting the admin PAT via Bootstrap. +func StartCombined(ctx context.Context) (*Combined, error) { + root, err := repoRoot() + if err != nil { + return nil, err + } + + combinedImage, err := resolveImage(ctx, root, "NB_E2E_COMBINED_IMAGE", defaultCombinedImage, combinedDockerfile) + if err != nil { + return nil, err + } + + net, err := network.New(ctx) + if err != nil { + return nil, fmt.Errorf("create shared network: %w", err) + } + + // Work dir under /tmp so Docker Desktop file sharing (which excludes + // macOS's /var/folders TMPDIR) can bind-mount it. + workDir, err := os.MkdirTemp("/tmp", "nb-e2e-combined-*") + if err != nil { + _ = net.Remove(ctx) + return nil, fmt.Errorf("create work dir: %w", err) + } + + cfg := fmt.Sprintf(combinedConfigYAML, combinedExposedURL, containerIssuer) + if err := os.WriteFile(filepath.Join(workDir, "config.yaml"), []byte(cfg), 0o644); err != nil { //nolint:gosec // non-secret config, bind-mounted and read by the container + _ = net.Remove(ctx) + return nil, fmt.Errorf("write combined config: %w", err) + } + if err := os.MkdirAll(filepath.Join(workDir, "data"), 0o755); err != nil { + _ = net.Remove(ctx) + return nil, fmt.Errorf("create datadir: %w", err) + } + + req := testcontainers.ContainerRequest{ + Image: combinedImage, + ExposedPorts: []string{combinedHTTPPort}, + Networks: []string{net.Name}, + NetworkAliases: map[string][]string{net.Name: {combinedAlias}}, + Env: map[string]string{ + "NB_SETUP_PAT_ENABLED": "true", + // Skip the GeoLite DB download — it blocks startup and agent-network + // ingest doesn't use geolocation. + "NB_DISABLE_GEOLOCATION": "true", + }, + Cmd: []string{"--config", "/nb/config.yaml"}, + HostConfigModifier: func(hc *container.HostConfig) { + hc.Binds = append(hc.Binds, workDir+":/nb") + }, + WaitingFor: wait.ForHTTP("/api/instance"). + WithPort(combinedHTTPPort). + WithStatusCodeMatcher(func(status int) bool { return status == 200 }). + WithStartupTimeout(120 * time.Second), + } + + c, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + if err != nil { + _ = net.Remove(ctx) + return nil, fmt.Errorf("start combined container: %w", err) + } + + host, err := c.Host(ctx) + if err != nil { + _ = c.Terminate(ctx) + _ = net.Remove(ctx) + return nil, fmt.Errorf("container host: %w", err) + } + mapped, err := c.MappedPort(ctx, nat.Port(combinedHTTPPort)) + if err != nil { + _ = c.Terminate(ctx) + _ = net.Remove(ctx) + return nil, fmt.Errorf("mapped port: %w", err) + } + + return &Combined{ + container: c, + network: net, + BaseURL: fmt.Sprintf("http://%s:%s", host, mapped.Port()), + workDir: workDir, + }, nil +} + +// resolveImage returns the image to run for a component. By default it builds +// the image from the repo Dockerfile under localTag, so the e2e exercises the +// branch's code. The env override changes this: a value containing a "/" is a +// registry reference that testcontainers pulls (e.g. to test a published +// release); a bare tag is built under that name instead. +func resolveImage(ctx context.Context, root, envKey, localTag, dockerfile string) (string, error) { + if v := os.Getenv(envKey); v != "" { + if strings.Contains(v, "/") { + return v, nil + } + localTag = v + } + if err := buildImage(ctx, root, dockerfile, localTag); err != nil { + return "", err + } + return localTag, nil +} + +// buildImage builds an image from a repo Dockerfile via buildx with BuildKit, so +// the Dockerfile cache mounts are honored and unchanged layers are reused. The +// result is loaded into the docker image store so testcontainers runs it by tag. +// When NB_E2E_BUILDX_CACHE names a directory (CI, with a container-driver +// builder from docker/setup-buildx-action), layer cache is read from and written +// to it as a local cache so actions/cache can persist it across runs; the Go +// compile itself still re-runs, as BuildKit mount caches can't be exported. +func buildImage(ctx context.Context, root, dockerfile, tag string) error { + args := []string{"buildx", "build", "-f", dockerfile, "-t", tag, "--load"} + if dir := os.Getenv("NB_E2E_BUILDX_CACHE"); dir != "" { + args = append(args, + "--cache-from", "type=local,src="+dir, + "--cache-to", "type=local,dest="+dir+",mode=max", + ) + } + args = append(args, ".") + + cmd := exec.CommandContext(ctx, "docker", args...) + cmd.Dir = root + cmd.Env = append(os.Environ(), "DOCKER_BUILDKIT=1") + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("build image %s: %w\n%s", tag, err, string(out)) + } + return nil +} + +// CreateProxyTokenCLI mints a proxy access token via the server's `token +// create` CLI inside the container — the same path the manual install uses. +// This yields a GLOBAL (account-less) token, so the proxy serves the whole +// cluster (SynthesizeServicesForCluster); an account-scoped REST token instead +// drives the per-account path. Returns the plaintext token. +func (c *Combined) CreateProxyTokenCLI(ctx context.Context, name string) (string, error) { + code, reader, err := c.container.Exec(ctx, + []string{"/go/bin/netbird-server", "token", "create", "--name", name, "--config", "/nb/config.yaml"}, + tcexec.Multiplexed()) + if err != nil { + return "", fmt.Errorf("exec token create: %w", err) + } + out, _ := io.ReadAll(reader) + if code != 0 { + return "", fmt.Errorf("token create exited %d: %s", code, string(out)) + } + for _, line := range strings.Split(string(out), "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "Token:") { + tok := strings.TrimSpace(strings.TrimPrefix(line, "Token:")) + if tok != "" { + return tok, nil + } + } + } + return "", fmt.Errorf("token not found in CLI output: %s", string(out)) +} + +// Logs returns the combined server container logs, for diagnostics. +func (c *Combined) Logs(ctx context.Context) string { + return containerLogs(ctx, c.container) +} + +// Terminate stops the container, removes the shared network, and cleans the +// work dir. +func (c *Combined) Terminate(ctx context.Context) error { + var err error + if c.container != nil { + err = c.container.Terminate(ctx) + } + if c.network != nil { + _ = c.network.Remove(ctx) + } + if c.workDir != "" { + _ = os.RemoveAll(c.workDir) + } + return err +} diff --git a/e2e/harness/config.go b/e2e/harness/config.go new file mode 100644 index 000000000..b4bed60a2 --- /dev/null +++ b/e2e/harness/config.go @@ -0,0 +1,26 @@ +//go:build e2e + +package harness + +// combinedConfigYAML is a minimal combined-server config for tests: plain HTTP +// on :8080 (no TLS cert configured → the server serves HTTP and expects to sit +// behind a reverse proxy, which is exactly what we want for in-cluster tests), +// embedded IdP, local signal/relay/STUN, and a sqlite store under the mounted +// data dir. exposedAddress is the address peers use to reach this container; it +// is overridden per-run so the value matches the container's network alias. +const combinedConfigYAML = `server: + listenAddress: ":8080" + exposedAddress: "%s" + healthcheckAddress: ":9000" + metricsPort: 9090 + logLevel: "info" + logFile: "console" + authSecret: "e2e-relay-secret" + dataDir: "/nb/data" + disableAnonymousMetrics: true + disableGeoliteUpdate: true + auth: + issuer: "%s" + store: + engine: "sqlite" +` diff --git a/e2e/harness/doc.go b/e2e/harness/doc.go new file mode 100644 index 000000000..937d8e664 --- /dev/null +++ b/e2e/harness/doc.go @@ -0,0 +1,13 @@ +//go:build e2e + +// Package harness provides a self-contained, OIDC-free way to stand up NetBird +// components in containers for end-to-end tests. It is feature-agnostic: any +// suite can ask for a live management server (with an admin PAT minted through +// the unauthenticated /api/setup bootstrap) and, later, a proxy and client. +// +// The harness compiles each component once in a cached builder container and +// mounts the resulting binary into a slim runtime container, so iterating on a +// branch doesn't pay a full image rebuild per run. Everything is gated behind +// the `e2e` build tag so normal builds and unit tests never pull in +// testcontainers. +package harness diff --git a/e2e/harness/paths.go b/e2e/harness/paths.go new file mode 100644 index 000000000..d7df6bbfa --- /dev/null +++ b/e2e/harness/paths.go @@ -0,0 +1,29 @@ +//go:build e2e + +package harness + +import ( + "fmt" + "os" + "path/filepath" +) + +// repoRoot walks up from the working directory to the module root (the +// directory holding go.mod), so the Docker build context is correct no matter +// which package the test runs from. +func repoRoot() (string, error) { + dir, err := os.Getwd() + if err != nil { + return "", err + } + for { + if _, statErr := os.Stat(filepath.Join(dir, "go.mod")); statErr == nil { + return dir, nil + } + parent := filepath.Dir(dir) + if parent == dir { + return "", fmt.Errorf("go.mod not found above %s", dir) + } + dir = parent + } +} diff --git a/e2e/harness/proxy.go b/e2e/harness/proxy.go new file mode 100644 index 000000000..8db2c140f --- /dev/null +++ b/e2e/harness/proxy.go @@ -0,0 +1,122 @@ +//go:build e2e + +package harness + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/docker/docker/api/types/container" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + proxyDockerfile = "proxy/Dockerfile.multistage" + // defaultProxyImage is the local tag the reverse proxy is built under from + // proxyDockerfile. Override with NB_E2E_PROXY_IMAGE: a value with a "/" is + // pulled as a published image; a bare tag is built under that name. + defaultProxyImage = "netbird-reverse-proxy:e2e" + proxyAlias = "proxy" + + // AgentNetworkCluster is the proxy cluster the e2e provider bootstraps and + // the proxy serves. It must equal the management's exposed domain + // (combinedAlias) — the working manual setup uses one NETBIRD_DOMAIN for + // both. The agent-network endpoint is .. + AgentNetworkCluster = combinedAlias +) + +// Proxy is a running agent-network gateway (netbird proxy) container. +type Proxy struct { + container testcontainers.Container + workDir string +} + +// StartProxy builds the proxy image and runs it on the combined server's +// network, registered via the given account proxy token and serving the +// AgentNetworkCluster over a self-signed wildcard cert. It does not wait for +// peer connectivity — callers poll management for the proxy peer. +func StartProxy(ctx context.Context, c *Combined, proxyToken string) (*Proxy, error) { + root, err := repoRoot() + if err != nil { + return nil, err + } + proxyImage, err := resolveImage(ctx, root, "NB_E2E_PROXY_IMAGE", defaultProxyImage, proxyDockerfile) + if err != nil { + return nil, err + } + + workDir, err := os.MkdirTemp("/tmp", "nb-e2e-proxy-*") + if err != nil { + return nil, fmt.Errorf("create proxy work dir: %w", err) + } + // MkdirTemp creates the dir 0700; widen it so the non-root proxy container + // can traverse the bind-mounted cert dir on Linux CI runners. + if err := os.Chmod(workDir, 0o755); err != nil { //nolint:gosec // throwaway e2e cert dir, must be traversable by the proxy container uid + return nil, fmt.Errorf("chmod proxy cert dir: %w", err) + } + if err := writeSelfSignedCert(workDir, []string{"*." + AgentNetworkCluster, AgentNetworkCluster}); err != nil { + return nil, err + } + + req := testcontainers.ContainerRequest{ + Image: proxyImage, + Networks: []string{c.network.Name}, + NetworkAliases: map[string][]string{c.network.Name: {proxyAlias}}, + Env: map[string]string{ + "NB_PROXY_TOKEN": proxyToken, + "NB_PROXY_MANAGEMENT_ADDRESS": combinedExposedURL, + "NB_PROXY_DOMAIN": AgentNetworkCluster, + "NB_PROXY_ADDRESS": ":443", + "NB_PROXY_CERTIFICATE_DIRECTORY": "/certs", + "NB_PROXY_HEALTH_ADDRESS": ":8081", + "NB_PROXY_LOG_LEVEL": "debug", + "NB_PROXY_PRIVATE": "true", + // Management is plain HTTP in-cluster, so allow the proxy token to + // ride a non-TLS gRPC connection. + "NB_PROXY_ALLOW_INSECURE": "true", + // The combined server multiplexes the relay over WebSocket on :8080 + // (no QUIC listener). The proxy's embedded relay client defaults to + // QUIC, which fails here and flaps the relay link, churning the + // proxy peer so it never stably registers. Force WS transport. + "NB_RELAY_TRANSPORT": "ws", + // Trace the embedded client (relay / signal / handshake) so + // peer-registration issues are visible in the proxy logs. + "NB_PROXY_CLIENT_LOG_LEVEL": "trace", + }, + HostConfigModifier: func(hc *container.HostConfig) { + hc.Binds = append(hc.Binds, workDir+":/certs") + hc.CapAdd = append(hc.CapAdd, "NET_ADMIN", "SYS_ADMIN", "SYS_RESOURCE", "NET_BIND_SERVICE") + }, + WaitingFor: wait.ForLog("Initial mapping sync complete").WithStartupTimeout(90 * time.Second), + } + + ctr, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + if err != nil { + return nil, fmt.Errorf("start proxy container: %w", err) + } + + return &Proxy{container: ctr, workDir: workDir}, nil +} + +// Logs returns the proxy container logs, for diagnostics on failure. +func (p *Proxy) Logs(ctx context.Context) string { + return containerLogs(ctx, p.container) +} + +// Terminate stops the proxy container and cleans its work dir. +func (p *Proxy) Terminate(ctx context.Context) error { + var err error + if p.container != nil { + err = p.container.Terminate(ctx) + } + if p.workDir != "" { + _ = os.RemoveAll(p.workDir) + } + return err +} diff --git a/go.mod b/go.mod index 9a57de1c9..e1c762607 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/DeRuina/timberjack v1.4.2 github.com/awnumar/memguard v0.23.0 github.com/aws/aws-sdk-go-v2 v1.38.3 + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 github.com/aws/aws-sdk-go-v2/config v1.31.6 github.com/aws/aws-sdk-go-v2/credentials v1.18.10 github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3 @@ -49,6 +50,8 @@ require ( github.com/crowdsecurity/go-cs-bouncer v0.0.21 github.com/dexidp/dex v2.13.0+incompatible github.com/dexidp/dex/api/v2 v2.4.0 + github.com/docker/docker v28.0.1+incompatible + github.com/docker/go-connections v0.6.0 github.com/ebitengine/purego v0.8.4 github.com/eko/gocache/lib/v4 v4.2.0 github.com/eko/gocache/store/go_cache/v4 v4.2.2 @@ -158,7 +161,6 @@ require ( github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/awnumar/memcall v0.4.0 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6 // indirect @@ -188,8 +190,6 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect - github.com/docker/docker v28.0.1+incompatible // indirect - github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fredbi/uri v1.1.1 // indirect diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 0d8fb3c47..f1b1832d2 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -116,6 +116,24 @@ func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, p c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer) } +// injectAllProxyPolicies prepares an account for the per-peer network-map +// computation. It prepends the in-memory agent-network services synthesised +// from the account's current provider/policy state to account.Services so +// the existing InjectProxyPolicies + injectPrivateServicePolicies walks pick +// them up alongside persisted reverse-proxy services. Synthesised services +// are never persisted; the account is loaded fresh per cycle so re-prepending +// is safe and idempotent. Accounts without agent-network providers get an +// empty synth slice — no behaviour change. +func (c *Controller) injectAllProxyPolicies(ctx context.Context, account *types.Account) { + synth, err := c.repo.SynthesizeAgentNetworkServices(ctx, account.Id) + if err != nil { + log.WithContext(ctx).Warnf("synthesise agent-network services for account %s: %v", account.Id, err) + } else if len(synth) > 0 { + account.Services = append(synth, account.Services...) + } + account.InjectProxyPolicies(ctx) +} + func (c *Controller) CountStreams() int { return c.peersUpdateManager.CountStreams() } @@ -150,7 +168,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin var wg sync.WaitGroup semaphore := make(chan struct{}, 10) - account.InjectProxyPolicies(ctx) + c.injectAllProxyPolicies(ctx, account) dnsCache := &cache.DNSConfigCache{} dnsDomain := c.GetDNSDomain(account.Settings) peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain) @@ -281,7 +299,15 @@ func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID s var wg sync.WaitGroup semaphore := make(chan struct{}, 10) - account.InjectProxyPolicies(ctx) + // The affected-peer path MUST mirror sendUpdateAccountPeers (line 171) + // here: injectAllProxyPolicies prepends the synthesised agent-network + // services BEFORE InjectProxyPolicies + private-service policies run. + // Previously this path called only account.InjectProxyPolicies, which + // skipped the synth-services prepend — so peer-level changes + // (proxy restart, embedded peer connect/disconnect) propagated a + // network map that omitted the synth DNS zone, and the agent kept + // resolving against the stale or absent record. + c.injectAllProxyPolicies(ctx, account) dnsCache := &cache.DNSConfigCache{} dnsDomain := c.GetDNSDomain(account.Settings) peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain) @@ -399,7 +425,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe return fmt.Errorf("failed to get validated peers: %v", err) } - account.InjectProxyPolicies(ctx) + c.injectAllProxyPolicies(ctx, account) dnsCache := &cache.DNSConfigCache{} dnsDomain := c.GetDNSDomain(account.Settings) peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain) @@ -603,7 +629,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr return nil, nil, 0, err } - account.InjectProxyPolicies(ctx) + c.injectAllProxyPolicies(ctx, account) approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { @@ -874,7 +900,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N return nil, err } - account.InjectProxyPolicies(ctx) + c.injectAllProxyPolicies(ctx, account) resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() groupIDToUserIDs := account.GetActiveGroupUsers() diff --git a/management/internals/controllers/network_map/controller/repository.go b/management/internals/controllers/network_map/controller/repository.go index caef362cb..c0fcefc7d 100644 --- a/management/internals/controllers/network_map/controller/repository.go +++ b/management/internals/controllers/network_map/controller/repository.go @@ -3,7 +3,9 @@ package controller import ( "context" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork" "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -16,6 +18,10 @@ type Repository interface { GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) + // SynthesizeAgentNetworkServices returns the in-memory reverse-proxy + // services synthesised from the account's agent-network provider/policy + // state. Empty for accounts without agent-network providers. + SynthesizeAgentNetworkServices(ctx context.Context, accountID string) ([]*service.Service, error) } type repository struct { @@ -50,6 +56,10 @@ func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID s return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) } +func (r *repository) SynthesizeAgentNetworkServices(ctx context.Context, accountID string) ([]*service.Service, error) { + return agentnetwork.SynthesizeServices(ctx, r.store, accountID) +} + func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) { return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID) } diff --git a/management/internals/modules/agentnetwork/accesslog_ingest.go b/management/internals/modules/agentnetwork/accesslog_ingest.go new file mode 100644 index 000000000..59e53efa2 --- /dev/null +++ b/management/internals/modules/agentnetwork/accesslog_ingest.go @@ -0,0 +1,215 @@ +package agentnetwork + +import ( + "context" + "math" + "strconv" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" + "github.com/netbirdio/netbird/management/server/store" +) + +// Metadata keys the proxy stamps on agent-network access-log entries. These +// mirror the constants in proxy/internal/middleware/keys.go and form the wire +// contract between the proxy and management; management flattens them into +// queryable columns. Keep in sync with the proxy side. +const ( + metaKeyProvider = "llm.provider" + metaKeyModel = "llm.model" + metaKeyResolvedProviderID = "llm.resolved_provider_id" + metaKeySelectedPolicyID = "llm.selected_policy_id" + metaKeyPolicyDecision = "llm_policy.decision" + metaKeyPolicyReason = "llm_policy.reason" + metaKeyInputTokens = "llm.input_tokens" //nolint:gosec // metadata key name, not a credential + metaKeyOutputTokens = "llm.output_tokens" //nolint:gosec // metadata key name, not a credential + metaKeyTotalTokens = "llm.total_tokens" //nolint:gosec // metadata key name, not a credential + metaKeyCostUSDTotal = "cost.usd_total" + metaKeyStream = "llm.stream" + metaKeySessionID = "llm.session_id" + metaKeyAuthorisingGroups = "llm.authorising_groups" + metaKeyRequestPrompt = "llm.request_prompt" + metaKeyResponseCompletion = "llm.response_completion" +) + +// IngestAccessLog flattens the metadata-bearing reverse-proxy access-log entry +// and persists it in the dedicated agent-network tables (instead of the shared +// reverse-proxy table), in two parts: +// +// - The stripped usage record is written unconditionally — usage/cost is +// collected on every request regardless of the account's log-collection +// toggle (the proxy ships a usage-only entry when logging is disabled). +// - The full access-log row (with request detail + prompt) is written only +// when the account's EnableLogCollection setting is on. This setting read +// is the authoritative gate; the proxy-side strip is defense in depth. +func IngestAccessLog(ctx context.Context, s store.Store, logEntry *accesslogs.AccessLogEntry) error { + entry, groups := flattenAccessLog(logEntry) + + usage, usageGroups := usageFromFlattenedLog(entry, groups) + if err := s.CreateAgentNetworkUsage(ctx, usage, usageGroups); err != nil { + log.WithContext(ctx).WithFields(log.Fields{ + "account_id": entry.AccountID, + "model": entry.Model, + }).Errorf("failed to save agent-network usage: %v", err) + return err + } + + settings, err := s.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, entry.AccountID) + if err != nil { + // No settings row (or a transient read error) means we can't confirm + // log collection is enabled — usage is already saved, so skip the full + // row rather than fail the whole ingest. + log.WithContext(ctx).Debugf("skipping full agent-network access-log row for account %s: %v", entry.AccountID, err) + return nil + } + if !settings.EnableLogCollection { + return nil + } + + if err := s.CreateAgentNetworkAccessLog(ctx, entry, groups); err != nil { + log.WithContext(ctx).WithFields(log.Fields{ + "account_id": entry.AccountID, + "service_id": entry.ServiceID, + "model": entry.Model, + "status": entry.StatusCode, + }).Errorf("failed to save agent-network access log: %v", err) + return err + } + return nil +} + +// flattenAccessLog converts a reverse-proxy AccessLogEntry (whose LLM +// dimensions live in the opaque Metadata map) into the flattened +// agent-network row + authorising-group child rows. +func flattenAccessLog(e *accesslogs.AccessLogEntry) (*types.AgentNetworkAccessLog, []types.AgentNetworkAccessLogGroup) { + meta := e.Metadata + + var sourceIP string + if e.GeoLocation.ConnectionIP != nil { + sourceIP = e.GeoLocation.ConnectionIP.String() + } + + entry := &types.AgentNetworkAccessLog{ + ID: e.ID, + AccountID: e.AccountID, + ServiceID: e.ServiceID, + Timestamp: e.Timestamp, + UserID: e.UserId, + SourceIP: sourceIP, + Method: e.Method, + Host: e.Host, + Path: e.Path, + Duration: e.Duration, + StatusCode: e.StatusCode, + AuthMethod: e.AuthMethodUsed, + BytesUpload: e.BytesUpload, + BytesDownload: e.BytesDownload, + + Provider: meta[metaKeyProvider], + Model: meta[metaKeyModel], + SessionID: meta[metaKeySessionID], + ResolvedProviderID: meta[metaKeyResolvedProviderID], + SelectedPolicyID: meta[metaKeySelectedPolicyID], + Decision: meta[metaKeyPolicyDecision], + DenyReason: meta[metaKeyPolicyReason], + InputTokens: parseMetaInt(meta, metaKeyInputTokens), + OutputTokens: parseMetaInt(meta, metaKeyOutputTokens), + TotalTokens: parseMetaInt(meta, metaKeyTotalTokens), + CostUSD: parseMetaFloat(meta, metaKeyCostUSDTotal), + Stream: parseMetaBool(meta, metaKeyStream), + RequestPrompt: meta[metaKeyRequestPrompt], + ResponseCompletion: meta[metaKeyResponseCompletion], + } + + var groups []types.AgentNetworkAccessLogGroup + for _, gid := range parseGroupCSV(meta[metaKeyAuthorisingGroups]) { + groups = append(groups, types.AgentNetworkAccessLogGroup{ + LogID: entry.ID, + GroupID: gid, + AccountID: entry.AccountID, + }) + } + return entry, groups +} + +// usageFromFlattenedLog derives the stripped usage record (and its group child +// rows) from an already-flattened access-log entry. The usage row shares the +// log's ID so the two correlate. +func usageFromFlattenedLog(e *types.AgentNetworkAccessLog, groups []types.AgentNetworkAccessLogGroup) (*types.AgentNetworkUsage, []types.AgentNetworkUsageGroup) { + usage := &types.AgentNetworkUsage{ + ID: e.ID, + AccountID: e.AccountID, + Timestamp: e.Timestamp, + UserID: e.UserID, + ResolvedProviderID: e.ResolvedProviderID, + Provider: e.Provider, + Model: e.Model, + SessionID: e.SessionID, + InputTokens: e.InputTokens, + OutputTokens: e.OutputTokens, + TotalTokens: e.TotalTokens, + CostUSD: e.CostUSD, + } + + usageGroups := make([]types.AgentNetworkUsageGroup, 0, len(groups)) + for _, g := range groups { + usageGroups = append(usageGroups, types.AgentNetworkUsageGroup{ + UsageID: usage.ID, + GroupID: g.GroupID, + AccountID: g.AccountID, + }) + } + return usage, usageGroups +} + +// parseMetaInt parses a non-negative token count. Negative or unparseable +// values are clamped to 0 so a malformed metric can't persist a negative +// counter. +func parseMetaInt(meta map[string]string, key string) int64 { + if v, err := strconv.ParseInt(strings.TrimSpace(meta[key]), 10, 64); err == nil && v >= 0 { + return v + } + return 0 +} + +// parseMetaFloat parses a non-negative, finite cost. Negative, NaN, Inf, or +// unparseable values are clamped to 0 so a malformed metric can't poison the +// stored cost. +func parseMetaFloat(meta map[string]string, key string) float64 { + if v, err := strconv.ParseFloat(strings.TrimSpace(meta[key]), 64); err == nil && v >= 0 && !math.IsInf(v, 0) { + return v + } + return 0 +} + +func parseMetaBool(meta map[string]string, key string) bool { + v, _ := strconv.ParseBool(strings.TrimSpace(meta[key])) + return v +} + +// parseGroupCSV splits the comma-separated authorising-group id list the proxy +// emits, trimming blanks and de-duplicating. Dedup matters because the group +// rows are keyed by (log_id, group_id) / (usage_id, group_id): a repeated id +// in the CSV would otherwise produce a duplicate primary key and fail the +// insert transaction. +func parseGroupCSV(raw string) []string { + if raw == "" { + return nil + } + parts := strings.Split(raw, ",") + out := make([]string, 0, len(parts)) + seen := make(map[string]struct{}, len(parts)) + for _, p := range parts { + if p = strings.TrimSpace(p); p != "" { + if _, dup := seen[p]; dup { + continue + } + seen[p] = struct{}{} + out = append(out, p) + } + } + return out +} diff --git a/management/internals/modules/agentnetwork/accesslog_ingest_realstore_test.go b/management/internals/modules/agentnetwork/accesslog_ingest_realstore_test.go new file mode 100644 index 000000000..431ce680e --- /dev/null +++ b/management/internals/modules/agentnetwork/accesslog_ingest_realstore_test.go @@ -0,0 +1,124 @@ +package agentnetwork + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" + "github.com/netbirdio/netbird/management/server/store" +) + +// newIngestTestEntry builds an agent-network reverse-proxy access-log entry whose +// LLM dimensions live in the opaque Metadata map, as the proxy ships it. +func newIngestTestEntry() *accesslogs.AccessLogEntry { + return &accesslogs.AccessLogEntry{ + ID: "log-1", + AccountID: testAccountID, + ServiceID: "svc-1", + Timestamp: time.Now().UTC(), + Method: "POST", + Host: testEndpoint, + Path: "/v1/chat/completions", + StatusCode: 200, + UserId: "user-1", + AgentNetwork: true, + Metadata: map[string]string{ + metaKeyProvider: "openai", + metaKeyModel: "gpt-5.4", + metaKeyResolvedProviderID: "prov-1", + metaKeySessionID: "sess-1", + metaKeyInputTokens: "100", + metaKeyOutputTokens: "50", + metaKeyTotalTokens: "150", + metaKeyCostUSDTotal: "0.0123", + metaKeyStream: "true", + metaKeyRequestPrompt: "hello", + metaKeyResponseCompletion: "world", + // repeated id must be de-duplicated before the group rows insert. + metaKeyAuthorisingGroups: "grp-eng,grp-eng,grp-ops", + }, + } +} + +// TestIngestAccessLog_RealStore_LogCollectionOff persists the usage ledger +// unconditionally but skips the full access-log row when the account hasn't +// opted into log collection. +func TestIngestAccessLog_RealStore_LogCollectionOff(t *testing.T) { + ctx := context.Background() + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + settings := newSynthTestSettings() + settings.EnableLogCollection = false + require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings)) + + require.NoError(t, IngestAccessLog(ctx, s, newIngestTestEntry())) + + usage, err := s.GetAgentNetworkUsageRows(ctx, store.LockingStrengthNone, testAccountID, types.AgentNetworkAccessLogFilter{}) + require.NoError(t, err) + require.Len(t, usage, 1, "usage row must be written even with log collection off") + assert.Equal(t, int64(100), usage[0].InputTokens, "input tokens must round-trip from metadata") + assert.Equal(t, int64(50), usage[0].OutputTokens, "output tokens must round-trip from metadata") + assert.InDelta(t, 0.0123, usage[0].CostUSD, 1e-9, "cost must round-trip from metadata") + + logs, _, err := s.GetAgentNetworkAccessLogs(ctx, store.LockingStrengthNone, testAccountID, types.AgentNetworkAccessLogFilter{}) + require.NoError(t, err) + assert.Empty(t, logs, "full access-log row must be skipped while log collection is off") +} + +// TestIngestAccessLog_RealStore_LogCollectionOn writes both the usage ledger and +// the full access-log row once the account opts in, carrying the request detail +// and prompt through. +func TestIngestAccessLog_RealStore_LogCollectionOn(t *testing.T) { + ctx := context.Background() + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + settings := newSynthTestSettings() + settings.EnableLogCollection = true + require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings)) + + require.NoError(t, IngestAccessLog(ctx, s, newIngestTestEntry())) + + usage, err := s.GetAgentNetworkUsageRows(ctx, store.LockingStrengthNone, testAccountID, types.AgentNetworkAccessLogFilter{}) + require.NoError(t, err) + require.Len(t, usage, 1, "usage row must be written when log collection is on") + + logs, total, err := s.GetAgentNetworkAccessLogs(ctx, store.LockingStrengthNone, testAccountID, types.AgentNetworkAccessLogFilter{}) + require.NoError(t, err) + require.Equal(t, int64(1), total, "exactly one access-log row expected") + require.Len(t, logs, 1, "full access-log row must be written when log collection is on") + assert.Equal(t, "gpt-5.4", logs[0].Model, "model must flatten from metadata") + assert.Equal(t, "hello", logs[0].RequestPrompt, "prompt must be retained when log collection is on") + assert.Equal(t, "world", logs[0].ResponseCompletion, "completion must be retained when log collection is on") + assert.True(t, logs[0].Stream, "stream flag must flatten from metadata") +} + +func TestParseGroupCSV_DedupAndTrim(t *testing.T) { + assert.Nil(t, parseGroupCSV(""), "empty CSV yields no groups") + assert.Equal(t, []string{"a", "b"}, parseGroupCSV(" a , b , a ,"), + "group CSV must trim, drop blanks, and de-duplicate preserving first-seen order") +} + +func TestParseMetaInt_ClampsNegativeAndJunk(t *testing.T) { + meta := map[string]string{"ok": " 42 ", "neg": "-5", "junk": "abc"} + assert.Equal(t, int64(42), parseMetaInt(meta, "ok"), "valid count parses with surrounding space trimmed") + assert.Equal(t, int64(0), parseMetaInt(meta, "neg"), "negative count clamps to 0") + assert.Equal(t, int64(0), parseMetaInt(meta, "junk"), "unparseable count clamps to 0") + assert.Equal(t, int64(0), parseMetaInt(meta, "missing"), "missing key clamps to 0") +} + +func TestParseMetaFloat_ClampsNegativeInfAndJunk(t *testing.T) { + meta := map[string]string{"ok": "1.5", "neg": "-1", "inf": "Inf", "junk": "x"} + assert.InDelta(t, 1.5, parseMetaFloat(meta, "ok"), 1e-9, "valid cost parses") + assert.Equal(t, float64(0), parseMetaFloat(meta, "neg"), "negative cost clamps to 0") + assert.Equal(t, float64(0), parseMetaFloat(meta, "inf"), "non-finite cost clamps to 0") + assert.Equal(t, float64(0), parseMetaFloat(meta, "junk"), "unparseable cost clamps to 0") +} diff --git a/management/internals/modules/agentnetwork/accesslog_sessions_realstore_test.go b/management/internals/modules/agentnetwork/accesslog_sessions_realstore_test.go new file mode 100644 index 000000000..7d53d7547 --- /dev/null +++ b/management/internals/modules/agentnetwork/accesslog_sessions_realstore_test.go @@ -0,0 +1,343 @@ +package agentnetwork + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/management/server/store" +) + +// baseTime is a fixed reference so session timestamps (and therefore the +// default MAX(timestamp) DESC ordering) are deterministic across runs. +var baseTime = time.Date(2026, 6, 30, 12, 0, 0, 0, time.UTC) + +// accessLogRow builds an agent-network access-log row for the shared test +// account. Functional options tweak the LLM dimensions a given test cares +// about; everything else gets a sane, allow/200 default. +func accessLogRow(id, sessionID string, ts time.Time, opts ...func(*types.AgentNetworkAccessLog)) *types.AgentNetworkAccessLog { + e := &types.AgentNetworkAccessLog{ + ID: id, + AccountID: testAccountID, + ServiceID: "svc-1", + Timestamp: ts, + UserID: "user-1", + SessionID: sessionID, + Method: "POST", + Host: testEndpoint, + Path: "/v1/chat/completions", + StatusCode: 200, + Decision: "allow", + Provider: "openai", + Model: "gpt-5.4", + ResolvedProviderID: "prov-1", + InputTokens: 100, + OutputTokens: 50, + TotalTokens: 150, + CostUSD: 0.01, + } + for _, o := range opts { + o(e) + } + return e +} + +func withUser(u string) func(*types.AgentNetworkAccessLog) { + return func(e *types.AgentNetworkAccessLog) { e.UserID = u } +} + +func withModel(m string) func(*types.AgentNetworkAccessLog) { + return func(e *types.AgentNetworkAccessLog) { e.Model = m } +} + +func withProvider(vendor, resolvedID string) func(*types.AgentNetworkAccessLog) { + return func(e *types.AgentNetworkAccessLog) { + e.Provider = vendor + e.ResolvedProviderID = resolvedID + } +} + +func withDeny(reason string) func(*types.AgentNetworkAccessLog) { + return func(e *types.AgentNetworkAccessLog) { + e.Decision = "deny" + e.DenyReason = reason + e.StatusCode = 403 + } +} + +func withTokens(in, out, total int64, cost float64) func(*types.AgentNetworkAccessLog) { + return func(e *types.AgentNetworkAccessLog) { + e.InputTokens = in + e.OutputTokens = out + e.TotalTokens = total + e.CostUSD = cost + } +} + +func withGroups(gids ...string) func(*types.AgentNetworkAccessLog) { + return func(e *types.AgentNetworkAccessLog) { e.GroupIDs = gids } +} + +// seedAccessLogs writes rows (and their authorising-group child rows) directly +// into the store, bypassing ingest so a test can control every dimension. +func seedAccessLogs(t *testing.T, s store.Store, rows ...*types.AgentNetworkAccessLog) { + t.Helper() + ctx := context.Background() + for _, r := range rows { + var groups []types.AgentNetworkAccessLogGroup + for _, g := range r.GroupIDs { + groups = append(groups, types.AgentNetworkAccessLogGroup{ + LogID: r.ID, + GroupID: g, + AccountID: r.AccountID, + }) + } + require.NoError(t, s.CreateAgentNetworkAccessLog(ctx, r, groups), "seed access-log row %s", r.ID) + } +} + +func newSessionsTestStore(t *testing.T) store.Store { + t.Helper() + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + t.Cleanup(cleanup) + return s +} + +// sessionIDs projects the session ids from a page of session summaries, in +// order, for concise ordering assertions. +func sessionIDs(sessions []*types.AgentNetworkAccessLogSession) []string { + out := make([]string, 0, len(sessions)) + for _, s := range sessions { + out = append(out, s.SessionID) + } + return out +} + +// TestAccessLogSessions_FoldAndAggregate verifies that multiple entries sharing +// a session id fold into one summary with summed usage, distinct +// provider/model lists, a deny rollup, and correct first/last activity bounds. +func TestAccessLogSessions_FoldAndAggregate(t *testing.T) { + ctx := context.Background() + s := newSessionsTestStore(t) + + // sess-a: three entries spanning 3 minutes, two providers/models, one deny. + seedAccessLogs(t, s, + accessLogRow("a1", "sess-a", baseTime, + withProvider("openai", "prov-openai"), withModel("gpt-5.4"), + withTokens(100, 50, 150, 0.01), withGroups("grp-eng")), + accessLogRow("a2", "sess-a", baseTime.Add(1*time.Minute), + withProvider("anthropic", "prov-anthropic"), withModel("claude-haiku-4-5"), + withTokens(200, 80, 280, 0.02), withGroups("grp-eng", "grp-ops")), + accessLogRow("a3", "sess-a", baseTime.Add(2*time.Minute), + withProvider("openai", "prov-openai"), withModel("gpt-5.4"), + withTokens(10, 5, 15, 0.001), withDeny("llm_policy.token_cap_exceeded")), + // sess-b: a single allow entry. + accessLogRow("b1", "sess-b", baseTime.Add(30*time.Minute), + withTokens(1, 2, 3, 0.5)), + ) + + sessions, total, err := s.GetAgentNetworkAccessLogSessions(ctx, store.LockingStrengthNone, testAccountID, types.AgentNetworkAccessLogFilter{}) + require.NoError(t, err) + require.Equal(t, int64(2), total, "two distinct sessions") + require.Len(t, sessions, 2) + + // Default sort is last-activity DESC, so sess-b (12:30) precedes sess-a (12:02). + require.Equal(t, []string{"sess-b", "sess-a"}, sessionIDs(sessions)) + + a := sessions[1] + assert.Equal(t, "sess-a", a.SessionID) + assert.Equal(t, 3, a.RequestCount, "three requests folded") + assert.Equal(t, int64(310), a.InputTokens, "input tokens summed") + assert.Equal(t, int64(135), a.OutputTokens, "output tokens summed") + assert.Equal(t, int64(445), a.TotalTokens, "total tokens summed") + assert.InDelta(t, 0.031, a.CostUSD, 1e-9, "cost summed") + assert.Equal(t, "deny", a.Decision, "any deny makes the session a deny") + assert.ElementsMatch(t, []string{"openai", "anthropic"}, a.Providers, "distinct providers") + assert.ElementsMatch(t, []string{"gpt-5.4", "claude-haiku-4-5"}, a.Models, "distinct models") + assert.ElementsMatch(t, []string{"grp-eng", "grp-ops"}, a.GroupIDs, "union of authorising groups") + assert.Equal(t, baseTime, a.StartedAt.UTC(), "started at is the earliest entry") + assert.Equal(t, baseTime.Add(2*time.Minute), a.EndedAt.UTC(), "ended at is the latest entry") + assert.Len(t, a.Entries, 3, "entries carried through") + + b := sessions[0] + assert.Equal(t, "sess-b", b.SessionID) + assert.Equal(t, 1, b.RequestCount) + assert.Equal(t, "allow", b.Decision) +} + +// TestAccessLogSessions_SessionlessRowsAreSingletons verifies that entries with +// no session id each form their own singleton session keyed by the row id. +func TestAccessLogSessions_SessionlessRowsAreSingletons(t *testing.T) { + ctx := context.Background() + s := newSessionsTestStore(t) + + seedAccessLogs(t, s, + accessLogRow("solo-1", "", baseTime), + accessLogRow("solo-2", "", baseTime.Add(time.Minute)), + // A real session with two entries, to prove they don't merge with the singletons. + accessLogRow("g1", "sess-x", baseTime.Add(2*time.Minute)), + accessLogRow("g2", "sess-x", baseTime.Add(3*time.Minute)), + ) + + sessions, total, err := s.GetAgentNetworkAccessLogSessions(ctx, store.LockingStrengthNone, testAccountID, types.AgentNetworkAccessLogFilter{}) + require.NoError(t, err) + require.Equal(t, int64(3), total, "two singletons + one grouped session") + require.Len(t, sessions, 3) + + for _, sess := range sessions { + if sess.SessionID == "sess-x" { + assert.Equal(t, 2, sess.RequestCount, "grouped session folds both entries") + } else { + assert.Empty(t, sess.SessionID, "singleton carries no session id") + assert.Equal(t, 1, sess.RequestCount, "singleton has exactly one request") + } + } +} + +// TestAccessLogSessions_Pagination verifies that paging returns the correct +// slice of sessions in stable order, with a stable total across pages and no +// overlap between pages. +func TestAccessLogSessions_Pagination(t *testing.T) { + ctx := context.Background() + s := newSessionsTestStore(t) + + // Five sessions, each a single entry, with increasing timestamps so the + // default MAX(timestamp) DESC order is sess-5, sess-4, sess-3, sess-2, sess-1. + rows := make([]*types.AgentNetworkAccessLog, 0, 5) + for i := 1; i <= 5; i++ { + rows = append(rows, accessLogRow( + "row-"+itoa(i), "sess-"+itoa(i), baseTime.Add(time.Duration(i)*time.Minute))) + } + seedAccessLogs(t, s, rows...) + + page := func(p int) []*types.AgentNetworkAccessLogSession { + sessions, total, err := s.GetAgentNetworkAccessLogSessions(ctx, store.LockingStrengthNone, testAccountID, + types.AgentNetworkAccessLogFilter{Page: p, PageSize: 2}) + require.NoError(t, err) + require.Equal(t, int64(5), total, "total session count is stable across pages") + return sessions + } + + assert.Equal(t, []string{"sess-5", "sess-4"}, sessionIDs(page(1)), "page 1: two newest") + assert.Equal(t, []string{"sess-3", "sess-2"}, sessionIDs(page(2)), "page 2: next two") + assert.Equal(t, []string{"sess-1"}, sessionIDs(page(3)), "page 3: remaining one") + assert.Empty(t, page(4), "page 4: past the end is empty") +} + +// TestAccessLogSessions_Filtering verifies each filter is applied before +// grouping, so the session set (and total) reflect only matching entries. +func TestAccessLogSessions_Filtering(t *testing.T) { + ctx := context.Background() + s := newSessionsTestStore(t) + + seedAccessLogs(t, s, + accessLogRow("r1", "sess-1", baseTime.Add(1*time.Minute), + withUser("alice"), withProvider("openai", "prov-openai"), withModel("gpt-5.4")), + accessLogRow("r2", "sess-2", baseTime.Add(2*time.Minute), + withUser("bob"), withProvider("anthropic", "prov-anthropic"), withModel("claude-haiku-4-5"), + withDeny("llm_policy.no_authorized_provider"), withGroups("grp-ops")), + accessLogRow("r3", "sess-3", baseTime.Add(3*time.Minute), + withUser("alice"), withProvider("openai", "prov-openai"), withModel("gpt-5.4"), + withGroups("grp-eng")), + ) + + filterCases := []struct { + name string + filter types.AgentNetworkAccessLogFilter + wantIDs []string + wantTot int64 + }{ + { + name: "by session id", + filter: types.AgentNetworkAccessLogFilter{SessionID: strp("sess-2")}, + wantIDs: []string{"sess-2"}, + wantTot: 1, + }, + { + name: "by user id", + filter: types.AgentNetworkAccessLogFilter{UserID: strp("alice")}, + wantIDs: []string{"sess-3", "sess-1"}, // last-activity DESC + wantTot: 2, + }, + { + name: "by model", + filter: types.AgentNetworkAccessLogFilter{Models: []string{"claude-haiku-4-5"}}, + wantIDs: []string{"sess-2"}, + wantTot: 1, + }, + { + name: "by resolved provider id", + filter: types.AgentNetworkAccessLogFilter{ProviderIDs: []string{"prov-openai"}}, + wantIDs: []string{"sess-3", "sess-1"}, + wantTot: 2, + }, + { + name: "by decision deny", + filter: types.AgentNetworkAccessLogFilter{Decision: strp("deny")}, + wantIDs: []string{"sess-2"}, + wantTot: 1, + }, + { + name: "by authorising group", + filter: types.AgentNetworkAccessLogFilter{GroupIDs: []string{"grp-eng"}}, + wantIDs: []string{"sess-3"}, + wantTot: 1, + }, + { + name: "by date range excludes earlier", + filter: types.AgentNetworkAccessLogFilter{ + StartDate: tp(baseTime.Add(90 * time.Second)), // after r1 (12:01), before r2 (12:02) + }, + wantIDs: []string{"sess-3", "sess-2"}, + wantTot: 2, + }, + } + + for _, tc := range filterCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + sessions, total, err := s.GetAgentNetworkAccessLogSessions(ctx, store.LockingStrengthNone, testAccountID, tc.filter) + require.NoError(t, err) + assert.Equal(t, tc.wantTot, total, "filtered total") + assert.Equal(t, tc.wantIDs, sessionIDs(sessions), "filtered session ids in order") + }) + } +} + +// TestAccessLogSessions_SortByCost verifies session-level aggregate sorting: +// ordering by summed cost, ascending and descending. +func TestAccessLogSessions_SortByCost(t *testing.T) { + ctx := context.Background() + s := newSessionsTestStore(t) + + // cheap: 0.01 total; mid: 0.05 total; pricey: 0.20 total (two entries). + seedAccessLogs(t, s, + accessLogRow("c1", "cheap", baseTime.Add(1*time.Minute), withTokens(1, 1, 2, 0.01)), + accessLogRow("m1", "mid", baseTime.Add(2*time.Minute), withTokens(1, 1, 2, 0.05)), + accessLogRow("p1", "pricey", baseTime.Add(3*time.Minute), withTokens(1, 1, 2, 0.15)), + accessLogRow("p2", "pricey", baseTime.Add(4*time.Minute), withTokens(1, 1, 2, 0.05)), + ) + + desc, total, err := s.GetAgentNetworkAccessLogSessions(ctx, store.LockingStrengthNone, testAccountID, + types.AgentNetworkAccessLogFilter{SortBy: "cost_usd", SortOrder: "desc"}) + require.NoError(t, err) + require.Equal(t, int64(3), total) + assert.Equal(t, []string{"pricey", "mid", "cheap"}, sessionIDs(desc), "descending by summed cost") + + asc, _, err := s.GetAgentNetworkAccessLogSessions(ctx, store.LockingStrengthNone, testAccountID, + types.AgentNetworkAccessLogFilter{SortBy: "cost_usd", SortOrder: "asc"}) + require.NoError(t, err) + assert.Equal(t, []string{"cheap", "mid", "pricey"}, sessionIDs(asc), "ascending by summed cost") +} + +// strp / tp / itoa are tiny local helpers to keep the filter table terse. +func strp(s string) *string { return &s } + +func tp(t time.Time) *time.Time { return &t } + +func itoa(i int) string { return string(rune('0' + i)) } diff --git a/management/internals/modules/agentnetwork/affectedpeers_hook.go b/management/internals/modules/agentnetwork/affectedpeers_hook.go new file mode 100644 index 000000000..58347666d --- /dev/null +++ b/management/internals/modules/agentnetwork/affectedpeers_hook.go @@ -0,0 +1,15 @@ +package agentnetwork + +import "github.com/netbirdio/netbird/management/server/affectedpeers" + +// init registers the agent-network service synthesiser with the affectedpeers +// resolver. Agent-network reverse-proxy services are synthesised on demand and +// never persisted, so the resolver can't load them from the store; without them +// it can't fold the embedded proxy peer into the affected set on a client +// group/peer change, and the proxy never learns a newly authorised client until +// it reconnects. Registered here (rather than via a direct +// affectedpeers→agentnetwork import) to avoid an import cycle +// (agentnetwork → account → affectedpeers). +func init() { + affectedpeers.SetAgentNetworkSynthesizer(SynthesizeServices) +} diff --git a/management/internals/modules/agentnetwork/catalog/catalog.go b/management/internals/modules/agentnetwork/catalog/catalog.go new file mode 100644 index 000000000..baf622778 --- /dev/null +++ b/management/internals/modules/agentnetwork/catalog/catalog.go @@ -0,0 +1,749 @@ +// Package catalog defines the static set of Agent Network providers +// recognized by the management server. The catalog is consulted both to +// validate provider_id on create/update and to surface the available +// providers (and their models) to the dashboard. +package catalog + +import "github.com/netbirdio/netbird/shared/management/http/api" + +// Model is the in-memory representation of a catalog model. +type Model struct { + ID string + Label string + InputPer1k float64 + OutputPer1k float64 + ContextWindow int +} + +// ProviderKind groups catalog entries for UI presentation. The split +// is semantic, not technical: +// - KindProvider: the upstream is a vendor's first-party API (OpenAI, +// Anthropic, Mistral, Bedrock, etc.) — NetBird talks straight to +// the model provider. +// - KindGateway: the upstream is itself a routing / aggregation layer +// in front of multiple providers (LiteLLM, Portkey, Helicone, …). +// These typically need NetBird identity stamped onto upstream +// requests so the gateway's analytics and budgets attribute to the +// real caller; that's what IdentityInjection is for. +// - KindCustom: the catch-all "OpenAI-compatible self-hosted endpoint" +// entry (vLLM, Ollama, custom inference servers). +// +// Frontend uses Kind to group the provider Select in the modal so an +// operator can spot at a glance which catalog entries proxy other +// providers vs. talk straight to one. Backend doesn't dispatch on Kind +// today; it's purely a presentation hint. +type ProviderKind string + +const ( + KindProvider ProviderKind = "provider" + KindGateway ProviderKind = "gateway" + KindCustom ProviderKind = "custom" +) + +// Provider is the in-memory representation of a catalog provider. +type Provider struct { + ID string + Name string + Description string + DefaultHost string + // Kind groups this entry for UI presentation; see ProviderKind. + Kind ProviderKind + // AuthHeaderName is the HTTP header the provider's API expects + // the credential under (e.g. "Authorization" for OpenAI, + // "x-api-key" for Anthropic). Combined with AuthHeaderTemplate + // at synthesis time to inject the auth header on every upstream + // request. + AuthHeaderName string + AuthHeaderTemplate string + DefaultContentType string + BrandColor string + // ParserID names the proxy LLM parser surface this provider + // speaks (matches llm.Parser.ProviderName: "openai", + // "anthropic"). Multiple catalog ids may share a parser surface + // (e.g. azure_openai_api and mistral_api both speak the OpenAI + // shape). Empty when no parser is yet implemented for the + // surface — the proxy middleware then falls back to URL sniffing + // or skips request-side enrichment. + ParserID string + // IdentityInjection, when non-nil, instructs the proxy to stamp + // the caller's NetBird identity onto upstream requests under the + // configured header names. Used for gateways like LiteLLM that + // key budgets and attribution off request headers (the gateway + // otherwise has no way to learn which user / group made the call). + // The proxy strips the same header names from the inbound request + // before stamping ours, so an app can't spoof identity by setting + // these headers itself. + IdentityInjection *IdentityInjection + // ExtraHeaders is a catalog-declared list of additional per- + // provider routing/config headers the proxy stamps on every + // upstream request. Distinct from AuthHeaderName/Template (which + // always carries the API_KEY) and from IdentityInjection (caller + // identity). Each entry surfaces an optional input on the + // dashboard's provider modal whose value lives on the provider + // record's ExtraValues map (keyed by ExtraHeader.Name). Empty + // list = no extra inputs rendered. Used today by Portkey for + // "x-portkey-config: pc-..." (a saved-config id that resolves + // upstream provider + credentials on Portkey's hosted side). + ExtraHeaders []ExtraHeader + Models []Model +} + +// ExtraHeader names a single optional per-provider routing/config +// header. Catalog declares N of these per provider type; the operator +// fills any subset on the provider record (see Provider.ExtraValues). +// At synth time, only entries with a non-empty operator value are +// stamped; the proxy's identity-inject middleware applies anti-spoof +// (Remove + Add) so a client can't supply these headers themselves. +// +// UI copy (label / help text / tooltip) for each known Name lives on +// the dashboard, not here — the backend's job is just to declare +// which wire headers are accepted. New provider needs an extra +// header? Add the Name here AND the matching UI copy on the dashboard. +type ExtraHeader struct { + // Name is the wire header name, e.g. "x-portkey-config". + Name string +} + +// IdentityInjection describes how the proxy stamps NetBird identity onto +// upstream gateway requests. Exactly one shape must be set — they're +// mutually exclusive and dispatched by the inject middleware. +// +// Shape choice tracks the wire convention the upstream gateway uses, +// not the vendor name. New gateways with a known shape become a catalog +// entry, not a new code path. +type IdentityInjection struct { + // HeaderPair emits separate headers per identity dimension + // (end-user id, tags as CSV). LiteLLM and OpenAI-compatible + // self-hosted gateways that read identity from dedicated headers. + HeaderPair *HeaderPairInjection + // JSONMetadata emits a single header carrying a JSON object with + // reserved keys for user / groups / etc. Portkey, Helicone-style + // metadata headers, anything that wants a structured envelope. + JSONMetadata *JSONMetadataInjection +} + +// HeaderPairInjection is the LiteLLM-style wire convention. +type HeaderPairInjection struct { + // Customizable, when true, marks the wire header names as + // operator-overridable: the dashboard surfaces EndUserIDHeader + // and TagsHeader as editable inputs (defaults shown as + // placeholders) and the synthesizer pulls the actual values from + // the provider record's IdentityHeader* fields rather than from + // these defaults. An empty operator value disables stamping for + // that dimension. Used today for Bifrost, whose log-metadata / + // telemetry header prefix (x-bf-lh-* vs x-bf-dim-*) is a + // per-operator choice; LiteLLM and similar gateways with a fixed + // wire protocol leave this false so the catalog defaults are + // authoritative. + Customizable bool + // EndUserIDHeader receives the caller's display identity (user + // email when the peer is attached to a user, else peer.Name), + // e.g. "x-litellm-end-user-id". + EndUserIDHeader string + // TagsHeader receives the caller's NetBird group display names + // as a CSV, e.g. "x-litellm-tags". + TagsHeader string + // TagsInBody, when true, additionally writes the tag list into + // the request body's metadata.tags array (a JSON path the + // gateway parses for budget enforcement). LiteLLM only honours + // metadata.tags for tag-budget gating — its x-litellm-tags + // header path feeds spend tracking but bypasses + // _tag_max_budget_check entirely. Body inject is skipped when + // the request body is empty, truncated, non-JSON, or when an + // existing metadata field is a non-object value (defensive: we + // never clobber a client-supplied non-object). The header path + // remains a robust fallback for spend tracking in those cases. + TagsInBody bool + // EndUserIDInBody, when true, additionally writes the display + // identity into the request body's top-level "user" field (the + // OpenAI-standard end-user identifier). LiteLLM resolves the end + // user id from headers first then body, so for LiteLLM this is + // belt-and-suspenders. It matters when an OpenAI-compatible + // gateway downstream of LiteLLM (or OpenAI direct, bypassing + // LiteLLM) only reads the body, and as anti-spoof: client- + // supplied "user" values are overwritten with our trusted + // identity. Same skip rules as TagsInBody. + EndUserIDInBody bool +} + +// JSONMetadataInjection is the Portkey-style wire convention: a single +// header carrying a JSON object. NetBird identity fields land under the +// configured reserved keys; missing keys (empty string) are skipped at +// emit time. +type JSONMetadataInjection struct { + // Customizable, when true, marks the JSON keys as operator- + // overridable. The dashboard surfaces UserKey and GroupsKey as + // editable inputs (the catalog values shown as placeholders) and + // the synthesizer pulls the actual JSON-key names from the + // provider record's IdentityHeader* fields. Same field reuse as + // HeaderPair's customizable path — the dimensions (user identity, + // groups) are the same, only the wire encoding differs (JSON key + // vs HTTP header name). An empty operator value disables emission + // for that dimension. Used today for Cloudflare AI Gateway, whose + // cf-aig-metadata header accepts arbitrary JSON keys; Portkey + // leaves this false because its keys are reserved by the Portkey + // schema. + Customizable bool + // Header is the wire header name carrying the JSON payload, e.g. + // "x-portkey-metadata". + Header string + // UserKey is the JSON key for the caller's display identity. + // Portkey reserves "_user" for this dimension. + UserKey string + // GroupsKey is the JSON key for the caller's NetBird groups, + // emitted as a CSV string value (Portkey requires string values). + GroupsKey string + // MaxValueLength caps each emitted JSON value, in bytes. Portkey + // enforces a 128-char limit per value; oversized values are + // truncated rather than failing the request. 0 disables the cap. + MaxValueLength int +} + +// providers is the canonical list of supported Agent Network providers. +// Update this list together with the dashboard's PROVIDER_CATALOG. +var providers = []Provider{ + { + ID: "openai_api", + Kind: KindProvider, + Name: "OpenAI API", + Description: "GPT, Responses API, and Embeddings", + DefaultHost: "api.openai.com", + AuthHeaderName: "Authorization", + AuthHeaderTemplate: "Bearer ${API_KEY}", + DefaultContentType: "application/json", + BrandColor: "#10A37F", + ParserID: "openai", + // Pricing + context windows cross-checked against LiteLLM's + // model_prices_and_context_window.json. Notable corrections from + // earlier values: o4-mini repriced from $4/$16 to $1.10/$4.40 + // per MTok, gpt-4o from $5/$15 to $2.50/$10, and the GPT-5 + // family context windows split between 1.05M for full-size + // models and 272K for mini/nano/codex variants. + Models: []Model{ + {ID: "gpt-5.5", Label: "GPT-5.5", InputPer1k: 0.005, OutputPer1k: 0.030, ContextWindow: 1050000}, + {ID: "gpt-5.5-pro", Label: "GPT-5.5 Pro", InputPer1k: 0.030, OutputPer1k: 0.180, ContextWindow: 1050000}, + {ID: "gpt-5.4", Label: "GPT-5.4", InputPer1k: 0.0025, OutputPer1k: 0.015, ContextWindow: 1050000}, + {ID: "gpt-5.4-pro", Label: "GPT-5.4 Pro", InputPer1k: 0.030, OutputPer1k: 0.180, ContextWindow: 1050000}, + {ID: "gpt-5.4-mini", Label: "GPT-5.4 Mini", InputPer1k: 0.00075, OutputPer1k: 0.0045, ContextWindow: 272000}, + {ID: "gpt-5.4-nano", Label: "GPT-5.4 Nano", InputPer1k: 0.0002, OutputPer1k: 0.00125, ContextWindow: 272000}, + {ID: "gpt-5.3-codex", Label: "GPT-5.3 Codex", InputPer1k: 0.00175, OutputPer1k: 0.014, ContextWindow: 272000}, + {ID: "gpt-5.3-chat-latest", Label: "GPT-5.3 Chat", InputPer1k: 0.00175, OutputPer1k: 0.014, ContextWindow: 128000}, + {ID: "o4-mini", Label: "o4-mini", InputPer1k: 0.0011, OutputPer1k: 0.0044, ContextWindow: 200000}, + {ID: "gpt-4.1", Label: "GPT-4.1", InputPer1k: 0.002, OutputPer1k: 0.008, ContextWindow: 1047576}, + {ID: "gpt-4.1-mini", Label: "GPT-4.1 mini", InputPer1k: 0.0004, OutputPer1k: 0.0016, ContextWindow: 1047576}, + {ID: "gpt-4.1-nano", Label: "GPT-4.1 nano", InputPer1k: 0.0001, OutputPer1k: 0.0004, ContextWindow: 1047576}, + {ID: "gpt-4o", Label: "GPT-4o", InputPer1k: 0.0025, OutputPer1k: 0.010, ContextWindow: 128000}, + {ID: "gpt-4o-mini", Label: "GPT-4o mini", InputPer1k: 0.00015, OutputPer1k: 0.0006, ContextWindow: 128000}, + {ID: "gpt-4-turbo", Label: "GPT-4 Turbo", InputPer1k: 0.01, OutputPer1k: 0.03, ContextWindow: 128000}, + {ID: "gpt-3.5-turbo", Label: "GPT-3.5 Turbo", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 16385}, + {ID: "text-embedding-3-large", Label: "text-embedding-3-large", InputPer1k: 0.00013, OutputPer1k: 0, ContextWindow: 8191}, + {ID: "text-embedding-3-small", Label: "text-embedding-3-small", InputPer1k: 0.00002, OutputPer1k: 0, ContextWindow: 8191}, + }, + }, + { + ID: "anthropic_api", + Kind: KindProvider, + Name: "Anthropic API", + Description: "Claude Messages API", + DefaultHost: "api.anthropic.com", + AuthHeaderName: "x-api-key", + AuthHeaderTemplate: "${API_KEY}", + DefaultContentType: "application/json", + BrandColor: "#D97757", + ParserID: "anthropic", + // Per Anthropic's current model lineup. Pricing in USD per 1k + // tokens. Context windows: 4.6+ family is 1M; Haiku 4.5 stays at + // 200K. claude-3-7-sonnet and claude-3-5-haiku retired + // 2026-02-19 — dropped from the catalog. claude-opus-4-1 + // deprecated, retires 2026-08-05 — kept until the cutover. + // claude-mythos-5 omitted: Project Glasswing access only, not a + // general-availability target. claude-fable-5 requires the + // account to be on >= 30-day data retention or all requests + // 400. + Models: []Model{ + {ID: "claude-fable-5", Label: "Claude Fable 5", InputPer1k: 0.010, OutputPer1k: 0.050, ContextWindow: 1000000}, + {ID: "claude-opus-4-8", Label: "Claude Opus 4.8", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000}, + {ID: "claude-opus-4-7", Label: "Claude Opus 4.7", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000}, + {ID: "claude-opus-4-6", Label: "Claude Opus 4.6", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000}, + {ID: "claude-opus-4-1", Label: "Claude Opus 4.1 (deprecated, retires 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000}, + {ID: "claude-sonnet-4-6", Label: "Claude Sonnet 4.6", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000}, + {ID: "claude-sonnet-4-5", Label: "Claude Sonnet 4.5", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000}, + {ID: "claude-haiku-4-5", Label: "Claude Haiku 4.5", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000}, + }, + }, + { + ID: "azure_openai_api", + Kind: KindProvider, + Name: "Azure OpenAI API", + Description: "Azure-hosted OpenAI deployments", + DefaultHost: ".openai.azure.com", + AuthHeaderName: "api-key", + AuthHeaderTemplate: "${API_KEY}", + DefaultContentType: "application/json", + BrandColor: "#0078D4", + ParserID: "openai", + // Mirrors openai_api pricing — Azure resells OpenAI models at the + // same per-token rates, just under different deployment names. + Models: []Model{ + {ID: "gpt-5.5", Label: "GPT-5.5 (Azure)", InputPer1k: 0.005, OutputPer1k: 0.030, ContextWindow: 1050000}, + {ID: "gpt-5.4", Label: "GPT-5.4 (Azure)", InputPer1k: 0.0025, OutputPer1k: 0.015, ContextWindow: 1050000}, + {ID: "gpt-5.4-mini", Label: "GPT-5.4 Mini (Azure)", InputPer1k: 0.00075, OutputPer1k: 0.0045, ContextWindow: 272000}, + {ID: "gpt-5.4-nano", Label: "GPT-5.4 Nano (Azure)", InputPer1k: 0.0002, OutputPer1k: 0.00125, ContextWindow: 272000}, + {ID: "o4-mini", Label: "o4-mini (Azure)", InputPer1k: 0.0011, OutputPer1k: 0.0044, ContextWindow: 200000}, + {ID: "gpt-4.1", Label: "GPT-4.1 (Azure)", InputPer1k: 0.002, OutputPer1k: 0.008, ContextWindow: 1047576}, + {ID: "gpt-4.1-mini", Label: "GPT-4.1 mini (Azure)", InputPer1k: 0.0004, OutputPer1k: 0.0016, ContextWindow: 1047576}, + {ID: "gpt-4o", Label: "GPT-4o (Azure)", InputPer1k: 0.0025, OutputPer1k: 0.010, ContextWindow: 128000}, + {ID: "gpt-4o-mini", Label: "GPT-4o mini (Azure)", InputPer1k: 0.00015, OutputPer1k: 0.0006, ContextWindow: 128000}, + {ID: "gpt-35-turbo", Label: "GPT-3.5 Turbo (Azure)", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 16385}, + }, + }, + { + ID: "bedrock_api", + Kind: KindProvider, + Name: "AWS Bedrock API", + Description: "Anthropic, Meta, Cohere via Bedrock", + DefaultHost: "bedrock-runtime..amazonaws.com", + AuthHeaderName: "Authorization", + AuthHeaderTemplate: "Bearer ${API_KEY}", + DefaultContentType: "application/json", + BrandColor: "#FF9900", + // Anthropic models on Bedrock take the anthropic.* prefix and + // follow the same lineup / pricing as the first-party Anthropic + // catalog entry above. claude-3-7-sonnet and claude-3-5-haiku + // were retired upstream on 2026-02-19 — dropped from the + // Bedrock list too. Amazon Nova entries cross-checked against + // LiteLLM (added Nova Micro + the new Nova 2 Lite preview). + // Llama 3.3 70B entry kept unchanged — LiteLLM tracks only + // per-region Llama 3 entries; standalone 3.3 not yet listed. + Models: []Model{ + {ID: "anthropic.claude-opus-4-8", Label: "Claude Opus 4.8 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000}, + {ID: "anthropic.claude-opus-4-7", Label: "Claude Opus 4.7 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000}, + {ID: "anthropic.claude-opus-4-6", Label: "Claude Opus 4.6 (Bedrock)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000}, + {ID: "anthropic.claude-opus-4-1", Label: "Claude Opus 4.1 (Bedrock, deprecated 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000}, + {ID: "anthropic.claude-sonnet-4-6", Label: "Claude Sonnet 4.6 (Bedrock)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000}, + {ID: "anthropic.claude-sonnet-4-5", Label: "Claude Sonnet 4.5 (Bedrock)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000}, + {ID: "anthropic.claude-haiku-4-5", Label: "Claude Haiku 4.5 (Bedrock)", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000}, + {ID: "meta.llama3-3-70b-instruct", Label: "Llama 3.3 70B (Bedrock)", InputPer1k: 0.00072, OutputPer1k: 0.00072, ContextWindow: 128000}, + {ID: "amazon.nova-2-lite", Label: "Amazon Nova 2 Lite (Bedrock, preview)", InputPer1k: 0.0003, OutputPer1k: 0.0025, ContextWindow: 1000000}, + {ID: "amazon.nova-pro", Label: "Amazon Nova Pro (Bedrock)", InputPer1k: 0.0008, OutputPer1k: 0.0032, ContextWindow: 300000}, + {ID: "amazon.nova-lite", Label: "Amazon Nova Lite (Bedrock)", InputPer1k: 0.00006, OutputPer1k: 0.00024, ContextWindow: 300000}, + {ID: "amazon.nova-micro", Label: "Amazon Nova Micro (Bedrock)", InputPer1k: 0.000035, OutputPer1k: 0.00014, ContextWindow: 128000}, + }, + }, + { + ID: "vertex_ai_api", + Kind: KindProvider, + Name: "Google Vertex AI API", + Description: "Anthropic Claude models hosted on Vertex AI", + DefaultHost: "-aiplatform.googleapis.com", + AuthHeaderName: "Authorization", + AuthHeaderTemplate: "Bearer ${API_KEY}", + DefaultContentType: "application/json", + BrandColor: "#4285F4", + // Vertex carries the model in the URL path and authenticates with a + // service-account-minted OAuth token (api_key = "keyfile::"). + // Only Anthropic-on-Vertex is metered today: the request parser maps the + // anthropic publisher to the Anthropic parser, so the lineup + prices + // mirror the first-party Anthropic catalog (LiteLLM vertex_ai/claude-* + // confirms the same per-token rates; cross-region profiles in eu/apac + // carry a ~10% premium that base pricing does not model). Gemini (the + // google publisher) is intentionally omitted until a Gemini parser + // exists — the router denies unmeterable publishers rather than forward + // them uncounted. + Models: []Model{ + {ID: "claude-fable-5", Label: "Claude Fable 5 (Vertex)", InputPer1k: 0.010, OutputPer1k: 0.050, ContextWindow: 1000000}, + {ID: "claude-opus-4-8", Label: "Claude Opus 4.8 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000}, + {ID: "claude-opus-4-7", Label: "Claude Opus 4.7 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000}, + {ID: "claude-opus-4-6", Label: "Claude Opus 4.6 (Vertex)", InputPer1k: 0.005, OutputPer1k: 0.025, ContextWindow: 1000000}, + {ID: "claude-opus-4-1", Label: "Claude Opus 4.1 (Vertex, deprecated 2026-08-05)", InputPer1k: 0.015, OutputPer1k: 0.075, ContextWindow: 200000}, + {ID: "claude-sonnet-4-6", Label: "Claude Sonnet 4.6 (Vertex)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 1000000}, + {ID: "claude-sonnet-4-5", Label: "Claude Sonnet 4.5 (Vertex)", InputPer1k: 0.003, OutputPer1k: 0.015, ContextWindow: 200000}, + {ID: "claude-haiku-4-5", Label: "Claude Haiku 4.5 (Vertex)", InputPer1k: 0.001, OutputPer1k: 0.005, ContextWindow: 200000}, + }, + }, + { + ID: "mistral_api", + Kind: KindProvider, + Name: "Mistral API", + Description: "Mistral cloud API", + DefaultHost: "api.mistral.ai", + AuthHeaderName: "Authorization", + AuthHeaderTemplate: "Bearer ${API_KEY}", + DefaultContentType: "application/json", + BrandColor: "#FF7000", + ParserID: "openai", + // Pricing + context windows cross-checked against LiteLLM. Key + // gotchas the marketing page hides: + // - `mistral-medium-latest` aliases to Medium 3.1 ($0.40/$2), + // NOT Medium 3.5 ($1.50/$7.50). Catalog exposes both. + // - `mistral-large-latest` aliases to Large 3 — 262K context, + // cheaper than Medium 3.5. + // - Magistral models are tuned for reasoning but cap context + // at only 40K (vs 128K-262K elsewhere). + // - `codestral-latest` still routes to the old 2405 build + // ($1/$3) per LiteLLM; the newer codestral-2508 is both + // cheaper and longer-context. Both exposed. + // - Pixtral was folded into the main Large/Medium series; no + // standalone vision entry. + Models: []Model{ + {ID: "mistral-large-latest", Label: "Mistral Large 3", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 262144}, + {ID: "mistral-medium-latest", Label: "Mistral Medium 3.1", InputPer1k: 0.0004, OutputPer1k: 0.002, ContextWindow: 131072}, + {ID: "mistral-medium-3-5", Label: "Mistral Medium 3.5", InputPer1k: 0.0015, OutputPer1k: 0.0075, ContextWindow: 262144}, + {ID: "mistral-small-latest", Label: "Mistral Small 3.2", InputPer1k: 0.00006, OutputPer1k: 0.00018, ContextWindow: 131072}, + {ID: "magistral-medium-latest", Label: "Magistral Medium (reasoning)", InputPer1k: 0.002, OutputPer1k: 0.005, ContextWindow: 40000}, + {ID: "magistral-small-latest", Label: "Magistral Small (reasoning)", InputPer1k: 0.0005, OutputPer1k: 0.0015, ContextWindow: 40000}, + {ID: "devstral-medium-latest", Label: "Devstral Medium 2 (coding)", InputPer1k: 0.0004, OutputPer1k: 0.002, ContextWindow: 256000}, + {ID: "devstral-small-latest", Label: "Devstral Small 2 (coding)", InputPer1k: 0.0001, OutputPer1k: 0.0003, ContextWindow: 256000}, + {ID: "codestral-2508", Label: "Codestral 2508", InputPer1k: 0.0003, OutputPer1k: 0.0009, ContextWindow: 256000}, + {ID: "codestral-latest", Label: "Codestral (legacy 2405)", InputPer1k: 0.001, OutputPer1k: 0.003, ContextWindow: 32000}, + {ID: "ministral-3-14b-2512", Label: "Ministral 3 14B", InputPer1k: 0.0002, OutputPer1k: 0.0002, ContextWindow: 262144}, + {ID: "ministral-8b-latest", Label: "Ministral 8B", InputPer1k: 0.00015, OutputPer1k: 0.00015, ContextWindow: 262144}, + {ID: "ministral-3-3b-2512", Label: "Ministral 3 3B", InputPer1k: 0.0001, OutputPer1k: 0.0001, ContextWindow: 131072}, + {ID: "mistral-embed", Label: "Mistral Embed", InputPer1k: 0.0001, OutputPer1k: 0, ContextWindow: 8192}, + }, + }, + { + ID: "litellm_proxy", + Kind: KindGateway, + Name: "LiteLLM Proxy", + Description: "Bring your own LiteLLM proxy with NetBird identity stamped on every request", + DefaultHost: "", + AuthHeaderName: "Authorization", + AuthHeaderTemplate: "Bearer ${API_KEY}", + DefaultContentType: "application/json", + BrandColor: "#0EA5E9", + ParserID: "openai", + // IdentityInjection requires a LiteLLM virtual key minted with + // metadata.allow_client_tags=true; the master key silently drops + // caller tags. Tags go out via both the x-litellm-tags header and + // body metadata.tags: LiteLLM enforces budgets from the body only, + // so the header is the spend-tracking fallback when body injection + // can't run. See the Agent Network provider docs for key setup. + IdentityInjection: &IdentityInjection{ + HeaderPair: &HeaderPairInjection{ + EndUserIDHeader: "x-litellm-end-user-id", + TagsHeader: "x-litellm-tags", + TagsInBody: true, + EndUserIDInBody: true, + }, + }, + Models: []Model{}, + }, + { + ID: "portkey", + Kind: KindGateway, + Name: "Portkey AI Gateway", + Description: "Portkey AI Gateway with NetBird identity stamped via x-portkey-metadata", + DefaultHost: "api.portkey.ai", + // Portkey hosted requires x-portkey-api-key (account key) + // plus a routing decision per request. The simplest routing + // path is a saved Portkey config id stamped via + // x-portkey-config — operators paste the pc-... id once and + // Portkey resolves the upstream provider + virtual key from + // it. ExtraHeaders below surfaces the input. Alternative: + // callers author "@org/model" in the body; both flows + // coexist (per-request authoring still works without a + // configured value). + AuthHeaderName: "x-portkey-api-key", + AuthHeaderTemplate: "${API_KEY}", + DefaultContentType: "application/json", + BrandColor: "#FF5C00", + ParserID: "openai", + IdentityInjection: &IdentityInjection{ + JSONMetadata: &JSONMetadataInjection{ + Header: "x-portkey-metadata", + UserKey: "_user", + GroupsKey: "groups", + MaxValueLength: 128, + }, + }, + ExtraHeaders: []ExtraHeader{ + {Name: "x-portkey-config"}, + }, + Models: []Model{}, + }, + { + ID: "bifrost", + Kind: KindGateway, + Name: "Bifrost", + Description: "Maxim AI's Bifrost gateway. Point upstream URL at /openai/v1 or /anthropic/v1 on your Bifrost host depending on which body shape your apps use.", + DefaultHost: "", + AuthHeaderName: "Authorization", + AuthHeaderTemplate: "Bearer ${API_KEY}", + DefaultContentType: "application/json", + BrandColor: "#7C3AED", + // ParserID empty: the proxy's request parser sniffs the URL + // path. Bifrost's /openai/v1/... contains "/v1/chat/completions" + // (matches OpenAIParser.DetectFromURL); /anthropic/v1/messages + // contains "/v1/messages" (matches AnthropicParser). Operators + // who paste a different prefix get no usage parsing and the + // cost meter skips with skipMissingProvider — degraded but + // non-fatal. + ParserID: "", + // Identity-injection headers are operator-customisable. The + // HeaderPair values below are PLACEHOLDERS surfaced by the + // dashboard; the actual values stamped on the wire come from + // the provider record's IdentityHeaderUserID / + // IdentityHeaderGroups fields. An empty operator value + // disables stamping for that dimension (the inject middleware + // already no-ops on empty header names). Defaulting to the + // x-bf-dim- family so the values land in Bifrost's + // Prometheus/OTEL pipelines when the operator declares the + // label names in their client.prometheus_labels config — see + // docs.getbifrost.ai/features/telemetry. Operators who use + // the always-on x-bf-lh- log-metadata family (no Bifrost-side + // declaration required) just edit the inputs. + // + // Bifrost virtual keys (sk-bf-*) ride Authorization: Bearer. + // Operators provision the VK on their Bifrost (UI / + // config.json / POST /api/governance/virtual-keys) and paste + // the returned sk-bf-... as ${API_KEY}. Pin v1.4+ to avoid + // the v1.3.0 x-bf-vk regression (maximhq/bifrost#632). + IdentityInjection: &IdentityInjection{ + HeaderPair: &HeaderPairInjection{ + EndUserIDHeader: "x-bf-dim-netbird_user_id", + TagsHeader: "x-bf-dim-netbird_groups", + Customizable: true, + }, + }, + Models: []Model{}, + }, + { + ID: "cloudflare_ai_gateway", + Kind: KindGateway, + Name: "Cloudflare AI Gateway", + Description: "Cloudflare AI Gateway. Operator pastes the gateway URL (with the upstream provider slug like /openai or /anthropic so the URL sniffer dispatches to the right parser) and a per-gateway authentication token. Recommended setup is BYOK / Stored Keys: Cloudflare manages the upstream provider credential and the gateway token is the only secret NetBird needs.", + DefaultHost: "", + AuthHeaderName: "cf-aig-authorization", + AuthHeaderTemplate: "Bearer ${API_KEY}", + DefaultContentType: "application/json", + BrandColor: "#F38020", + // ParserID empty: like Bifrost, the proxy's parser-detect + // sniffs the URL path. /openai/... contains the OpenAI hint + // substrings; /anthropic/v1/messages contains /v1/messages + // (matches AnthropicParser). The /compat universal endpoint + // also speaks OpenAI shape so OpenAIParser handles it. + // Operators who paste a different prefix degrade to no-cost + // (skipMissingProvider) but the request still flows. + ParserID: "", + // cf-aig-metadata is a single header carrying a JSON object; + // up to five string/number/boolean values per request. NetBird + // occupies two slots (user id + groups CSV) and leaves three + // for operator-added context. JSON keys are operator- + // customisable so Cloudflare-side log filters can use the + // operator's existing label conventions instead of NetBird's + // defaults — hence Customizable=true. The dashboard surfaces + // the catalog values as placeholders; only the values stored + // on the provider record's IdentityHeader* fields land on the + // wire (empty operator value = key is omitted from the JSON, + // since applyJSONMetadata already skips empty keys). + IdentityInjection: &IdentityInjection{ + JSONMetadata: &JSONMetadataInjection{ + Header: "cf-aig-metadata", + UserKey: "netbird_user_id", + GroupsKey: "netbird_groups", + Customizable: true, + // Cloudflare's docs don't specify a per-value cap; + // leaving 0 disables the truncate path. Header-level + // constraint is "5 entries max" rather than length. + MaxValueLength: 0, + }, + }, + Models: []Model{}, + }, + { + ID: "vercel_ai_gateway", + Kind: KindGateway, + Name: "Vercel AI Gateway", + Description: "Vercel's unified API for hundreds of models. Single endpoint, OpenAI-compatible body, model dispatch via prefix (openai/..., anthropic/..., google/..., xai/...). Per-user / per-tag attribution lands in Vercel's Custom Reporting API and observability dashboard.", + DefaultHost: "", + AuthHeaderName: "Authorization", + AuthHeaderTemplate: "Bearer ${API_KEY}", + DefaultContentType: "application/json", + BrandColor: "#000000", + // Vercel always speaks OpenAI shape on /v1/chat/completions — + // the model prefix in the body picks the upstream provider. + // No URL sniffing needed; pin the parser directly. + ParserID: "openai", + // HeaderPair shape with fixed wire names dictated by Vercel's + // Custom Reporting API contract. Customizable=false because + // renaming the headers makes Vercel silently stop attributing + // — the gateway's reporting endpoint only matches its own + // header names. Same fixed-protocol position as LiteLLM. + // + // Caveats operators should know: + // - up to 10 tags total per request (deduped); 11+ → HTTP 400 + // - each tag must be 1-64 chars + // - user up to 256 chars (NetBird user emails fit) + // - $0.075 per 1k unique user/tag values written + // We don't enforce the caps in the inject middleware today; + // operators in groups beyond the 10-tag limit will see Vercel + // 400s and need to re-scope their group memberships. + IdentityInjection: &IdentityInjection{ + HeaderPair: &HeaderPairInjection{ + EndUserIDHeader: "ai-reporting-user", + TagsHeader: "ai-reporting-tags", + }, + }, + Models: []Model{}, + }, + { + ID: "openrouter", + Kind: KindGateway, + Name: "OpenRouter", + Description: "OpenRouter's unified API for hundreds of models. Single endpoint at openrouter.ai/api/v1, OpenAI-compatible body, model dispatch via prefix (anthropic/claude-..., openai/gpt-..., google/gemini-..., etc.). Per-user attribution lands in OpenRouter's analytics via the OpenAI-standard `user` body field; OpenRouter has no groups / tags dimension at request time.", + DefaultHost: "openrouter.ai/api/v1", + AuthHeaderName: "Authorization", + AuthHeaderTemplate: "Bearer ${API_KEY}", + DefaultContentType: "application/json", + BrandColor: "#6F4FF2", + // OpenRouter is single-endpoint OpenAI-shape on /api/v1/chat/completions — + // model prefix in the body picks the upstream provider. + // Pinning the parser saves URL sniffing. + ParserID: "openai", + // HeaderPair shape with EndUserIDInBody as the only active + // dimension. OpenRouter's per-user attribution is the + // OpenAI-standard `user` body field, not a header — and + // OpenRouter offers no per-request groups / tags dimension at + // all. Customizable=false because the field name is locked by + // OpenAI's spec; renaming would just defeat the inject. + IdentityInjection: &IdentityInjection{ + HeaderPair: &HeaderPairInjection{ + EndUserIDInBody: true, + }, + }, + // HTTP-Referer + X-OpenRouter-Title surface in OpenRouter's + // app rankings and per-app analytics. Operators paste their + // own app URL + display name on the provider record so their + // requests show under their brand instead of "no app". Both + // are static per-deployment, not per-request, hence the + // ExtraHeaders mechanism (operator-typed value, stamped on + // every request to this provider). Skip X-OpenRouter-Categories + // for now — the marketplace-categories dimension is + // niche-enough that we'd add it on demand. + ExtraHeaders: []ExtraHeader{ + {Name: "HTTP-Referer"}, + {Name: "X-OpenRouter-Title"}, + }, + Models: []Model{}, + }, + { + ID: "custom", + Kind: KindCustom, + Name: "Custom / Self-hosted", + Description: "OpenAI-compatible endpoint (vLLM, Ollama, …)", + DefaultHost: "", + AuthHeaderName: "Authorization", + AuthHeaderTemplate: "Bearer ${API_KEY}", + DefaultContentType: "application/json", + BrandColor: "#9CA3AF", + Models: []Model{}, + }, +} + +// All returns a copy of the full catalog. +func All() []Provider { + out := make([]Provider, len(providers)) + copy(out, providers) + return out +} + +// Lookup returns the catalog entry with the given id, if any. +func Lookup(id string) (Provider, bool) { + for _, p := range providers { + if p.ID == id { + return p, true + } + } + return Provider{}, false +} + +// IsKnown reports whether the given id refers to a catalog entry. +func IsKnown(id string) bool { + _, ok := Lookup(id) + return ok +} + +// IsVertexPathStyle reports whether a provider uses the Google Vertex AI +// request shape — the model is carried in the URL path +// (/v1/projects/{p}/locations/{r}/publishers/{pub}/models/{model}:{action}) +// rather than the body, so the proxy routes it by path instead of by model. +func IsVertexPathStyle(providerID string) bool { + return providerID == "vertex_ai_api" +} + +// IsBedrockPathStyle reports whether a provider uses the AWS Bedrock request +// shape — the model is carried in the URL path (/model/{modelId}/{action}, +// action being invoke, invoke-with-response-stream, converse, or +// converse-stream) rather than the body, so the proxy routes it by path. +func IsBedrockPathStyle(providerID string) bool { + return providerID == "bedrock_api" +} + +// ToAPIResponse renders a catalog provider as the API representation. +func (p Provider) ToAPIResponse() api.AgentNetworkCatalogProvider { + models := make([]api.AgentNetworkCatalogModel, 0, len(p.Models)) + for _, m := range p.Models { + models = append(models, api.AgentNetworkCatalogModel{ + Id: m.ID, + Label: m.Label, + InputPer1k: m.InputPer1k, + OutputPer1k: m.OutputPer1k, + ContextWindow: m.ContextWindow, + }) + } + kind := api.AgentNetworkCatalogProviderKindProvider + switch p.Kind { + case KindGateway: + kind = api.AgentNetworkCatalogProviderKindGateway + case KindCustom: + kind = api.AgentNetworkCatalogProviderKindCustom + } + resp := api.AgentNetworkCatalogProvider{ + Id: p.ID, + Name: p.Name, + Description: p.Description, + DefaultHost: p.DefaultHost, + Kind: kind, + AuthHeaderTemplate: p.AuthHeaderTemplate, + DefaultContentType: p.DefaultContentType, + BrandColor: p.BrandColor, + Models: models, + } + if len(p.ExtraHeaders) > 0 { + extras := make([]api.AgentNetworkCatalogExtraHeader, 0, len(p.ExtraHeaders)) + for _, h := range p.ExtraHeaders { + extras = append(extras, api.AgentNetworkCatalogExtraHeader{ + Name: h.Name, + }) + } + resp.ExtraHeaders = &extras + } + // Surface IdentityInjection so the dashboard can decide whether + // to render editable inputs vs. a read-only mappings strip per + // shape's customizable flag. HeaderPair (Bifrost) and + // JSONMetadata (Cloudflare, Portkey) are mutually exclusive on a + // given catalog entry; emit whichever shape is set. + if p.IdentityInjection != nil { + injection := &api.AgentNetworkCatalogIdentityInjection{} + if hp := p.IdentityInjection.HeaderPair; hp != nil { + injection.HeaderPair = &api.AgentNetworkCatalogHeaderPairInjection{ + Customizable: hp.Customizable, + EndUserIdHeader: hp.EndUserIDHeader, + TagsHeader: hp.TagsHeader, + } + } + if jm := p.IdentityInjection.JSONMetadata; jm != nil { + injection.JsonMetadata = &api.AgentNetworkCatalogJSONMetadataInjection{ + Customizable: jm.Customizable, + Header: jm.Header, + UserKey: jm.UserKey, + GroupsKey: jm.GroupsKey, + } + } + if injection.HeaderPair != nil || injection.JsonMetadata != nil { + resp.IdentityInjection = injection + } + } + return resp +} diff --git a/management/internals/modules/agentnetwork/handlers/access_log_handler.go b/management/internals/modules/agentnetwork/handlers/access_log_handler.go new file mode 100644 index 000000000..4484d8c91 --- /dev/null +++ b/management/internals/modules/agentnetwork/handlers/access_log_handler.go @@ -0,0 +1,134 @@ +package handlers + +import ( + "net/http" + "time" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +// addAccessLogEndpoints registers the read-only, server-side-filtered +// agent-network access-log listing and the aggregated usage overview. +func (h *handler) addAccessLogEndpoints(router *mux.Router) { + router.HandleFunc("/agent-network/access-logs", h.listAccessLogs).Methods("GET", "OPTIONS") + router.HandleFunc("/agent-network/access-log-sessions", h.listAccessLogSessions).Methods("GET", "OPTIONS") + router.HandleFunc("/agent-network/usage/overview", h.getUsageOverview).Methods("GET", "OPTIONS") +} + +func (h *handler) getUsageOverview(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + // Reuse the access-log filter for the shared date/user/group/provider/model + // params; pagination/sort/search are irrelevant for an aggregate. + var filter types.AgentNetworkAccessLogFilter + if err := filter.ParseFromRequest(r); err != nil { + util.WriteError(r.Context(), err, w) + return + } + // Bound the aggregation window so an unbounded or over-wide query can't load + // an account's entire usage history into memory. + filter.ApplyUsageOverviewBounds(time.Now()) + granularity := types.ParseUsageGranularity(r.URL.Query().Get("granularity")) + + buckets, err := h.manager.GetUsageOverview(r.Context(), userAuth.AccountId, userAuth.UserId, filter, granularity) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + out := make([]api.AgentNetworkUsageBucket, 0, len(buckets)) + for _, b := range buckets { + out = append(out, b.ToAPIResponse()) + } + util.WriteJSONObject(r.Context(), w, out) +} + +func (h *handler) listAccessLogs(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var filter types.AgentNetworkAccessLogFilter + if err := filter.ParseFromRequest(r); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + rows, total, err := h.manager.ListAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId, filter) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + data := make([]api.AgentNetworkAccessLog, 0, len(rows)) + for _, row := range rows { + data = append(data, row.ToAPIResponse()) + } + + pageSize := filter.GetLimit() + totalPages := 0 + if pageSize > 0 { + totalPages = int((total + int64(pageSize) - 1) / int64(pageSize)) + } + + util.WriteJSONObject(r.Context(), w, api.AgentNetworkAccessLogsResponse{ + Data: data, + Page: filter.Page, + PageSize: pageSize, + TotalRecords: int(total), + TotalPages: totalPages, + }) +} + +// listAccessLogSessions returns the access logs grouped by session: the page +// unit is a session (total counts sessions), each carrying an aggregate summary +// and its ordered entries. Accepts the same filters as listAccessLogs. +func (h *handler) listAccessLogSessions(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var filter types.AgentNetworkAccessLogFilter + if err := filter.ParseFromRequest(r); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + sessions, total, err := h.manager.ListAccessLogSessions(r.Context(), userAuth.AccountId, userAuth.UserId, filter) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + data := make([]api.AgentNetworkAccessLogSession, 0, len(sessions)) + for _, sess := range sessions { + data = append(data, sess.ToAPIResponse()) + } + + pageSize := filter.GetLimit() + totalPages := 0 + if pageSize > 0 { + totalPages = int((total + int64(pageSize) - 1) / int64(pageSize)) + } + + util.WriteJSONObject(r.Context(), w, api.AgentNetworkAccessLogSessionsResponse{ + Data: data, + Page: filter.Page, + PageSize: pageSize, + TotalRecords: int(total), + TotalPages: totalPages, + }) +} diff --git a/management/internals/modules/agentnetwork/handlers/budget_handler.go b/management/internals/modules/agentnetwork/handlers/budget_handler.go new file mode 100644 index 000000000..5630de17f --- /dev/null +++ b/management/internals/modules/agentnetwork/handlers/budget_handler.go @@ -0,0 +1,172 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "strings" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" +) + +// addBudgetRuleEndpoints registers the account-level budget rule routes. +func (h *handler) addBudgetRuleEndpoints(router *mux.Router) { + router.HandleFunc("/agent-network/budget-rules", h.getAllBudgetRules).Methods("GET", "OPTIONS") + router.HandleFunc("/agent-network/budget-rules", h.createBudgetRule).Methods("POST", "OPTIONS") + router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.getBudgetRule).Methods("GET", "OPTIONS") + router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.updateBudgetRule).Methods("PUT", "OPTIONS") + router.HandleFunc("/agent-network/budget-rules/{ruleId}", h.deleteBudgetRule).Methods("DELETE", "OPTIONS") +} + +func (h *handler) getAllBudgetRules(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + rules, err := h.manager.GetAllBudgetRules(r.Context(), userAuth.AccountId, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + out := make([]*api.AgentNetworkBudgetRule, 0, len(rules)) + for _, rule := range rules { + out = append(out, rule.ToAPIResponse()) + } + util.WriteJSONObject(r.Context(), w, out) +} + +func (h *handler) getBudgetRule(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ruleID := mux.Vars(r)["ruleId"] + if ruleID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w) + return + } + + rule, err := h.manager.GetBudgetRule(r.Context(), userAuth.AccountId, userAuth.UserId, ruleID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + util.WriteJSONObject(r.Context(), w, rule.ToAPIResponse()) +} + +func (h *handler) createBudgetRule(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.AgentNetworkBudgetRuleRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + if err := validateBudgetRule(&req); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + rule := types.NewAccountBudgetRule(userAuth.AccountId) + rule.FromAPIRequest(&req) + + created, err := h.manager.CreateBudgetRule(r.Context(), userAuth.UserId, rule) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + util.WriteJSONObject(r.Context(), w, created.ToAPIResponse()) +} + +func (h *handler) updateBudgetRule(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ruleID := mux.Vars(r)["ruleId"] + if ruleID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w) + return + } + + var req api.AgentNetworkBudgetRuleRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + if err := validateBudgetRule(&req); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + rule := &types.AccountBudgetRule{ID: ruleID, AccountID: userAuth.AccountId} + rule.FromAPIRequest(&req) + + updated, err := h.manager.UpdateBudgetRule(r.Context(), userAuth.UserId, rule) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse()) +} + +func (h *handler) deleteBudgetRule(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ruleID := mux.Vars(r)["ruleId"] + if ruleID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "budget rule ID is required"), w) + return + } + + if err := h.manager.DeleteBudgetRule(r.Context(), userAuth.AccountId, userAuth.UserId, ruleID); err != nil { + util.WriteError(r.Context(), err, w) + return + } + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} + +// validateBudgetRule rejects malformed budget rules. It reuses the policy limit +// validation since the cap shape is identical, and rejects empty target entries. +func validateBudgetRule(req *api.AgentNetworkBudgetRuleRequest) error { + if strings.TrimSpace(req.Name) == "" { + return status.Errorf(status.InvalidArgument, "name is required") + } + if req.TargetGroups != nil { + for _, id := range *req.TargetGroups { + if strings.TrimSpace(id) == "" { + return status.Errorf(status.InvalidArgument, "target_groups must not contain empty entries") + } + } + } + if req.TargetUsers != nil { + for _, id := range *req.TargetUsers { + if strings.TrimSpace(id) == "" { + return status.Errorf(status.InvalidArgument, "target_users must not contain empty entries") + } + } + } + return validatePolicyLimits(req.Limits) +} diff --git a/management/internals/modules/agentnetwork/handlers/budget_handler_test.go b/management/internals/modules/agentnetwork/handlers/budget_handler_test.go new file mode 100644 index 000000000..4038761c5 --- /dev/null +++ b/management/internals/modules/agentnetwork/handlers/budget_handler_test.go @@ -0,0 +1,131 @@ +package handlers + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// TestBudgetRuleHandler_RoundTrip seeds a budget rule via the store and asserts +// the GET wire shape carries targets and the reused PolicyLimits cap shape. The +// create/update/delete success paths go through accountManager.StoreEvent which +// this fixture doesn't wire — they are covered by the manager-level no-mock +// test (TestAgentNetwork_BudgetRuleCRUD_RealManager). +func TestBudgetRuleHandler_RoundTrip(t *testing.T) { + f := newAgentNetworkHandlerFixture(t) + + rule := &agentNetworkTypes.AccountBudgetRule{ + ID: "ainbud_test", + AccountID: testAccountID, + Name: "org-monthly", + Enabled: true, + TargetGroups: []string{"grp-eng"}, + TargetUsers: []string{"user-alice"}, + Limits: agentNetworkTypes.PolicyLimits{ + TokenLimit: agentNetworkTypes.PolicyTokenLimit{Enabled: true, GroupCap: 100000, UserCap: 10000, WindowSeconds: 2_592_000}, + BudgetLimit: agentNetworkTypes.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 500, WindowSeconds: 2_592_000}, + }, + } + require.NoError(t, f.store.SaveAgentNetworkBudgetRule(context.Background(), rule)) + + rec := f.do(t, http.MethodGet, "/agent-network/budget-rules/"+rule.ID, "") + require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String()) + + var got api.AgentNetworkBudgetRule + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + assert.Equal(t, "org-monthly", got.Name, "name must round-trip") + assert.Equal(t, []string{"grp-eng"}, got.TargetGroups, "target groups must round-trip") + assert.Equal(t, []string{"user-alice"}, got.TargetUsers, "target users must round-trip") + assert.Equal(t, int64(100000), got.Limits.TokenLimit.GroupCap, "token group cap must round-trip") + assert.Equal(t, int64(2_592_000), got.Limits.BudgetLimit.WindowSeconds, "budget window must round-trip") +} + +// TestBudgetRuleHandler_ListReturnsArray asserts the list endpoint returns a +// JSON array (never null) for an account with no rules. +func TestBudgetRuleHandler_ListReturnsArray(t *testing.T) { + f := newAgentNetworkHandlerFixture(t) + + rec := f.do(t, http.MethodGet, "/agent-network/budget-rules", "") + require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String()) + assert.Equal(t, "[]", trimSpace(rec.Body.String()), "empty account must return an empty array, not null") +} + +// TestBudgetRuleHandler_RejectsMissingName covers the validation path (which +// runs before the manager call, so it works without a wired accountManager). +func TestBudgetRuleHandler_RejectsMissingName(t *testing.T) { + f := newAgentNetworkHandlerFixture(t) + + body := `{ + "name": "", + "limits": { + "token_limit": {"enabled": false, "group_cap": 0, "user_cap": 0, "window_seconds": 0}, + "budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0} + } + }` + rec := f.do(t, http.MethodPost, "/agent-network/budget-rules", body) + assert.Equal(t, http.StatusUnprocessableEntity, rec.Code, + "missing name must be rejected as a validation error (not a route/auth 4xx): got %d body=%s", rec.Code, rec.Body.String()) + assert.Contains(t, rec.Body.String(), "name", + "rejection body must name the offending field, proving the validation path: %s", rec.Body.String()) +} + +// TestBudgetRuleHandler_RejectsSubMinuteWindow proves budget rules reuse the +// policy-limit validation (enabled limit needs window >= 60s). +func TestBudgetRuleHandler_RejectsSubMinuteWindow(t *testing.T) { + f := newAgentNetworkHandlerFixture(t) + + body := `{ + "name": "bad-window", + "limits": { + "token_limit": {"enabled": true, "group_cap": 1000, "user_cap": 0, "window_seconds": 30}, + "budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0} + } + }` + rec := f.do(t, http.MethodPost, "/agent-network/budget-rules", body) + assert.Equal(t, http.StatusUnprocessableEntity, rec.Code, + "sub-minute window must be rejected as a validation error (not a route/auth 4xx): got %d body=%s", rec.Code, rec.Body.String()) + assert.Contains(t, rec.Body.String(), "window_seconds", + "rejection body must name the offending window_seconds field, proving the validation path: %s", rec.Body.String()) +} + +// TestSettingsHandler_GetExposesCollectionToggles asserts the GET settings wire +// shape carries the account-level collection toggles after a store seed. +func TestSettingsHandler_GetExposesCollectionToggles(t *testing.T) { + f := newAgentNetworkHandlerFixture(t) + + require.NoError(t, f.store.SaveAgentNetworkSettings(context.Background(), &agentNetworkTypes.Settings{ + AccountID: testAccountID, + Cluster: "eu.proxy.netbird.io", + Subdomain: "violet", + EnableLogCollection: true, + EnablePromptCollection: true, + RedactPii: false, + })) + + rec := f.do(t, http.MethodGet, "/agent-network/settings", "") + require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String()) + + var got api.AgentNetworkSettings + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + assert.True(t, got.EnableLogCollection, "log collection toggle must surface on the wire") + assert.True(t, got.EnablePromptCollection, "prompt collection toggle must surface on the wire") + assert.False(t, got.RedactPii, "redact toggle must surface its false value") + assert.Equal(t, "violet.eu.proxy.netbird.io", got.Endpoint, "endpoint stays computed from immutable cluster+subdomain") +} + +func trimSpace(s string) string { + for len(s) > 0 && (s[len(s)-1] == '\n' || s[len(s)-1] == ' ' || s[len(s)-1] == '\t' || s[len(s)-1] == '\r') { + s = s[:len(s)-1] + } + for len(s) > 0 && (s[0] == '\n' || s[0] == ' ' || s[0] == '\t' || s[0] == '\r') { + s = s[1:] + } + return s +} diff --git a/management/internals/modules/agentnetwork/handlers/consumption_handler.go b/management/internals/modules/agentnetwork/handlers/consumption_handler.go new file mode 100644 index 000000000..654f23109 --- /dev/null +++ b/management/internals/modules/agentnetwork/handlers/consumption_handler.go @@ -0,0 +1,53 @@ +package handlers + +import ( + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +// addConsumptionEndpoints registers the read-only Agent Network +// consumption listing — backs the dashboard's basic counter view. +func (h *handler) addConsumptionEndpoints(router *mux.Router) { + router.HandleFunc("/agent-network/consumption", h.listConsumption).Methods("GET", "OPTIONS") +} + +func (h *handler) listConsumption(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + rows, err := h.manager.ListConsumption(r.Context(), userAuth.AccountId, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + out := make([]api.AgentNetworkConsumption, 0, len(rows)) + for _, row := range rows { + out = append(out, consumptionToAPI(row)) + } + util.WriteJSONObject(r.Context(), w, out) +} + +func consumptionToAPI(c *types.Consumption) api.AgentNetworkConsumption { + windowStart := c.WindowStartUTC + updatedAt := c.UpdatedAt + return api.AgentNetworkConsumption{ + DimensionKind: api.AgentNetworkConsumptionDimensionKind(c.DimensionKind), + DimensionId: c.DimensionID, + WindowSeconds: c.WindowSeconds, + WindowStartUtc: windowStart, + TokensInput: c.TokensInput, + TokensOutput: c.TokensOutput, + CostUsd: c.CostUSD, + UpdatedAt: &updatedAt, + } +} diff --git a/management/internals/modules/agentnetwork/handlers/guardrails_handler.go b/management/internals/modules/agentnetwork/handlers/guardrails_handler.go new file mode 100644 index 000000000..81f19b9f1 --- /dev/null +++ b/management/internals/modules/agentnetwork/handlers/guardrails_handler.go @@ -0,0 +1,171 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "strings" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" +) + +// addGuardrailEndpoints registers all Agent Network guardrail routes. +func (h *handler) addGuardrailEndpoints(router *mux.Router) { + router.HandleFunc("/agent-network/guardrails", h.getAllGuardrails).Methods("GET", "OPTIONS") + router.HandleFunc("/agent-network/guardrails", h.createGuardrail).Methods("POST", "OPTIONS") + router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.getGuardrail).Methods("GET", "OPTIONS") + router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.updateGuardrail).Methods("PUT", "OPTIONS") + router.HandleFunc("/agent-network/guardrails/{guardrailId}", h.deleteGuardrail).Methods("DELETE", "OPTIONS") +} + +func (h *handler) getAllGuardrails(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + guardrails, err := h.manager.GetAllGuardrails(r.Context(), userAuth.AccountId, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + out := make([]*api.AgentNetworkGuardrail, 0, len(guardrails)) + for _, g := range guardrails { + out = append(out, g.ToAPIResponse()) + } + util.WriteJSONObject(r.Context(), w, out) +} + +func (h *handler) getGuardrail(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + guardrailID := mux.Vars(r)["guardrailId"] + if guardrailID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w) + return + } + + guardrail, err := h.manager.GetGuardrail(r.Context(), userAuth.AccountId, userAuth.UserId, guardrailID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + util.WriteJSONObject(r.Context(), w, guardrail.ToAPIResponse()) +} + +func (h *handler) createGuardrail(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.AgentNetworkGuardrailRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + if err := validateGuardrail(&req); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + guardrail := types.NewGuardrail(userAuth.AccountId) + guardrail.FromAPIRequest(&req) + + created, err := h.manager.CreateGuardrail(r.Context(), userAuth.UserId, guardrail) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, created.ToAPIResponse()) +} + +func (h *handler) updateGuardrail(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + guardrailID := mux.Vars(r)["guardrailId"] + if guardrailID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w) + return + } + + var req api.AgentNetworkGuardrailRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + if err := validateGuardrail(&req); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + guardrail := &types.Guardrail{ + ID: guardrailID, + AccountID: userAuth.AccountId, + } + guardrail.FromAPIRequest(&req) + + updated, err := h.manager.UpdateGuardrail(r.Context(), userAuth.UserId, guardrail) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse()) +} + +func (h *handler) deleteGuardrail(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + guardrailID := mux.Vars(r)["guardrailId"] + if guardrailID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "guardrail ID is required"), w) + return + } + + if err := h.manager.DeleteGuardrail(r.Context(), userAuth.AccountId, userAuth.UserId, guardrailID); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} + +func validateGuardrail(req *api.AgentNetworkGuardrailRequest) error { + if strings.TrimSpace(req.Name) == "" { + return status.Errorf(status.InvalidArgument, "name is required") + } + + c := req.Checks + if c.ModelAllowlist.Enabled { + for _, id := range c.ModelAllowlist.Models { + if strings.TrimSpace(id) == "" { + return status.Errorf(status.InvalidArgument, "model_allowlist.models must not contain empty entries") + } + } + } + return nil +} diff --git a/management/internals/modules/agentnetwork/handlers/handlers_test.go b/management/internals/modules/agentnetwork/handlers/handlers_test.go new file mode 100644 index 000000000..27ebea5dd --- /dev/null +++ b/management/internals/modules/agentnetwork/handlers/handlers_test.go @@ -0,0 +1,256 @@ +package handlers + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "runtime" + "strings" + "testing" + + "github.com/golang/mock/gomock" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork" + agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" + nbtypes "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +const ( + testAccountID = "acc-1" + testUserID = "user-bob" +) + +// agentNetworkHandlerFixture builds a real agentnetwork.Manager with +// a sqlite store and an always-allow permissions mock, then exposes +// the HTTP handlers via a gorilla router. Tests issue requests +// through httptest and assert on the wire shape — the same path the +// dashboard exercises. +type agentNetworkHandlerFixture struct { + store store.Store + manager agentnetwork.Manager + router *mux.Router +} + +func newAgentNetworkHandlerFixture(t *testing.T) *agentNetworkHandlerFixture { + t.Helper() + if runtime.GOOS == "windows" { + t.Skip("sqlite store not properly supported on Windows yet") + } + t.Setenv("NETBIRD_STORE_ENGINE", string(nbtypes.SqliteStoreEngine)) + + st, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + require.NoError(t, err) + t.Cleanup(cleanUp) + + ctrl := gomock.NewController(t) + perms := permissions.NewMockManager(ctrl) + // Always-allow: the handler tests are about wire shape, not + // authz. Authz is covered by the manager's own tests. + perms.EXPECT(). + ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(true, context.Background(), nil). + AnyTimes() + + manager := agentnetwork.NewManager(st, perms, nil, nil) + h := &handler{manager: manager} + + router := mux.NewRouter() + h.addPolicyEndpoints(router) + h.addConsumptionEndpoints(router) + h.addBudgetRuleEndpoints(router) + h.addSettingsEndpoints(router) + + return &agentNetworkHandlerFixture{ + store: st, + manager: manager, + router: router, + } +} + +func (f *agentNetworkHandlerFixture) do(t *testing.T, method, path, body string) *httptest.ResponseRecorder { + t.Helper() + var reader io.Reader + if body != "" { + reader = strings.NewReader(body) + } + req := httptest.NewRequest(method, path, reader) + if body != "" { + req.Header.Set("Content-Type", "application/json") + } + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + }) + rec := httptest.NewRecorder() + f.router.ServeHTTP(rec, req) + return rec +} + +// seedProvider persists a minimal provider record so policy create +// passes the manager's destination_provider_ids existence check. +func (f *agentNetworkHandlerFixture) seedProvider(t *testing.T, id string) { + t.Helper() + require.NoError(t, f.store.SaveAgentNetworkProvider(context.Background(), &agentNetworkTypes.Provider{ + ID: id, + AccountID: testAccountID, + ProviderID: "openai_api", + Name: "test-" + id, + UpstreamURL: "https://api.openai.com", + APIKey: "sk-test", + Enabled: true, + SessionPrivateKey: "test-priv-key", + SessionPublicKey: "test-pub-key", + })) +} + +// TestPolicyHandler_WindowSecondsRoundTrip ports bash 10 to Go: +// assert that a policy with window_seconds on both Token + Budget +// halves round-trips through GET unchanged AND that legacy +// window_hours / window_days are absent from the JSON response. We +// seed the policy directly via the store rather than POST-ing +// because the create path goes through the manager's +// accountManager.StoreEvent which we don't wire in this fixture; the +// on-wire shape is what matters here, and the POST validation path +// is covered separately by the RejectsSubMinuteWindow test. +func TestPolicyHandler_WindowSecondsRoundTrip(t *testing.T) { + f := newAgentNetworkHandlerFixture(t) + + policy := &agentNetworkTypes.Policy{ + ID: "ainpol_test", + AccountID: testAccountID, + Name: "round-trip", + Enabled: true, + SourceGroups: []string{"grp-engineers"}, + DestinationProviderIDs: []string{"prov-1"}, + Limits: agentNetworkTypes.PolicyLimits{ + TokenLimit: agentNetworkTypes.PolicyTokenLimit{Enabled: true, GroupCap: 10000, UserCap: 5000, WindowSeconds: 86_400}, + BudgetLimit: agentNetworkTypes.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 10.0, UserCapUsd: 2.5, WindowSeconds: 2_592_000}, + }, + } + require.NoError(t, f.store.SaveAgentNetworkPolicy(context.Background(), policy)) + + rec := f.do(t, http.MethodGet, "/agent-network/policies/"+policy.ID, "") + require.Equal(t, http.StatusOK, rec.Code, "GET must succeed: %s", rec.Body.String()) + + var got api.AgentNetworkPolicy + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + assert.Equal(t, int64(86_400), got.Limits.TokenLimit.WindowSeconds, "token_limit.window_seconds must round-trip") + assert.Equal(t, int64(2_592_000), got.Limits.BudgetLimit.WindowSeconds, "budget_limit.window_seconds must round-trip") + + // Legacy field names must NOT appear in the response — would + // signal that the management server is still emitting the old + // shape and would fool a v1 dashboard into rendering days/hours. + assert.NotContains(t, rec.Body.String(), "window_hours", + "legacy window_hours field must be absent from the on-wire response") + assert.NotContains(t, rec.Body.String(), "window_days", + "legacy window_days field must be absent from the on-wire response") +} + +// TestPolicyHandler_RejectsSubMinuteWindow ports bash 20 to Go: an +// enabled limit with window_seconds < 60 must surface as a 4xx +// because anything finer than per-minute produces an untenable +// volume of consumption rows for a feature whose value comes from +// per-window cap enforcement. +func TestPolicyHandler_RejectsSubMinuteWindow(t *testing.T) { + f := newAgentNetworkHandlerFixture(t) + f.seedProvider(t, "prov-1") + + body := `{ + "name": "sub-minute-window", + "enabled": true, + "source_groups": ["grp-engineers"], + "destination_provider_ids": ["prov-1"], + "guardrail_ids": [], + "limits": { + "token_limit": {"enabled": true, "group_cap": 10000, "user_cap": 5000, "window_seconds": 30}, + "budget_limit": {"enabled": false, "group_cap_usd": 0, "user_cap_usd": 0, "window_seconds": 0} + } + }` + rec := f.do(t, http.MethodPost, "/agent-network/policies", body) + // 422 specifically (InvalidArgument) proves the window-validation path — + // a route miss would be 404 and an auth failure 403, so a generic 4xx + // would let those false-pass. + assert.Equal(t, http.StatusUnprocessableEntity, rec.Code, + "enabled token_limit with window_seconds<60 must be rejected as a validation error: got %d body=%s", rec.Code, rec.Body.String()) + assert.Contains(t, rec.Body.String(), "window_seconds", + "rejection body must name the offending window_seconds field, proving it's the validation path: %s", rec.Body.String()) +} + +// TestConsumptionHandler_EmptyAccountReturnsArray ports bash 30 to +// Go: GET /agent-network/consumption on a clean account always +// returns a JSON array (possibly empty), never a 404 / 500. The +// dashboard depends on this shape to render its empty state. +func TestConsumptionHandler_EmptyAccountReturnsArray(t *testing.T) { + f := newAgentNetworkHandlerFixture(t) + + rec := f.do(t, http.MethodGet, "/agent-network/consumption", "") + require.Equal(t, http.StatusOK, rec.Code) + + var rows []api.AgentNetworkConsumption + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rows), + "response must always be a JSON array — even when empty: %s", rec.Body.String()) + assert.Empty(t, rows) +} + +// TestConsumptionHandler_PopulatedAccountListsRows mirrors the +// /consumption read after a few RecordConsumption calls. Validates +// the wire shape carries every field the dashboard reads (dim_kind, +// dim_id, window_seconds, window_start_utc, tokens, cost_usd) and +// rows are ordered window-newest-first. +func TestConsumptionHandler_PopulatedAccountListsRows(t *testing.T) { + f := newAgentNetworkHandlerFixture(t) + + require.NoError(t, f.manager.RecordConsumption( + context.Background(), testAccountID, + agentNetworkTypes.DimensionGroup, "grp-engineers", + 86_400, 100, 50, 0.0125, + )) + require.NoError(t, f.manager.RecordConsumption( + context.Background(), testAccountID, + agentNetworkTypes.DimensionUser, testUserID, + 86_400, 100, 50, 0.0125, + )) + + rec := f.do(t, http.MethodGet, "/agent-network/consumption", "") + require.Equal(t, http.StatusOK, rec.Code) + + var rows []api.AgentNetworkConsumption + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rows)) + require.Len(t, rows, 2, "two RecordConsumption calls must yield two rows") + + // Index by dim_kind so we can assert the full wire shape of each row, + // including the dimension id and the aligned window start the dashboard + // keys on. Both rows share totals and window. + byKind := make(map[string]api.AgentNetworkConsumption, len(rows)) + for _, row := range rows { + assert.Equal(t, int64(100), row.TokensInput) + assert.Equal(t, int64(50), row.TokensOutput) + assert.InDelta(t, 0.0125, row.CostUsd, 1e-9) + assert.Equal(t, int64(86_400), row.WindowSeconds) + assert.False(t, row.WindowStartUtc.IsZero(), "window_start_utc must be set on every row") + byKind[string(row.DimensionKind)] = row + } + + groupRow, ok := byKind["group"] + require.True(t, ok, "group dimension must surface") + assert.Equal(t, "grp-engineers", groupRow.DimensionId, "group row must carry the source group id as dimension_id") + + userRow, ok := byKind["user"] + require.True(t, ok, "user dimension must surface") + assert.Equal(t, testUserID, userRow.DimensionId, "user row must carry the user id as dimension_id") + + // Both rows fall in the same aligned window (same length, recorded + // together), so window_start_utc must match across them. + assert.Equal(t, groupRow.WindowStartUtc, userRow.WindowStartUtc, + "rows recorded in the same window must share the aligned window_start_utc") +} diff --git a/management/internals/modules/agentnetwork/handlers/policies_handler.go b/management/internals/modules/agentnetwork/handlers/policies_handler.go new file mode 100644 index 000000000..b821a5295 --- /dev/null +++ b/management/internals/modules/agentnetwork/handlers/policies_handler.go @@ -0,0 +1,228 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "strings" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" +) + +// minWindowSeconds is the floor enforced on enabled token / budget +// limit windows. One minute is short enough for fine-grained burst +// control without producing untenable consumption-row volume at scale. +const minWindowSeconds int64 = 60 + +// addPolicyEndpoints registers all Agent Network policy routes on the +// shared handler. +func (h *handler) addPolicyEndpoints(router *mux.Router) { + router.HandleFunc("/agent-network/policies", h.getAllPolicies).Methods("GET", "OPTIONS") + router.HandleFunc("/agent-network/policies", h.createPolicy).Methods("POST", "OPTIONS") + router.HandleFunc("/agent-network/policies/{policyId}", h.getPolicy).Methods("GET", "OPTIONS") + router.HandleFunc("/agent-network/policies/{policyId}", h.updatePolicy).Methods("PUT", "OPTIONS") + router.HandleFunc("/agent-network/policies/{policyId}", h.deletePolicy).Methods("DELETE", "OPTIONS") +} + +func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policies, err := h.manager.GetAllPolicies(r.Context(), userAuth.AccountId, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + out := make([]*api.AgentNetworkPolicy, 0, len(policies)) + for _, p := range policies { + out = append(out, p.ToAPIResponse()) + } + util.WriteJSONObject(r.Context(), w, out) +} + +func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyID := mux.Vars(r)["policyId"] + if policyID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w) + return + } + + policy, err := h.manager.GetPolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policyID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + util.WriteJSONObject(r.Context(), w, policy.ToAPIResponse()) +} + +func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.AgentNetworkPolicyRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + if err := validatePolicy(&req); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policy := types.NewPolicy(userAuth.AccountId) + policy.FromAPIRequest(&req) + + created, err := h.manager.CreatePolicy(r.Context(), userAuth.UserId, policy) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, created.ToAPIResponse()) +} + +func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyID := mux.Vars(r)["policyId"] + if policyID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w) + return + } + + var req api.AgentNetworkPolicyRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + if err := validatePolicy(&req); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policy := &types.Policy{ + ID: policyID, + AccountID: userAuth.AccountId, + } + policy.FromAPIRequest(&req) + + updated, err := h.manager.UpdatePolicy(r.Context(), userAuth.UserId, policy) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse()) +} + +func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyID := mux.Vars(r)["policyId"] + if policyID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy ID is required"), w) + return + } + + if err := h.manager.DeletePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policyID); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} + +func validatePolicy(req *api.AgentNetworkPolicyRequest) error { + if strings.TrimSpace(req.Name) == "" { + return status.Errorf(status.InvalidArgument, "name is required") + } + if len(req.SourceGroups) == 0 { + return status.Errorf(status.InvalidArgument, "source_groups must contain at least one group id") + } + for _, id := range req.SourceGroups { + if strings.TrimSpace(id) == "" { + return status.Errorf(status.InvalidArgument, "source_groups must not contain empty entries") + } + } + if len(req.DestinationProviderIds) == 0 { + return status.Errorf(status.InvalidArgument, "destination_provider_ids must contain at least one provider id") + } + for _, id := range req.DestinationProviderIds { + if strings.TrimSpace(id) == "" { + return status.Errorf(status.InvalidArgument, "destination_provider_ids must not contain empty entries") + } + } + if req.GuardrailIds != nil { + for _, id := range *req.GuardrailIds { + if strings.TrimSpace(id) == "" { + return status.Errorf(status.InvalidArgument, "guardrail_ids must not contain empty entries") + } + } + } + if req.Limits != nil { + if err := validatePolicyLimits(*req.Limits); err != nil { + return err + } + } + return nil +} + +func validatePolicyLimits(l api.AgentNetworkPolicyLimits) error { + if l.TokenLimit.Enabled { + if l.TokenLimit.WindowSeconds < minWindowSeconds { + return status.Errorf(status.InvalidArgument, "limits.token_limit.window_seconds must be at least %d (one minute) when enabled", minWindowSeconds) + } + if l.TokenLimit.GroupCap < 0 { + return status.Errorf(status.InvalidArgument, "limits.token_limit.group_cap must not be negative") + } + if l.TokenLimit.UserCap < 0 { + return status.Errorf(status.InvalidArgument, "limits.token_limit.user_cap must not be negative") + } + if l.TokenLimit.GroupCap == 0 && l.TokenLimit.UserCap == 0 { + return status.Errorf(status.InvalidArgument, "limits.token_limit requires group_cap or user_cap to be greater than zero when enabled") + } + } + if l.BudgetLimit.Enabled { + if l.BudgetLimit.WindowSeconds < minWindowSeconds { + return status.Errorf(status.InvalidArgument, "limits.budget_limit.window_seconds must be at least %d (one minute) when enabled", minWindowSeconds) + } + if l.BudgetLimit.GroupCapUsd < 0 { + return status.Errorf(status.InvalidArgument, "limits.budget_limit.group_cap_usd must not be negative") + } + if l.BudgetLimit.UserCapUsd < 0 { + return status.Errorf(status.InvalidArgument, "limits.budget_limit.user_cap_usd must not be negative") + } + if l.BudgetLimit.GroupCapUsd == 0 && l.BudgetLimit.UserCapUsd == 0 { + return status.Errorf(status.InvalidArgument, "limits.budget_limit requires group_cap_usd or user_cap_usd to be greater than zero when enabled") + } + } + return nil +} diff --git a/management/internals/modules/agentnetwork/handlers/providers_handler.go b/management/internals/modules/agentnetwork/handlers/providers_handler.go new file mode 100644 index 000000000..13da137d5 --- /dev/null +++ b/management/internals/modules/agentnetwork/handlers/providers_handler.go @@ -0,0 +1,217 @@ +// Package handlers serves the Agent Network HTTP API. +// +// All persistence is delegated to agentnetwork.Manager so this layer only +// translates between the wire format (api.AgentNetworkProvider*) and the +// domain types. +package handlers + +import ( + "encoding/json" + "net/http" + "net/url" + "strings" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork" + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/catalog" + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" +) + +type handler struct { + manager agentnetwork.Manager +} + +// RegisterEndpoints registers all Agent Network routes. +func RegisterEndpoints(manager agentnetwork.Manager, router *mux.Router) { + h := &handler{manager: manager} + router.HandleFunc("/agent-network/catalog/providers", h.getCatalogProviders).Methods("GET", "OPTIONS") + router.HandleFunc("/agent-network/providers", h.getAllProviders).Methods("GET", "OPTIONS") + router.HandleFunc("/agent-network/providers", h.createProvider).Methods("POST", "OPTIONS") + router.HandleFunc("/agent-network/providers/{providerId}", h.getProvider).Methods("GET", "OPTIONS") + router.HandleFunc("/agent-network/providers/{providerId}", h.updateProvider).Methods("PUT", "OPTIONS") + router.HandleFunc("/agent-network/providers/{providerId}", h.deleteProvider).Methods("DELETE", "OPTIONS") + h.addPolicyEndpoints(router) + h.addGuardrailEndpoints(router) + h.addSettingsEndpoints(router) + h.addConsumptionEndpoints(router) + h.addAccessLogEndpoints(router) + h.addBudgetRuleEndpoints(router) +} + +func (h *handler) getCatalogProviders(w http.ResponseWriter, r *http.Request) { + if _, err := nbcontext.GetUserAuthFromContext(r.Context()); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + entries := catalog.All() + out := make([]api.AgentNetworkCatalogProvider, 0, len(entries)) + for _, e := range entries { + out = append(out, e.ToAPIResponse()) + } + util.WriteJSONObject(r.Context(), w, out) +} + +func (h *handler) getAllProviders(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + providers, err := h.manager.GetAllProviders(r.Context(), userAuth.AccountId, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + out := make([]*api.AgentNetworkProvider, 0, len(providers)) + for _, p := range providers { + out = append(out, p.ToAPIResponse()) + } + util.WriteJSONObject(r.Context(), w, out) +} + +func (h *handler) getProvider(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + providerID := mux.Vars(r)["providerId"] + if providerID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w) + return + } + + provider, err := h.manager.GetProvider(r.Context(), userAuth.AccountId, userAuth.UserId, providerID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + util.WriteJSONObject(r.Context(), w, provider.ToAPIResponse()) +} + +func (h *handler) createProvider(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.AgentNetworkProviderRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + if err := validate(&req, true); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + provider := types.NewProvider(userAuth.AccountId) + provider.FromAPIRequest(&req) + + bootstrapCluster := "" + if req.BootstrapCluster != nil { + bootstrapCluster = *req.BootstrapCluster + } + + created, err := h.manager.CreateProvider(r.Context(), userAuth.UserId, provider, bootstrapCluster) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, created.ToAPIResponse()) +} + +func (h *handler) updateProvider(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + providerID := mux.Vars(r)["providerId"] + if providerID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w) + return + } + + var req api.AgentNetworkProviderRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + if err := validate(&req, false); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + provider := &types.Provider{ + ID: providerID, + AccountID: userAuth.AccountId, + } + provider.FromAPIRequest(&req) + + updated, err := h.manager.UpdateProvider(r.Context(), userAuth.UserId, provider) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse()) +} + +func (h *handler) deleteProvider(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + providerID := mux.Vars(r)["providerId"] + if providerID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "provider ID is required"), w) + return + } + + if err := h.manager.DeleteProvider(r.Context(), userAuth.AccountId, userAuth.UserId, providerID); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} + +func validate(req *api.AgentNetworkProviderRequest, requireAPIKey bool) error { + if strings.TrimSpace(req.ProviderId) == "" { + return status.Errorf(status.InvalidArgument, "provider_id is required") + } + if !catalog.IsKnown(req.ProviderId) { + return status.Errorf(status.InvalidArgument, "provider_id %q is not a known catalog provider", req.ProviderId) + } + if strings.TrimSpace(req.Name) == "" { + return status.Errorf(status.InvalidArgument, "name is required") + } + if strings.TrimSpace(req.UpstreamUrl) == "" { + return status.Errorf(status.InvalidArgument, "upstream_url is required") + } + u, err := url.Parse(strings.TrimSpace(req.UpstreamUrl)) + if err != nil || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") { + return status.Errorf(status.InvalidArgument, "upstream_url must be a full http(s) URL") + } + if requireAPIKey && (req.ApiKey == nil || strings.TrimSpace(*req.ApiKey) == "") { + return status.Errorf(status.InvalidArgument, "api_key is required") + } + return nil +} diff --git a/management/internals/modules/agentnetwork/handlers/settings_handler.go b/management/internals/modules/agentnetwork/handlers/settings_handler.go new file mode 100644 index 000000000..c65efad0f --- /dev/null +++ b/management/internals/modules/agentnetwork/handlers/settings_handler.go @@ -0,0 +1,74 @@ +package handlers + +import ( + "encoding/json" + "errors" + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" +) + +// addSettingsEndpoints registers the Agent Network settings routes. The +// settings row is bootstrapped server-side on first provider create; GET reads +// it and PUT updates the mutable collection toggles (cluster/subdomain stay +// immutable). +func (h *handler) addSettingsEndpoints(router *mux.Router) { + router.HandleFunc("/agent-network/settings", h.getSettings).Methods("GET", "OPTIONS") + router.HandleFunc("/agent-network/settings", h.updateSettings).Methods("PUT", "OPTIONS") +} + +// updateSettings applies the collection toggles to the account's settings row. +func (h *handler) updateSettings(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.AgentNetworkSettingsRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + settings := &types.Settings{AccountID: userAuth.AccountId} + settings.FromAPIRequest(&req) + + updated, err := h.manager.UpdateSettings(r.Context(), userAuth.UserId, settings) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + util.WriteJSONObject(r.Context(), w, updated.ToAPIResponse()) +} + +// getSettings returns the account's agent-network settings. The settings +// row is bootstrapped on first provider create, so freshly-onboarded +// accounts have nothing to read. Rather than 404-ing in that case (which +// the dashboard would have to special-case), return a JSON null with 200 +// so consumers can branch on the body alone. +func (h *handler) getSettings(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + settings, err := h.manager.GetSettings(r.Context(), userAuth.AccountId, userAuth.UserId) + if err != nil { + var sErr *status.Error + if errors.As(err, &sErr) && sErr.Type() == status.NotFound { + util.WriteJSONObject(r.Context(), w, nil) + return + } + util.WriteError(r.Context(), err, w) + return + } + util.WriteJSONObject(r.Context(), w, settings.ToAPIResponse()) +} diff --git a/management/internals/modules/agentnetwork/labelgen/labelgen.go b/management/internals/modules/agentnetwork/labelgen/labelgen.go new file mode 100644 index 000000000..b45ff4ea8 --- /dev/null +++ b/management/internals/modules/agentnetwork/labelgen/labelgen.go @@ -0,0 +1,66 @@ +// Package labelgen produces DNS-safe Agent Network subdomain labels. +package labelgen + +import ( + "fmt" + "math/rand" + "sort" + "sync" +) + +// pickAttempts caps the random retries before falling back to the +// suffixed form. Eight is a soft compromise: with a near-empty taken +// set the very first pick almost always succeeds; when the wordlist is +// densely populated the fallback eventually fires anyway. +const pickAttempts = 8 + +var ( + dedupOnce sync.Once + uniqWords []string +) + +// uniqueWords returns the wordlist deduplicated and sorted for +// deterministic exhaustion behaviour. Lazy-built once per process. +func uniqueWords() []string { + dedupOnce.Do(func() { + seen := make(map[string]struct{}, len(words)) + uniqWords = make([]string, 0, len(words)) + for _, w := range words { + if _, ok := seen[w]; ok { + continue + } + seen[w] = struct{}{} + uniqWords = append(uniqWords, w) + } + sort.Strings(uniqWords) + }) + return uniqWords +} + +// PickUnique selects a label not already in `taken`. It tries up to +// pickAttempts random picks; on exhaustion it scans the deduplicated +// wordlist for any remaining free entry, and if none is left appends +// `-` to a deterministic word and returns. The caller +// is responsible for seeding rng (math/rand). +func PickUnique(rng *rand.Rand, taken map[string]struct{}, fallbackSuffix string) string { + pool := uniqueWords() + if len(pool) == 0 { + return fallbackSuffix + } + + for i := 0; i < pickAttempts; i++ { + w := pool[rng.Intn(len(pool))] + if _, ok := taken[w]; !ok { + return w + } + } + + for _, w := range pool { + if _, ok := taken[w]; !ok { + return w + } + } + + w := pool[rng.Intn(len(pool))] + return fmt.Sprintf("%s-%s", w, fallbackSuffix) +} diff --git a/management/internals/modules/agentnetwork/labelgen/labelgen_test.go b/management/internals/modules/agentnetwork/labelgen/labelgen_test.go new file mode 100644 index 000000000..f03a3501d --- /dev/null +++ b/management/internals/modules/agentnetwork/labelgen/labelgen_test.go @@ -0,0 +1,101 @@ +package labelgen + +import ( + "math/rand" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestPickUnique_DeterministicWithSeededRng locks the property the +// caller relies on: same seed + same taken set → same pick. Without +// that, the bootstrap flow can't reproduce a label across retries. +func TestPickUnique_DeterministicWithSeededRng(t *testing.T) { + taken := map[string]struct{}{} + + rngA := rand.New(rand.NewSource(42)) + rngB := rand.New(rand.NewSource(42)) + + a := PickUnique(rngA, taken, "abcd") + b := PickUnique(rngB, taken, "abcd") + + assert.Equal(t, a, b, "Same seed and taken set must produce identical pick") +} + +// TestPickUnique_AvoidsTakenWordsWhenMostAreReserved seeds taken with +// every word in the pool except a handful and confirms PickUnique +// finds one of the remaining free entries instead of returning the +// fallback form. +func TestPickUnique_AvoidsTakenWordsWhenMostAreReserved(t *testing.T) { + pool := uniqueWords() + require.NotEmpty(t, pool, "wordlist must be populated for the test to mean anything") + + free := map[string]struct{}{ + pool[0]: {}, + pool[len(pool)/2]: {}, + pool[len(pool)-1]: {}, + } + + taken := make(map[string]struct{}, len(pool)) + for _, w := range pool { + if _, ok := free[w]; ok { + continue + } + taken[w] = struct{}{} + } + + rng := rand.New(rand.NewSource(7)) + got := PickUnique(rng, taken, "abcd") + + _, isFree := free[got] + assert.True(t, isFree, "PickUnique must return one of the free words; got %q", got) + assert.NotContains(t, got, "-", "Free pick must not be the suffix fallback form") +} + +// TestPickUnique_FallsBackWhenAllReserved exhausts the pool and +// confirms PickUnique appends the supplied suffix instead of +// returning a duplicate. +func TestPickUnique_FallsBackWhenAllReserved(t *testing.T) { + pool := uniqueWords() + + taken := make(map[string]struct{}, len(pool)) + for _, w := range pool { + taken[w] = struct{}{} + } + + rng := rand.New(rand.NewSource(99)) + got := PickUnique(rng, taken, "abcd") + + assert.True(t, strings.HasSuffix(got, "-abcd"), "Exhausted pool must produce -; got %q", got) + + prefix := strings.TrimSuffix(got, "-abcd") + found := false + for _, w := range pool { + if w == prefix { + found = true + break + } + } + assert.True(t, found, "Fallback prefix must be drawn from the wordlist; got %q", prefix) +} + +// TestUniqueWords_DropsDuplicates guards against authoring slips in +// words.go: every entry must be unique and DNS-safe. +func TestUniqueWords_DropsDuplicates(t *testing.T) { + pool := uniqueWords() + seen := make(map[string]struct{}, len(pool)) + for _, w := range pool { + _, dup := seen[w] + assert.False(t, dup, "Duplicate entry %q in deduplicated pool", w) + seen[w] = struct{}{} + assert.GreaterOrEqual(t, len(w), 4, "Word %q is shorter than 4 chars", w) + assert.LessOrEqual(t, len(w), 12, "Word %q is longer than 12 chars", w) + for _, r := range w { + ok := r >= 'a' && r <= 'z' + assert.True(t, ok, "Word %q contains non-lowercase-ASCII rune %q", w, r) + } + } + assert.GreaterOrEqual(t, len(pool), 500, "Pool must contain at least 500 unique words") +} diff --git a/management/internals/modules/agentnetwork/labelgen/words.go b/management/internals/modules/agentnetwork/labelgen/words.go new file mode 100644 index 000000000..2028ff23d --- /dev/null +++ b/management/internals/modules/agentnetwork/labelgen/words.go @@ -0,0 +1,136 @@ +// Package labelgen produces DNS-safe Agent Network subdomain labels. +// +// The wordlist below is a curated subset drawn from public-domain +// nature / common-noun pools (e.g. EFF's diceware lists). Every entry +// is lowercase ASCII, 4–12 chars, no hyphens, no digits, and was +// hand-checked to avoid offensive, brand, or region-specific terms. +package labelgen + +// words is the pool PickUnique selects from. The slice is intentionally +// not sorted — random picks distribute across the list naturally. +var words = []string{ + "acorn", "adobe", "agate", "alder", "almond", "alpine", "amber", "amethyst", + "anchor", "antler", "apple", "apricot", "arcade", "arctic", "arrow", "ashen", + "aspen", "atlas", "atom", "aurora", "autumn", "azure", + "badger", "bamboo", "banana", "banjo", "barley", "barn", "basalt", "basil", + "basin", "bayou", "beach", "beacon", "beaver", "beech", "beetle", "berry", + "birch", "bison", "blossom", "blue", "bobcat", "bonsai", "boulder", "branch", + "brass", "breeze", "bridge", "bright", "brook", "broom", "brown", "buffalo", + "bumble", "burrow", "butter", "button", + "cabin", "cactus", "calm", "camel", "campfire", "canary", "candle", "canoe", + "canyon", "cardinal", "carrot", "cascade", "castle", "cedar", "celery", "cello", + "cement", "cherry", "chestnut", "chime", "cinnamon", "cinder", "citron", "clay", + "clear", "cliff", "clock", "cloud", "clover", "coast", "cobalt", "cobble", + "cocoa", "coffee", "comet", "compass", "copper", "coral", "corner", "cosmos", + "cotton", "cougar", "country", "coyote", "cove", "crane", "crater", "creek", + "crescent", "crimson", "crocus", "crystal", "cypress", + "daffodil", "dahlia", "daisy", "dawn", "deer", "delta", "denim", "desert", + "dewdrop", "diamond", "dolphin", "doodle", "dove", "dragon", "drift", "drop", + "dune", "dusk", "dusty", + "eagle", "earth", "echo", "elder", "elkhorn", "ember", "emerald", "emperor", + "evergreen", "evening", + "falcon", "fawn", "feather", "fern", "fiddle", "field", "fiesta", "finch", + "firepit", "firefly", "fjord", "flame", "flax", "fleece", "flint", "floral", + "flower", "flute", "foal", "foggy", "forest", "fountain", "foxglove", "fresh", + "frost", "fuchsia", "fudge", + "gable", "galaxy", "garden", "garnet", "gazelle", "geode", "geyser", "ginger", + "glacier", "glade", "glass", "glow", "gold", "goose", "gorge", "gourd", + "granite", "grape", "grass", "gravel", "grayling", "greenery", "grizzly", "grove", + "gull", "gumdrop", "gust", + "hammock", "harbor", "harvest", "hawk", "hazel", "heather", "hedge", "heron", + "hibiscus", "hickory", "hideaway", "highland", "hill", "hive", "hollow", "honey", + "hopper", "horizon", "hummingbird", "husky", + "iceberg", "indigo", "iris", "island", "ivory", "ivybush", + "jade", "jasmine", "jasper", "jaybird", "jelly", "jewel", "jonquil", "journey", + "juniper", "jupiter", "jute", + "kale", "kangaroo", "kayak", "kelp", "kestrel", "kettle", "khaki", "kindling", + "kingfisher", "kiwi", "knapweed", "koala", + "lagoon", "lake", "lantern", "larch", "lark", "laurel", "lava", "lavender", + "leaf", "lemon", "lichen", "light", "lilac", "lily", "lime", "limestone", + "linden", "linen", "lion", "lobster", "locust", "loon", "lotus", "lumber", + "lunar", "lupine", "lynx", + "madrone", "magenta", "magnolia", "mahogany", "mallow", "mango", "manor", "maple", + "marble", "marigold", "marina", "marlin", "marsh", "mauve", "meadow", "melody", + "melon", "merlin", "metal", "midnight", "milk", "millet", "mineral", "mint", + "mirror", "mist", "mitten", "molasses", "moon", "moose", "morning", "moss", + "mountain", "mulberry", "muscat", "mustard", + "narwhal", "navy", "nectar", "needle", "nest", "nettle", "newt", "nightfall", + "noon", "nook", "north", "nova", "nutmeg", + "oaken", "oasis", "oatmeal", "ocean", "ochre", "octagon", "olive", "onyx", + "opal", "orange", "orbit", "orchard", "orchid", "oregano", "orion", "osprey", + "otter", "outpost", "owlet", "oyster", + "painter", "palace", "palm", "pansy", "panther", "papaya", "paprika", "parsley", + "partridge", "passage", "pastel", "patio", "peach", "peacock", "pear", "pearl", + "pebble", "pecan", "pelican", "penguin", "peony", "pepper", "perch", "peridot", + "pewter", "phoenix", "pier", "pillar", "pine", "pineapple", "pinto", "piper", + "pistachio", "plain", "planet", "plateau", "platinum", "plum", "plume", "polar", + "pollen", "pond", "poplar", "poppy", "porcelain", "portal", "portrait", "potato", + "prairie", "primrose", "prism", "puffin", "pumpkin", + "quail", "quartz", "quaver", "quill", "quince", "quinoa", + "rabbit", "raccoon", "radish", "rain", "rainbow", "raindrop", "rapids", "raspberry", + "raven", "ravine", "redwood", "reed", "reef", "ridge", "river", "robin", + "rocket", "rubyred", "rose", "rosemary", "rosewood", "ruffle", "rugby", "russet", + "rustic", "ryefield", + "saffron", "sage", "salmon", "sand", "sandstone", "sapphire", "savanna", "scarlet", + "scout", "seal", "season", "seaweed", "sequoia", "shadow", "shamrock", "shell", + "sherbet", "shore", "silver", "siskin", "skybloom", "skyline", "sleet", "smoke", + "snail", "snapdragon", "snow", "snowflake", "snowy", "solar", "song", "sonic", + "sorrel", "south", "sparkle", "sparrow", "spice", "spider", "spinach", "spire", + "spring", "sprout", "spruce", "squirrel", "starfish", "starlight", "stoat", "stone", + "stork", "storm", "stream", "studio", "summer", "sunbeam", "sundew", "sunny", + "sunrise", "sunset", "swallow", "swan", "sweet", "sycamore", + "tangelo", "tangerine", "tansy", "taupe", "teak", "teal", "thicket", "thistle", + "thrush", "thunder", "tide", "tiger", "tinder", "topaz", "torch", "tortoise", + "tower", "trail", "tranquil", "tundra", "tulip", "turquoise", "turtle", "twig", + "twilight", + "umber", "uplands", + "valley", "vanilla", "velvet", "venus", "verdant", "verdigris", "vermilion", "violet", + "vista", "vivid", "volcano", "vortex", + "walnut", "warbler", "watercress", "waterfall", "wave", "waxwing", "weasel", "westwind", + "whale", "whisker", "whisper", "wicker", "wildwood", "willow", "winter", "wisp", + "wisteria", "wolf", "wombat", "woodland", "woolly", "wren", "wreath", + "yarrow", "yellow", "yewtree", "yodel", + "zebra", "zenith", "zephyr", "zinnia", + "alabaster", "alfalfa", "almanac", "anise", "antelope", "arbor", "arena", "armadillo", + "avocet", "azalea", "balsam", "bayou", "beacon", "blizzard", "bluebell", "bluebird", + "bluejay", "bobolink", "borage", "boreal", "buckeye", "buckthorn", "buttercup", + "cabana", "calico", "canopy", "caraway", "cardamom", "cattail", "celadon", "centaur", + "chambray", "chamois", "champlain", "chestnuts", "chickadee", "chinook", "chipmunk", "cinnabar", + "cirrus", "citrine", "clematis", "copperhead", + "crocodile", "currant", "cuttlebone", "daffy", "dapple", "delphinium", "dervish", "diamondback", + "dogwood", "dolphins", "dragonfly", "driftwood", "dusk", "dustpan", "ebony", "edelweiss", + "emperor", "endive", "estuary", "everglade", "fairway", "feldspar", "fennel", "fieldstone", + "firebrand", "firefly", "fireweed", "firework", "flagstone", "fossil", "frostbite", "galleon", + "gardener", "geranium", "gingko", "ginseng", "goldfish", "goldfinch", "goldenrod", "graphite", + "greenfinch", "guppy", "haiku", "halibut", "hammerhead", "harbinger", "harvest", "hatchling", + "havana", "hawthorn", "hazelnut", "heartwood", "henna", "heron", "highrise", "homestead", + "honeycomb", "honeydew", "horseshoe", "hyacinth", "iceland", "icicle", "indigobird", "ironwood", + "jacaranda", "jamboree", "javelina", "jellyfish", "junebug", "kaleido", "kayaker", "kerchief", + "keystone", "kingdom", "labrador", "lacewing", "ladybug", "lakeside", "lamplight", "leopard", + "lighthouse", "lilypad", "lullaby", "magnet", "mahonia", "mandolin", "manzanita", "maraschino", + "mariner", "marsupial", "mastodon", "matterhorn", "mayflower", "mayfly", "meadowlark", "merlot", + "meteor", "midshipman", "millpond", "mimosa", "minnow", "mockingbird", "molten", "monarch", + "monsoon", "moondust", "moonlight", "moorland", "morning", "mossland", "mountain", "mulch", + "narcissus", "nautilus", "nettlebush", "northstar", "nuthatch", "obsidian", "okra", "olivine", + "opalescent", "orchidea", "orchard", "ornament", "outrigger", "oxalis", "paddler", "paintbrush", + "papyrus", "paradise", "pasture", "patchwork", "pathway", "peridot", "periwinkle", "petalbloom", + "petrel", "petunia", "phlox", "pikeperch", "pinecone", "pioneer", "pipevine", "platypus", + "pomelo", "pondweed", "porpoise", "powder", "promise", "puddle", "pumice", "puzzle", + "quetzal", "quicksilver", "raccoon", "ragwort", "rainforest", "ramble", "rapid", "rascal", + "raspberry", "redbud", "redfern", "redpoll", "reedling", "ringtail", "riverbed", "riverbird", + "riverstone", "rockcress", "roebuck", "rosebay", "rosehip", "rosemary", "rowan", "rumble", + "runaway", "rustler", "sagebrush", "sailcloth", "salamander", "salsify", "samphire", "sandbar", + "sanddollar", "sandpiper", "santolina", "sapodilla", "sassafras", "scallion", "schooner", "seafoam", + "seafrost", "seagrass", "seahorse", "seaport", "seashell", "seaspray", "shamble", "shimmer", + "shoreline", "silkmoth", "silverfox", "skylark", "snapdragon", "snowberry", "snowdrop", "snowfall", + "snowmelt", "softwood", "songbird", "sorghum", "southwind", "speedwell", "spinnaker", "spruce", + "starlight", "starling", "stormcloud", "summit", "sundance", "sundew", "sundial", "sunflower", + "surface", "swallowtail", "sweetcorn", "sycamore", "tabletop", "tamarack", "tamarind", "tangerine", + "tarragon", "telescope", "thicket", "thrasher", "thunder", "thyme", "tideline", "timberland", + "tinderbox", "topiary", "torchwood", "totem", "tradewind", "treasure", "tremolo", "trinket", + "trumpetvine", "tugboat", "tundra", "turnstone", "underbrush", "vagabond", "valerian", "vanilla", + "velveteen", "vermilion", "vinca", "vineyard", "violet", "voyager", "wagonwheel", "walnutwood", + "watermark", "watershed", "waterway", "wavefront", "westerly", "whaleback", "whetstone", "wicker", + "wildbloom", "wildflower", "wilderness", "windsong", "windward", "winterberry", "woodbine", "woodfern", + "woodland", "woodthrush", "woolgrass", "yellowfin", "zenithal", "zucchini", +} diff --git a/management/internals/modules/agentnetwork/manager.go b/management/internals/modules/agentnetwork/manager.go new file mode 100644 index 000000000..d88e0d77c --- /dev/null +++ b/management/internals/modules/agentnetwork/manager.go @@ -0,0 +1,911 @@ +package agentnetwork + +import ( + "context" + "errors" + "fmt" + "math/rand" + "slices" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/labelgen" + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/management/status" +) + +// ensureSessionKeys mints an ed25519 session keypair on the provider +// when one is missing. Idempotent: skips when both fields are already +// populated (e.g. update or migrated rows). The keys are used by the +// synthesised reverse-proxy service to sign / verify session JWTs +// after a successful OIDC handshake. +func ensureSessionKeys(p *types.Provider) error { + if p.SessionPrivateKey != "" && p.SessionPublicKey != "" { + return nil + } + pair, err := sessionkey.GenerateKeyPair() + if err != nil { + return fmt.Errorf("generate provider session keys: %w", err) + } + p.SessionPrivateKey = pair.PrivateKey + p.SessionPublicKey = pair.PublicKey + return nil +} + +// Manager governs the lifecycle of Agent Network providers and policies. +type Manager interface { + GetAllProviders(ctx context.Context, accountID, userID string) ([]*types.Provider, error) + GetProvider(ctx context.Context, accountID, userID, providerID string) (*types.Provider, error) + CreateProvider(ctx context.Context, userID string, provider *types.Provider, bootstrapCluster string) (*types.Provider, error) + UpdateProvider(ctx context.Context, userID string, provider *types.Provider) (*types.Provider, error) + DeleteProvider(ctx context.Context, accountID, userID, providerID string) error + + GetAllPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) + GetPolicy(ctx context.Context, accountID, userID, policyID string) (*types.Policy, error) + CreatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error) + UpdatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error) + DeletePolicy(ctx context.Context, accountID, userID, policyID string) error + + GetAllGuardrails(ctx context.Context, accountID, userID string) ([]*types.Guardrail, error) + GetGuardrail(ctx context.Context, accountID, userID, guardrailID string) (*types.Guardrail, error) + CreateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error) + UpdateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error) + DeleteGuardrail(ctx context.Context, accountID, userID, guardrailID string) error + + GetAllBudgetRules(ctx context.Context, accountID, userID string) ([]*types.AccountBudgetRule, error) + GetBudgetRule(ctx context.Context, accountID, userID, ruleID string) (*types.AccountBudgetRule, error) + CreateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error) + UpdateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error) + DeleteBudgetRule(ctx context.Context, accountID, userID, ruleID string) error + + GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) + UpdateSettings(ctx context.Context, userID string, settings *types.Settings) (*types.Settings, error) + + ListConsumption(ctx context.Context, accountID, userID string) ([]*types.Consumption, error) + ListAccessLogs(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error) + ListAccessLogSessions(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLogSession, int64, error) + GetUsageOverview(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter, granularity types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error) + StartAccessLogCleanup(ctx context.Context, cleanupIntervalHours int) + RecordConsumption(ctx context.Context, accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds, tokensIn, tokensOut int64, costUSD float64) error + RecordAccountBudgetUsage(ctx context.Context, accountID, userID string, groupIDs []string, tokensIn, tokensOut int64, costUSD float64) error + RecordUsage(ctx context.Context, in RecordUsageInput) error + SelectPolicyForRequest(ctx context.Context, in PolicySelectionInput) (*PolicySelectionResult, error) +} + +// PolicySelectionInput is the per-request selection envelope. The +// proxy populates it from CapturedData (account, user, groups) plus +// the provider llm_router resolved. +type PolicySelectionInput struct { + AccountID string + UserID string + GroupIDs []string + ProviderID string +} + +// PolicySelectionResult names the policy that "pays" for this request +// plus the deny envelope when every applicable policy has exhausted +// every cap. AttributionGroupID is the lowest group id (string sort) +// of caller_groups ∩ selected_policy.source_groups; empty when no +// group dimension applies. WindowSeconds is the chosen policy's +// effective window length in seconds (token_limit's wins when both +// halves are enabled with mismatched windows; budget_limit's +// otherwise; 0 when no caps are configured at all). +type PolicySelectionResult struct { + Allow bool + SelectedPolicyID string + AttributionGroupID string + WindowSeconds int64 + DenyCode string + DenyReason string +} + +type managerImpl struct { + store store.Store + accountManager account.Manager + permissionsManager permissions.Manager + proxyController proxy.Controller + + // reconcileCache holds the last set of synthesised proxy mappings + // per account so reconcile can emit precise Create/Update/Delete + // updates instead of a full re-push on every mutation. Keyed by + // accountID, then by synthesised service ID. + reconcileMu sync.Mutex + reconcileCache map[string]map[string]*proto.ProxyMapping + + // labelRngMu guards labelRng. PickUnique consumes math/rand.Source + // state; concurrent provider creates would otherwise race. + labelRngMu sync.Mutex + labelRng *rand.Rand +} + +// NewManager constructs the persistent Agent Network manager. The +// manager persists provider/policy/guardrail configuration and, on +// every mutation, reconciles the in-memory synthesised reverse-proxy +// services with the proxy cluster via proxyController. Pass nil for +// proxyController to disable the reconcile push (useful in tests). +func NewManager( + store store.Store, + permissionsManager permissions.Manager, + accountManager account.Manager, + proxyController proxy.Controller, +) Manager { + return &managerImpl{ + store: store, + accountManager: accountManager, + permissionsManager: permissionsManager, + proxyController: proxyController, + reconcileCache: make(map[string]map[string]*proto.ProxyMapping), + labelRng: rand.New(rand.NewSource(time.Now().UnixNano())), + } +} + +func (m *managerImpl) GetAllProviders(ctx context.Context, accountID, userID string) ([]*types.Provider, error) { + if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil { + return nil, err + } + return m.store.GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, accountID) +} + +func (m *managerImpl) GetProvider(ctx context.Context, accountID, userID, providerID string) (*types.Provider, error) { + if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil { + return nil, err + } + return m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthNone, accountID, providerID) +} + +// CreateProvider persists a new provider for the account. bootstrapCluster +// is used only when the per-account agent-network Settings row hasn't +// been created yet; otherwise it is ignored (the cluster is pinned on +// Settings and every provider in the account routes through it). +func (m *managerImpl) CreateProvider(ctx context.Context, userID string, provider *types.Provider, bootstrapCluster string) (*types.Provider, error) { + if err := m.requirePermission(ctx, provider.AccountID, userID, operations.Create); err != nil { + return nil, err + } + + // An empty api_key would silently produce a synthesised service + // that 401s on every upstream request. Surface the misconfiguration + // at create time instead. + if strings.TrimSpace(provider.APIKey) == "" { + return nil, status.Errorf(status.InvalidArgument, "api_key is required when creating an agent network provider") + } + + if provider.ID == "" { + fresh := types.NewProvider(provider.AccountID) + provider.ID = fresh.ID + provider.CreatedAt = fresh.CreatedAt + provider.UpdatedAt = fresh.UpdatedAt + } + + if err := ensureSessionKeys(provider); err != nil { + return nil, err + } + + if err := m.store.SaveAgentNetworkProvider(ctx, provider); err != nil { + return nil, fmt.Errorf("save agent network provider: %w", err) + } + + if strings.TrimSpace(bootstrapCluster) != "" { + if _, err := m.bootstrapSettingsIfNeeded(ctx, provider.AccountID, bootstrapCluster); err != nil { + // The provider create has already succeeded; logging the + // bootstrap miss matches the plan's PoC behaviour. The synth + // path treats a missing settings row as a no-op, and the next + // provider create retries the bootstrap. + log.WithContext(ctx).Debugf("agent-network bootstrap settings for account %s on cluster %s: %v", provider.AccountID, bootstrapCluster, err) + } + } + + m.accountManager.StoreEvent(ctx, userID, provider.ID, provider.AccountID, activity.AgentNetworkProviderCreated, provider.EventMeta()) + m.reconcile(ctx, provider.AccountID) + + return provider, nil +} + +func (m *managerImpl) UpdateProvider(ctx context.Context, userID string, provider *types.Provider) (*types.Provider, error) { + if err := m.requirePermission(ctx, provider.AccountID, userID, operations.Update); err != nil { + return nil, err + } + + existing, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthUpdate, provider.AccountID, provider.ID) + if err != nil { + return nil, fmt.Errorf("failed to get agent network provider: %w", err) + } + + // Preserve the API key if the caller didn't rotate it. A + // whitespace-only value is treated as "not rotated" rather than a + // real key, but it must not silently overwrite a valid stored key. + if provider.APIKey == "" { + provider.APIKey = existing.APIKey + } else if strings.TrimSpace(provider.APIKey) == "" { + return nil, status.Errorf(status.InvalidArgument, "api_key must be non-blank when rotating an agent network provider") + } + // Always preserve the session keypair across updates so existing + // session cookies stay valid. The keys are server-managed and + // never surfaced through the API. + provider.SessionPrivateKey = existing.SessionPrivateKey + provider.SessionPublicKey = existing.SessionPublicKey + if err := ensureSessionKeys(provider); err != nil { + return nil, err + } + provider.CreatedAt = existing.CreatedAt + provider.UpdatedAt = time.Now().UTC() + + if err := m.store.SaveAgentNetworkProvider(ctx, provider); err != nil { + return nil, fmt.Errorf("save agent network provider: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, provider.ID, provider.AccountID, activity.AgentNetworkProviderUpdated, provider.EventMeta()) + m.reconcile(ctx, provider.AccountID) + + return provider, nil +} + +func (m *managerImpl) DeleteProvider(ctx context.Context, accountID, userID, providerID string) error { + if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil { + return err + } + + provider, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthUpdate, accountID, providerID) + if err != nil { + return fmt.Errorf("failed to get agent network provider: %w", err) + } + + // Refuse to delete while any policy still references this provider. + // The operator must detach it first. + policies, err := m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return fmt.Errorf("failed to get agent network policies: %w", err) + } + var blocking []string + for _, p := range policies { + if slices.Contains(p.DestinationProviderIDs, providerID) { + blocking = append(blocking, p.Name) + } + } + if len(blocking) > 0 { + return status.Errorf( + status.InvalidArgument, + "provider is in use by %d %s (%s); detach it before deleting", + len(blocking), + pluralize(len(blocking), "policy", "policies"), + strings.Join(blocking, ", "), + ) + } + + if err := m.store.DeleteAgentNetworkProvider(ctx, accountID, providerID); err != nil { + return fmt.Errorf("failed to delete agent network provider: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, providerID, accountID, activity.AgentNetworkProviderDeleted, provider.EventMeta()) + m.reconcile(ctx, accountID) + + return nil +} + +func pluralize(n int, singular, plural string) string { + if n == 1 { + return singular + } + return plural +} + +func (m *managerImpl) GetAllPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) { + if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil { + return nil, err + } + return m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, accountID) +} + +func (m *managerImpl) GetPolicy(ctx context.Context, accountID, userID, policyID string) (*types.Policy, error) { + if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil { + return nil, err + } + return m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID) +} + +func (m *managerImpl) CreatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error) { + if err := m.requirePermission(ctx, policy.AccountID, userID, operations.Create); err != nil { + return nil, err + } + + if policy.ID == "" { + fresh := types.NewPolicy(policy.AccountID) + policy.ID = fresh.ID + policy.CreatedAt = fresh.CreatedAt + policy.UpdatedAt = fresh.UpdatedAt + } + + if err := m.validateProviderRefs(ctx, policy.AccountID, policy.DestinationProviderIDs); err != nil { + return nil, err + } + + if err := m.store.SaveAgentNetworkPolicy(ctx, policy); err != nil { + return nil, fmt.Errorf("failed to save agent network policy: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, policy.ID, policy.AccountID, activity.AgentNetworkPolicyCreated, policy.EventMeta()) + m.reconcile(ctx, policy.AccountID) + + return policy, nil +} + +func (m *managerImpl) UpdatePolicy(ctx context.Context, userID string, policy *types.Policy) (*types.Policy, error) { + if err := m.requirePermission(ctx, policy.AccountID, userID, operations.Update); err != nil { + return nil, err + } + + existing, err := m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthUpdate, policy.AccountID, policy.ID) + if err != nil { + return nil, fmt.Errorf("failed to get agent network policy: %w", err) + } + + if err := m.validateProviderRefs(ctx, policy.AccountID, policy.DestinationProviderIDs); err != nil { + return nil, err + } + + policy.CreatedAt = existing.CreatedAt + policy.UpdatedAt = time.Now().UTC() + + if err := m.store.SaveAgentNetworkPolicy(ctx, policy); err != nil { + return nil, fmt.Errorf("failed to save agent network policy: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, policy.ID, policy.AccountID, activity.AgentNetworkPolicyUpdated, policy.EventMeta()) + m.reconcile(ctx, policy.AccountID) + + return policy, nil +} + +func (m *managerImpl) DeletePolicy(ctx context.Context, accountID, userID, policyID string) error { + if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil { + return err + } + + policy, err := m.store.GetAgentNetworkPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID) + if err != nil { + return fmt.Errorf("failed to get agent network policy: %w", err) + } + + if err := m.store.DeleteAgentNetworkPolicy(ctx, accountID, policyID); err != nil { + return fmt.Errorf("failed to delete agent network policy: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, policyID, accountID, activity.AgentNetworkPolicyDeleted, policy.EventMeta()) + m.reconcile(ctx, accountID) + + return nil +} + +func (m *managerImpl) GetAllGuardrails(ctx context.Context, accountID, userID string) ([]*types.Guardrail, error) { + if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil { + return nil, err + } + return m.store.GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, accountID) +} + +func (m *managerImpl) GetGuardrail(ctx context.Context, accountID, userID, guardrailID string) (*types.Guardrail, error) { + if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil { + return nil, err + } + return m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthNone, accountID, guardrailID) +} + +func (m *managerImpl) CreateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error) { + if err := m.requirePermission(ctx, guardrail.AccountID, userID, operations.Create); err != nil { + return nil, err + } + + if guardrail.ID == "" { + fresh := types.NewGuardrail(guardrail.AccountID) + guardrail.ID = fresh.ID + guardrail.CreatedAt = fresh.CreatedAt + guardrail.UpdatedAt = fresh.UpdatedAt + } + + if err := m.store.SaveAgentNetworkGuardrail(ctx, guardrail); err != nil { + return nil, fmt.Errorf("failed to save agent network guardrail: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, guardrail.ID, guardrail.AccountID, activity.AgentNetworkGuardrailCreated, guardrail.EventMeta()) + m.reconcile(ctx, guardrail.AccountID) + + return guardrail, nil +} + +func (m *managerImpl) UpdateGuardrail(ctx context.Context, userID string, guardrail *types.Guardrail) (*types.Guardrail, error) { + if err := m.requirePermission(ctx, guardrail.AccountID, userID, operations.Update); err != nil { + return nil, err + } + + existing, err := m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthUpdate, guardrail.AccountID, guardrail.ID) + if err != nil { + return nil, fmt.Errorf("failed to get agent network guardrail: %w", err) + } + + guardrail.CreatedAt = existing.CreatedAt + guardrail.UpdatedAt = time.Now().UTC() + + if err := m.store.SaveAgentNetworkGuardrail(ctx, guardrail); err != nil { + return nil, fmt.Errorf("failed to save agent network guardrail: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, guardrail.ID, guardrail.AccountID, activity.AgentNetworkGuardrailUpdated, guardrail.EventMeta()) + m.reconcile(ctx, guardrail.AccountID) + + return guardrail, nil +} + +func (m *managerImpl) DeleteGuardrail(ctx context.Context, accountID, userID, guardrailID string) error { + if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil { + return err + } + + guardrail, err := m.store.GetAgentNetworkGuardrailByID(ctx, store.LockingStrengthUpdate, accountID, guardrailID) + if err != nil { + return fmt.Errorf("failed to get agent network guardrail: %w", err) + } + + if err := m.store.DeleteAgentNetworkGuardrail(ctx, accountID, guardrailID); err != nil { + return fmt.Errorf("failed to delete agent network guardrail: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, guardrailID, accountID, activity.AgentNetworkGuardrailDeleted, guardrail.EventMeta()) + m.reconcile(ctx, accountID) + + return nil +} + +// GetAllBudgetRules returns every account-level budget rule for the account. +func (m *managerImpl) GetAllBudgetRules(ctx context.Context, accountID, userID string) ([]*types.AccountBudgetRule, error) { + if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil { + return nil, err + } + return m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, accountID) +} + +// GetBudgetRule returns a single account-level budget rule. +func (m *managerImpl) GetBudgetRule(ctx context.Context, accountID, userID, ruleID string) (*types.AccountBudgetRule, error) { + if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil { + return nil, err + } + return m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthNone, accountID, ruleID) +} + +// CreateBudgetRule persists a new account-level budget rule. Budget rules are +// enforced at request time (CheckLLMPolicyLimits), not baked into the synth +// proxy config, so no reconcile is needed. +func (m *managerImpl) CreateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error) { + if err := m.requirePermission(ctx, rule.AccountID, userID, operations.Create); err != nil { + return nil, err + } + + if rule.ID == "" { + fresh := types.NewAccountBudgetRule(rule.AccountID) + rule.ID = fresh.ID + rule.CreatedAt = fresh.CreatedAt + rule.UpdatedAt = fresh.UpdatedAt + } + + if err := m.store.SaveAgentNetworkBudgetRule(ctx, rule); err != nil { + return nil, fmt.Errorf("save agent network budget rule: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, rule.ID, rule.AccountID, activity.AgentNetworkBudgetRuleCreated, rule.EventMeta()) + + return rule, nil +} + +// UpdateBudgetRule updates an existing account-level budget rule. +func (m *managerImpl) UpdateBudgetRule(ctx context.Context, userID string, rule *types.AccountBudgetRule) (*types.AccountBudgetRule, error) { + if err := m.requirePermission(ctx, rule.AccountID, userID, operations.Update); err != nil { + return nil, err + } + + existing, err := m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthUpdate, rule.AccountID, rule.ID) + if err != nil { + return nil, fmt.Errorf("get agent network budget rule: %w", err) + } + + rule.CreatedAt = existing.CreatedAt + rule.UpdatedAt = time.Now().UTC() + + if err := m.store.SaveAgentNetworkBudgetRule(ctx, rule); err != nil { + return nil, fmt.Errorf("save agent network budget rule: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, rule.ID, rule.AccountID, activity.AgentNetworkBudgetRuleUpdated, rule.EventMeta()) + + return rule, nil +} + +// DeleteBudgetRule removes an account-level budget rule. +func (m *managerImpl) DeleteBudgetRule(ctx context.Context, accountID, userID, ruleID string) error { + if err := m.requirePermission(ctx, accountID, userID, operations.Delete); err != nil { + return err + } + + rule, err := m.store.GetAgentNetworkBudgetRuleByID(ctx, store.LockingStrengthUpdate, accountID, ruleID) + if err != nil { + return fmt.Errorf("get agent network budget rule: %w", err) + } + + if err := m.store.DeleteAgentNetworkBudgetRule(ctx, accountID, ruleID); err != nil { + return fmt.Errorf("delete agent network budget rule: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, ruleID, accountID, activity.AgentNetworkBudgetRuleDeleted, rule.EventMeta()) + + return nil +} + +// UpdateSettings applies the mutable account-level settings — the collection +// toggles — onto the existing row. Cluster and Subdomain are immutable and are +// preserved from the persisted row regardless of the input. Because the +// collection toggles change the synthesised service config (prompt-capture +// gating, access-log emission), a reconcile is triggered so the proxy and peer +// network maps converge on the new state. +func (m *managerImpl) UpdateSettings(ctx context.Context, userID string, settings *types.Settings) (*types.Settings, error) { + if err := m.requirePermission(ctx, settings.AccountID, userID, operations.Update); err != nil { + return nil, err + } + + existing, err := m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthUpdate, settings.AccountID) + if err != nil { + return nil, fmt.Errorf("get agent network settings: %w", err) + } + + existing.EnableLogCollection = settings.EnableLogCollection + existing.EnablePromptCollection = settings.EnablePromptCollection + existing.RedactPii = settings.RedactPii + existing.AccessLogRetentionDays = settings.AccessLogRetentionDays + existing.UpdatedAt = time.Now().UTC() + + if err := m.store.SaveAgentNetworkSettings(ctx, existing); err != nil { + return nil, fmt.Errorf("save agent network settings: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, settings.AccountID, settings.AccountID, activity.AgentNetworkSettingsUpdated, map[string]any{ + "log_collection": existing.EnableLogCollection, + "prompt_collection": existing.EnablePromptCollection, + "redact_pii": existing.RedactPii, + }) + m.reconcile(ctx, settings.AccountID) + + return existing, nil +} + +// validateProviderRefs ensures every destination provider id refers to a +// provider that exists in the same account. +func (m *managerImpl) validateProviderRefs(ctx context.Context, accountID string, providerIDs []string) error { + if len(providerIDs) == 0 { + return nil + } + for _, id := range providerIDs { + if _, err := m.store.GetAgentNetworkProviderByID(ctx, store.LockingStrengthNone, accountID, id); err != nil { + // Only a genuine not-found means the reference is invalid; a + // store/runtime error must propagate as-is rather than be + // masked as a client validation error. + var sErr *status.Error + if errors.As(err, &sErr) && sErr.Type() == status.NotFound { + return status.Errorf(status.InvalidArgument, "destination_provider_ids: provider %s does not exist", id) + } + return fmt.Errorf("get destination provider %s: %w", id, err) + } + } + return nil +} + +// GetSettings returns the agent-network settings row for the account. +// Returns the underlying status.NotFound when no row has been +// bootstrapped yet (i.e. the account has no providers). +func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) { + if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil { + return nil, err + } + return m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID) +} + +// bootstrapSettingsIfNeeded creates the per-account agent-network +// settings row when missing. The cluster comes from the create-time +// hint the dashboard sends (auto-picked from the active cluster list); +// the subdomain is picked from the curated wordlist avoiding +// collisions on the same cluster. Idempotent: if a row already exists +// it is returned untouched and the hint is ignored. +func (m *managerImpl) bootstrapSettingsIfNeeded(ctx context.Context, accountID, providerCluster string) (*types.Settings, error) { + if accountID == "" { + return nil, fmt.Errorf("bootstrap settings: account id is required") + } + if strings.TrimSpace(providerCluster) == "" { + return nil, fmt.Errorf("bootstrap settings: provider cluster is required") + } + + existing, err := m.store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID) + if err == nil { + return existing, nil + } + var sErr *status.Error + if !errors.As(err, &sErr) || sErr.Type() != status.NotFound { + return nil, fmt.Errorf("get agent network settings: %w", err) + } + + siblings, err := m.store.GetAgentNetworkSettingsByCluster(ctx, store.LockingStrengthNone, providerCluster) + if err != nil { + return nil, fmt.Errorf("list agent network settings on cluster: %w", err) + } + taken := make(map[string]struct{}, len(siblings)) + for _, s := range siblings { + taken[s.Subdomain] = struct{}{} + } + + suffix := accountID + if len(suffix) > 4 { + suffix = suffix[:4] + } + + m.labelRngMu.Lock() + subdomain := labelgen.PickUnique(m.labelRng, taken, suffix) + m.labelRngMu.Unlock() + + now := time.Now().UTC() + settings := &types.Settings{ + AccountID: accountID, + Cluster: providerCluster, + Subdomain: subdomain, + // Logs on by default; usage is collected regardless. Retention bounds + // how long full log rows are kept. + EnableLogCollection: true, + AccessLogRetentionDays: types.DefaultAccessLogRetentionDays, + CreatedAt: now, + UpdatedAt: now, + } + if err := m.store.SaveAgentNetworkSettings(ctx, settings); err != nil { + return nil, fmt.Errorf("save agent network settings: %w", err) + } + return settings, nil +} + +// ListConsumption returns every consumption row recorded for the +// account, ordered window-newest-first. Backs the dashboard's basic +// counter view; permission gate is the same Read role that gates +// every other agent-network surface. +func (m *managerImpl) ListConsumption(ctx context.Context, accountID, userID string) ([]*types.Consumption, error) { + if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil { + return nil, err + } + return m.store.ListAgentNetworkConsumption(ctx, store.LockingStrengthNone, accountID) +} + +// ListAccessLogs returns a paginated, server-side-filtered page of +// agent-network access logs plus the total count matching the filter. +func (m *managerImpl) ListAccessLogs(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error) { + if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil { + return nil, 0, err + } + return m.store.GetAgentNetworkAccessLogs(ctx, store.LockingStrengthNone, accountID, filter) +} + +// ListAccessLogSessions returns a paginated, server-side-filtered page of +// agent-network access logs grouped by session, plus the total number of +// sessions matching the filter. +func (m *managerImpl) ListAccessLogSessions(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLogSession, int64, error) { + if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil { + return nil, 0, err + } + return m.store.GetAgentNetworkAccessLogSessions(ctx, store.LockingStrengthNone, accountID, filter) +} + +// GetUsageOverview returns the filtered usage rows aggregated into time buckets +// at the requested granularity, oldest-first. +func (m *managerImpl) GetUsageOverview(ctx context.Context, accountID, userID string, filter types.AgentNetworkAccessLogFilter, granularity types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error) { + if err := m.requirePermission(ctx, accountID, userID, operations.Read); err != nil { + return nil, err + } + rows, err := m.store.GetAgentNetworkUsageRows(ctx, store.LockingStrengthNone, accountID, filter) + if err != nil { + return nil, err + } + return types.AggregateUsageByGranularity(rows, granularity), nil +} + +// StartAccessLogCleanup launches a background sweep that periodically deletes +// each account's agent-network access-log rows older than that account's +// AccessLogRetentionDays. Usage records are never swept. A non-positive +// interval defaults to 24h. +func (m *managerImpl) StartAccessLogCleanup(ctx context.Context, cleanupIntervalHours int) { + if cleanupIntervalHours <= 0 { + cleanupIntervalHours = 24 + } + interval := time.Duration(cleanupIntervalHours) * time.Hour + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + m.cleanupAccessLogsOnce(ctx) // run once on startup + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + m.cleanupAccessLogsOnce(ctx) + } + } + }() +} + +// cleanupAccessLogsOnce sweeps every account's expired access-log rows against +// its configured retention. Best-effort: a per-account failure is logged and +// the sweep continues. +func (m *managerImpl) cleanupAccessLogsOnce(ctx context.Context) { + settings, err := m.store.GetAllAgentNetworkSettings(ctx, store.LockingStrengthNone) + if err != nil { + log.WithContext(ctx).Errorf("agent-network access-log cleanup: list settings: %v", err) + return + } + for _, s := range settings { + if s.AccessLogRetentionDays <= 0 { + continue // keep indefinitely + } + cutoff := time.Now().UTC().AddDate(0, 0, -s.AccessLogRetentionDays) + deleted, err := m.store.DeleteOldAgentNetworkAccessLogs(ctx, s.AccountID, cutoff) + if err != nil { + log.WithContext(ctx).Warnf("agent-network access-log cleanup for account %s: %v", s.AccountID, err) + continue + } + if deleted > 0 { + log.WithContext(ctx).Infof("agent-network access-log cleanup: deleted %d rows for account %s (retention %d days)", deleted, s.AccountID, s.AccessLogRetentionDays) + } + } +} + +// RecordConsumption increments the (dim, window) counter by the +// supplied deltas. The window_start is computed from time.Now under +// the supplied window_seconds so callers don't have to pre-align — +// the proxy's post-flight path simply hands us tokens + cost and +// which dimension we're booking against. +func (m *managerImpl) RecordConsumption(ctx context.Context, accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds, tokensIn, tokensOut int64, costUSD float64) error { + if accountID == "" || dimID == "" || windowSeconds <= 0 { + return status.Errorf(status.InvalidArgument, "account_id, dim_id and window_seconds must be set") + } + windowStart := types.WindowStart(time.Now(), windowSeconds) + return m.store.IncrementAgentNetworkConsumption(ctx, accountID, kind, dimID, windowSeconds, windowStart, tokensIn, tokensOut, costUSD) +} + +func (m *managerImpl) requirePermission(ctx context.Context, accountID, userID string, op operations.Operation) error { + ok, _, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.AgentNetwork, op) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + return nil +} + +type mockManager struct{} + +// NewManagerMock returns a no-op manager useful for tests. +func NewManagerMock() Manager { + return &mockManager{} +} + +func (*mockManager) GetAllProviders(_ context.Context, _, _ string) ([]*types.Provider, error) { + return []*types.Provider{}, nil +} + +func (*mockManager) GetProvider(_ context.Context, _, _, _ string) (*types.Provider, error) { + return &types.Provider{}, nil +} + +func (*mockManager) CreateProvider(_ context.Context, _ string, p *types.Provider, _ string) (*types.Provider, error) { + return p, nil +} + +func (*mockManager) UpdateProvider(_ context.Context, _ string, p *types.Provider) (*types.Provider, error) { + return p, nil +} + +func (*mockManager) DeleteProvider(_ context.Context, _, _, _ string) error { return nil } + +func (*mockManager) GetAllPolicies(_ context.Context, _, _ string) ([]*types.Policy, error) { + return []*types.Policy{}, nil +} + +func (*mockManager) GetPolicy(_ context.Context, _, _, _ string) (*types.Policy, error) { + return &types.Policy{}, nil +} + +func (*mockManager) CreatePolicy(_ context.Context, _ string, p *types.Policy) (*types.Policy, error) { + return p, nil +} + +func (*mockManager) UpdatePolicy(_ context.Context, _ string, p *types.Policy) (*types.Policy, error) { + return p, nil +} + +func (*mockManager) DeletePolicy(_ context.Context, _, _, _ string) error { return nil } + +func (*mockManager) GetAllGuardrails(_ context.Context, _, _ string) ([]*types.Guardrail, error) { + return []*types.Guardrail{}, nil +} + +func (*mockManager) GetGuardrail(_ context.Context, _, _, _ string) (*types.Guardrail, error) { + return &types.Guardrail{}, nil +} + +func (*mockManager) CreateGuardrail(_ context.Context, _ string, g *types.Guardrail) (*types.Guardrail, error) { + return g, nil +} + +func (*mockManager) UpdateGuardrail(_ context.Context, _ string, g *types.Guardrail) (*types.Guardrail, error) { + return g, nil +} + +func (*mockManager) DeleteGuardrail(_ context.Context, _, _, _ string) error { return nil } + +func (*mockManager) GetAllBudgetRules(_ context.Context, _, _ string) ([]*types.AccountBudgetRule, error) { + return []*types.AccountBudgetRule{}, nil +} + +func (*mockManager) GetBudgetRule(_ context.Context, _, _, _ string) (*types.AccountBudgetRule, error) { + return &types.AccountBudgetRule{}, nil +} + +func (*mockManager) CreateBudgetRule(_ context.Context, _ string, r *types.AccountBudgetRule) (*types.AccountBudgetRule, error) { + return r, nil +} + +func (*mockManager) UpdateBudgetRule(_ context.Context, _ string, r *types.AccountBudgetRule) (*types.AccountBudgetRule, error) { + return r, nil +} + +func (*mockManager) DeleteBudgetRule(_ context.Context, _, _, _ string) error { return nil } + +func (*mockManager) GetSettings(_ context.Context, _, _ string) (*types.Settings, error) { + return nil, status.Errorf(status.NotFound, "agent network settings not found") +} + +func (*mockManager) UpdateSettings(_ context.Context, _ string, s *types.Settings) (*types.Settings, error) { + return s, nil +} + +func (*mockManager) ListConsumption(_ context.Context, _, _ string) ([]*types.Consumption, error) { + return nil, nil +} + +func (*mockManager) ListAccessLogs(_ context.Context, _, _ string, _ types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLog, int64, error) { + return nil, 0, nil +} + +func (*mockManager) ListAccessLogSessions(_ context.Context, _, _ string, _ types.AgentNetworkAccessLogFilter) ([]*types.AgentNetworkAccessLogSession, int64, error) { + return nil, 0, nil +} + +func (*mockManager) GetUsageOverview(_ context.Context, _, _ string, _ types.AgentNetworkAccessLogFilter, _ types.UsageGranularity) ([]*types.AgentNetworkUsageBucket, error) { + return nil, nil +} + +func (*mockManager) StartAccessLogCleanup(_ context.Context, _ int) {} + +func (*mockManager) RecordConsumption(_ context.Context, _ string, _ types.ConsumptionDimension, _ string, _, _, _ int64, _ float64) error { + return nil +} + +func (*mockManager) RecordAccountBudgetUsage(_ context.Context, _, _ string, _ []string, _, _ int64, _ float64) error { + return nil +} + +func (*mockManager) RecordUsage(_ context.Context, _ RecordUsageInput) error { + return nil +} diff --git a/management/internals/modules/agentnetwork/policyselect.go b/management/internals/modules/agentnetwork/policyselect.go new file mode 100644 index 000000000..9203a1910 --- /dev/null +++ b/management/internals/modules/agentnetwork/policyselect.go @@ -0,0 +1,660 @@ +package agentnetwork + +import ( + "context" + "fmt" + "math" + "sort" + "time" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" +) + +// validateUsageDeltas rejects negative or non-finite usage counters before they +// reach the consumption store, so a bad delta can't decrement or poison totals. +// The store batch method enforces the same invariant; this is the manager-level +// guard so direct callers fail fast with a clear error. +func validateUsageDeltas(tokensIn, tokensOut int64, costUSD float64) error { + if tokensIn < 0 || tokensOut < 0 || costUSD < 0 || math.IsNaN(costUSD) || math.IsInf(costUSD, 0) { + return status.Errorf(status.InvalidArgument, "usage deltas must be non-negative and finite") + } + return nil +} + +// Deny codes the proxy surfaces back to the caller when every +// applicable policy is exhausted. The proxy converts these into +// upstream-shaped error responses. +const ( + //nolint:gosec // policy deny code label, not a credential + denyCodeTokenCapExceeded = "llm_policy.token_cap_exceeded" + //nolint:gosec // policy deny code label, not a credential + denyCodeBudgetCapExceeded = "llm_policy.budget_cap_exceeded" + //nolint:gosec // account deny code label, not a credential + denyCodeAccountTokenCapExceeded = "llm_account.token_cap_exceeded" + //nolint:gosec // account deny code label, not a credential + denyCodeAccountBudgetCapExceeded = "llm_account.budget_cap_exceeded" +) + +// consumptionCache holds the consumption counters prefetched for one +// policy-selection request, keyed by ConsumptionKey. A miss returns a zero +// counter — the same contract the store's single-row getter uses for absent +// rows — so the eval logic is identical whether a counter exists yet or not. +type consumptionCache map[types.ConsumptionKey]*types.Consumption + +func (c consumptionCache) get(accountID string, kind types.ConsumptionDimension, dimID string, windowSeconds int64, windowStart time.Time) *types.Consumption { + key := types.ConsumptionKey{Kind: kind, DimID: dimID, WindowSeconds: windowSeconds, WindowStartUTC: windowStart.UTC()} + if row, ok := c[key]; ok && row != nil { + return row + } + return &types.Consumption{ + AccountID: accountID, + DimensionKind: kind, + DimensionID: dimID, + WindowSeconds: windowSeconds, + WindowStartUTC: windowStart.UTC(), + } +} + +// addLimitKeys records the user/group consumption keys a single enabled (token +// or budget) limit window reads for the given attribution group, into a dedup +// set. attrGroup may be empty (no group dimension applies). +func addLimitKeys(set map[types.ConsumptionKey]struct{}, userID, attrGroup string, windowSeconds int64, now time.Time) { + if windowSeconds <= 0 { + return + } + ws := types.WindowStart(now, windowSeconds) + if userID != "" { + set[types.ConsumptionKey{Kind: types.DimensionUser, DimID: userID, WindowSeconds: windowSeconds, WindowStartUTC: ws}] = struct{}{} + } + if attrGroup != "" { + set[types.ConsumptionKey{Kind: types.DimensionGroup, DimID: attrGroup, WindowSeconds: windowSeconds, WindowStartUTC: ws}] = struct{}{} + } +} + +// prefetchConsumption loads, in one store round-trip, every consumption counter +// that the account-budget ceiling and the candidate policies will read while +// scoring this request. This replaces the per-cap point reads the selector +// previously issued one at a time (the N+1 on the hot path). +func (m *managerImpl) prefetchConsumption(ctx context.Context, in PolicySelectionInput, rules []*types.AccountBudgetRule, candidates []*types.Policy, now time.Time) (consumptionCache, error) { + set := make(map[types.ConsumptionKey]struct{}) + for _, p := range candidates { + attr := lowestIntersect(p.SourceGroups, in.GroupIDs) + if p.Limits.TokenLimit.Enabled { + addLimitKeys(set, in.UserID, attr, p.Limits.TokenLimit.WindowSeconds, now) + } + if p.Limits.BudgetLimit.Enabled { + addLimitKeys(set, in.UserID, attr, p.Limits.BudgetLimit.WindowSeconds, now) + } + } + for _, r := range rules { + if r == nil || !r.Enabled || !budgetRuleApplies(r, in) { + continue + } + attr := lowestIntersect(r.TargetGroups, in.GroupIDs) + if r.Limits.TokenLimit.Enabled { + addLimitKeys(set, in.UserID, attr, r.Limits.TokenLimit.WindowSeconds, now) + } + if r.Limits.BudgetLimit.Enabled { + addLimitKeys(set, in.UserID, attr, r.Limits.BudgetLimit.WindowSeconds, now) + } + } + if len(set) == 0 { + return consumptionCache{}, nil + } + keys := make([]types.ConsumptionKey, 0, len(set)) + for k := range set { + keys = append(keys, k) + } + rows, err := m.store.GetAgentNetworkConsumptionBatch(ctx, store.LockingStrengthNone, in.AccountID, keys) + if err != nil { + return nil, fmt.Errorf("batch read consumption: %w", err) + } + return consumptionCache(rows), nil +} + +// SelectPolicyForRequest picks the policy that "pays" for the +// incoming request. The chosen policy is the one with the largest +// pool that still has headroom — drain the bigger bucket first, +// fall through to the next-biggest only when the current one's +// group cap or shared per-user cap is exhausted. This matches +// operator intuition for layered tiers ("privileged group has the +// 10k budget, regular group has 1k as the safety net") and avoids +// the load-balancer flapping that fraction-based scoring produces +// once any cap has been touched. +// +// Ordering across non-exhausted candidates: +// 1. Policies with NO enabled caps (catch-all-allow) win over any +// capped policy — operators who configure unlimited access +// expect requests to attribute there until they explicitly add +// caps. +// 2. Larger group token cap wins. +// 3. Larger group budget USD cap wins. +// 4. Larger user token cap wins. +// 5. Larger user budget USD cap wins. +// 6. Older created_at wins (deterministic final tiebreak so +// multi-node selection converges). +// +// Returns Allow=true with empty SelectedPolicyID when no policy in +// the account targets the (provider, caller-groups) combination — +// llm_router is the gate that owns "no policy authorises this +// request" semantics; this function trusts that authorisation has +// already happened upstream and only does the limit-aware +// attribution. +func (m *managerImpl) SelectPolicyForRequest(ctx context.Context, in PolicySelectionInput) (*PolicySelectionResult, error) { + if in.AccountID == "" { + return nil, status.Errorf(status.InvalidArgument, "account_id is required") + } + + now := time.Now().UTC() + + rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, in.AccountID) + if err != nil { + return nil, fmt.Errorf("list account budget rules: %w", err) + } + policies, err := m.store.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, in.AccountID) + if err != nil { + return nil, fmt.Errorf("list account policies: %w", err) + } + candidates := filterApplicablePolicies(policies, in) + + // Prefetch every consumption counter the ceiling + candidate policies will + // read, in a single store round-trip, then score against the cache. + cache, err := m.prefetchConsumption(ctx, in, rules, candidates, now) + if err != nil { + return nil, err + } + + // Account-level budget rules are an always-on ceiling, evaluated + // independently of policy selection (they bind even for catch-all-allow + // policies or requests that match no policy). All applicable rules must + // pass — this is where min-wins lives. + if deny, code, reason := checkAccountBudget(in, rules, cache, now); deny { + return &PolicySelectionResult{Allow: false, DenyCode: code, DenyReason: reason}, nil + } + + if len(candidates) == 0 { + return &PolicySelectionResult{Allow: true}, nil + } + scored, lastDenyCode, lastDenyReason := scoreCandidates(in, candidates, cache, now) + if len(scored) == 0 { + return &PolicySelectionResult{ + Allow: false, + DenyCode: lastDenyCode, + DenyReason: lastDenyReason, + }, nil + } + + sort.SliceStable(scored, func(i, j int) bool { + // Catch-all-allow (no caps configured) wins outright over + // any capped policy. + iNoCap := isUncapped(scored[i].policy) + jNoCap := isUncapped(scored[j].policy) + if iNoCap != jNoCap { + return iNoCap + } + // Bigger pool drains first. Group caps dominate (shared + // across the group) before individual caps. + if a, b := groupCapTokens(scored[i].policy), groupCapTokens(scored[j].policy); a != b { + return a > b + } + if a, b := groupCapBudgetUsd(scored[i].policy), groupCapBudgetUsd(scored[j].policy); a != b { + return a > b + } + if a, b := userCapTokens(scored[i].policy), userCapTokens(scored[j].policy); a != b { + return a > b + } + if a, b := userCapBudgetUsd(scored[i].policy), userCapBudgetUsd(scored[j].policy); a != b { + return a > b + } + return scored[i].policy.CreatedAt.Before(scored[j].policy.CreatedAt) + }) + + winner := scored[0] + return &PolicySelectionResult{ + Allow: true, + SelectedPolicyID: winner.policy.ID, + AttributionGroupID: winner.attributionGroup, + WindowSeconds: winner.windowSeconds, + }, nil +} + +// filterApplicablePolicies returns the enabled policies that target +// the requested provider and have at least one of the caller's groups +// in their source_groups. Caller's group set is matched +// case-sensitively against policy.SourceGroups. +func filterApplicablePolicies(policies []*types.Policy, in PolicySelectionInput) []*types.Policy { + if len(policies) == 0 { + return nil + } + groupSet := make(map[string]struct{}, len(in.GroupIDs)) + for _, g := range in.GroupIDs { + if g != "" { + groupSet[g] = struct{}{} + } + } + out := make([]*types.Policy, 0, len(policies)) + for _, p := range policies { + if p == nil || !p.Enabled { + continue + } + if !sliceContains(p.DestinationProviderIDs, in.ProviderID) { + continue + } + if !anyGroupMatches(p.SourceGroups, groupSet) { + continue + } + out = append(out, p) + } + return out +} + +// candidate is the per-policy intermediate the selector ranks. A +// policy that's been exhausted on any enabled cap never makes it +// into this slice; the selector's deny envelope carries the latest +// exhaustion's reason out separately. +type candidate struct { + policy *types.Policy + attributionGroup string + windowSeconds int64 +} + +// scoreCandidates evaluates every applicable policy against the +// caller's current consumption. Exhausted policies are filtered out +// of the returned slice; the most recent exhaustion's deny code + +// human reason is returned alongside so the caller can surface it +// when no candidate survives. +func scoreCandidates( + in PolicySelectionInput, + candidates []*types.Policy, + cache consumptionCache, + now time.Time, +) ([]candidate, string, string) { + out := make([]candidate, 0, len(candidates)) + var lastDenyCode, lastDenyReason string + + for _, p := range candidates { + c, exhausted, denyCode, denyReason := scoreOne(in, p, cache, now) + if exhausted { + lastDenyCode = denyCode + lastDenyReason = denyReason + continue + } + out = append(out, c) + } + return out, lastDenyCode, lastDenyReason +} + +// scoreOne checks a single policy for cap exhaustion. Returns the +// candidate envelope when the policy still has headroom on every +// enabled cap; reports exhausted=true with a deny code naming the +// offending cap kind otherwise. +func scoreOne( + in PolicySelectionInput, + p *types.Policy, + cache consumptionCache, + now time.Time, +) (candidate, bool, string, string) { + attrGroup := lowestIntersect(p.SourceGroups, in.GroupIDs) + c := candidate{ + policy: p, + attributionGroup: attrGroup, + windowSeconds: effectiveWindowSeconds(p), + } + + if p.Limits.TokenLimit.Enabled && p.Limits.TokenLimit.WindowSeconds > 0 { + if exhausted, reason := evalTokenCap(cache, in.AccountID, in.UserID, attrGroup, p.Limits.TokenLimit, now, "policy "+p.ID); exhausted { + return candidate{}, true, denyCodeTokenCapExceeded, reason + } + } + + if p.Limits.BudgetLimit.Enabled && p.Limits.BudgetLimit.WindowSeconds > 0 { + if exhausted, reason := evalBudgetCap(cache, in.AccountID, in.UserID, attrGroup, p.Limits.BudgetLimit, now, "policy "+p.ID); exhausted { + return candidate{}, true, denyCodeBudgetCapExceeded, reason + } + } + + return c, false, "", "" +} + +// evalTokenCap reports whether the token limit is already exhausted for the +// caller in its own window. attrGroup may be empty (no group dimension applies). +// label identifies the cap source ("policy " or "account rule ") for the +// deny reason. It is the shared primitive behind both policy and account-rule +// enforcement. +func evalTokenCap( + cache consumptionCache, + accountID, userID, attrGroup string, + tl types.PolicyTokenLimit, + now time.Time, + label string, +) (bool, string) { + windowStart := types.WindowStart(now, tl.WindowSeconds) + + if tl.UserCap > 0 && userID != "" { + row := cache.get(accountID, types.DimensionUser, userID, tl.WindowSeconds, windowStart) + used := row.TokensInput + row.TokensOutput + if used >= tl.UserCap { + return true, fmt.Sprintf("user token cap exhausted on %s (used %d of %d)", label, used, tl.UserCap) + } + } + + if tl.GroupCap > 0 && attrGroup != "" { + row := cache.get(accountID, types.DimensionGroup, attrGroup, tl.WindowSeconds, windowStart) + used := row.TokensInput + row.TokensOutput + if used >= tl.GroupCap { + return true, fmt.Sprintf("group token cap exhausted on %s (used %d of %d)", label, used, tl.GroupCap) + } + } + + return false, "" +} + +// evalBudgetCap is the budget (USD) counterpart of evalTokenCap. +func evalBudgetCap( + cache consumptionCache, + accountID, userID, attrGroup string, + bl types.PolicyBudgetLimit, + now time.Time, + label string, +) (bool, string) { + windowStart := types.WindowStart(now, bl.WindowSeconds) + + if bl.UserCapUsd > 0 && userID != "" { + row := cache.get(accountID, types.DimensionUser, userID, bl.WindowSeconds, windowStart) + if row.CostUSD >= bl.UserCapUsd { + return true, fmt.Sprintf("user budget cap exhausted on %s (used $%.4f of $%.4f)", label, row.CostUSD, bl.UserCapUsd) + } + } + + if bl.GroupCapUsd > 0 && attrGroup != "" { + row := cache.get(accountID, types.DimensionGroup, attrGroup, bl.WindowSeconds, windowStart) + if row.CostUSD >= bl.GroupCapUsd { + return true, fmt.Sprintf("group budget cap exhausted on %s (used $%.4f of $%.4f)", label, row.CostUSD, bl.GroupCapUsd) + } + } + + return false, "" +} + +// checkAccountBudget evaluates every applicable account-level budget rule as an +// all-must-pass ceiling. A rule applies when the caller is in its TargetUsers, +// one of its TargetGroups, or it has no targets at all (account-wide). Returns +// deny=true with an llm_account.* code on the first exhausted rule. Group caps +// attribute to the lowest intersecting group (the same model policies use), so +// multi-group behavior is unchanged. +func checkAccountBudget(in PolicySelectionInput, rules []*types.AccountBudgetRule, cache consumptionCache, now time.Time) (bool, string, string) { + for _, r := range rules { + if r == nil || !r.Enabled || !budgetRuleApplies(r, in) { + continue + } + attrGroup := lowestIntersect(r.TargetGroups, in.GroupIDs) + label := "account rule " + r.ID + + if r.Limits.TokenLimit.Enabled && r.Limits.TokenLimit.WindowSeconds > 0 { + if exhausted, reason := evalTokenCap(cache, in.AccountID, in.UserID, attrGroup, r.Limits.TokenLimit, now, label); exhausted { + return true, denyCodeAccountTokenCapExceeded, reason + } + } + + if r.Limits.BudgetLimit.Enabled && r.Limits.BudgetLimit.WindowSeconds > 0 { + if exhausted, reason := evalBudgetCap(cache, in.AccountID, in.UserID, attrGroup, r.Limits.BudgetLimit, now, label); exhausted { + return true, denyCodeAccountBudgetCapExceeded, reason + } + } + } + + return false, "", "" +} + +// budgetRuleApplies reports whether an account budget rule binds the caller: +// a direct user match, a group intersection, or an untargeted (account-wide) +// rule. +func budgetRuleApplies(r *types.AccountBudgetRule, in PolicySelectionInput) bool { + if len(r.TargetUsers) == 0 && len(r.TargetGroups) == 0 { + return true + } + if in.UserID != "" && sliceContains(r.TargetUsers, in.UserID) { + return true + } + groupSet := make(map[string]struct{}, len(in.GroupIDs)) + for _, g := range in.GroupIDs { + if g != "" { + groupSet[g] = struct{}{} + } + } + return anyGroupMatches(r.TargetGroups, groupSet) +} + +// RecordAccountBudgetUsage fans the served request's usage out to every +// applicable account budget rule's own (dimension, window) counter. The user +// dimension is always booked when a rule has a user-applicable cap; the group +// dimension books against the rule's lowest intersecting group. This runs +// alongside the policy-window record so account ceilings accumulate in their own +// windows (commonly monthly) independently of the per-policy window. +func (m *managerImpl) RecordAccountBudgetUsage(ctx context.Context, accountID, userID string, groupIDs []string, tokensIn, tokensOut int64, costUSD float64) error { + if accountID == "" { + return status.Errorf(status.InvalidArgument, "account_id is required") + } + if err := validateUsageDeltas(tokensIn, tokensOut, costUSD); err != nil { + return err + } + rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return fmt.Errorf("list account budget rules: %w", err) + } + set := make(map[types.ConsumptionKey]struct{}) + addAccountBudgetKeys(set, PolicySelectionInput{AccountID: accountID, UserID: userID, GroupIDs: groupIDs}, rules, time.Now().UTC()) + if len(set) == 0 { + return nil + } + return m.store.IncrementAgentNetworkConsumptionBatch(ctx, accountID, keysSlice(set), tokensIn, tokensOut, costUSD) +} + +// RecordUsageInput carries everything RecordUsage books for one served request. +type RecordUsageInput struct { + AccountID string + UserID string + AttributionGroupID string // selected policy's attribution group (policy window) + GroupIDs []string + WindowSeconds int64 // selected policy's window; 0 means no policy cap + TokensIn int64 + TokensOut int64 + CostUSD float64 +} + +// RecordUsage books a served request's usage against every counter it touches — +// the selected policy's per-(user, group) window plus every applicable account +// budget rule's own window — deduplicated and written in a single transaction. +// Two counters that collapse to the same (dimension, window) tuple are booked +// once, so a single request can never double-count against one cap. +func (m *managerImpl) RecordUsage(ctx context.Context, in RecordUsageInput) error { + if in.AccountID == "" { + return status.Errorf(status.InvalidArgument, "account_id is required") + } + if err := validateUsageDeltas(in.TokensIn, in.TokensOut, in.CostUSD); err != nil { + return err + } + now := time.Now().UTC() + set := make(map[types.ConsumptionKey]struct{}) + + // Policy-window dimensions are booked only when a policy cap bound this + // request (window > 0). A zero window means catch-all-allow / no policy cap; + // the account fan-out below still books against the budget rules' windows. + if in.WindowSeconds > 0 { + addLimitKeys(set, in.UserID, in.AttributionGroupID, in.WindowSeconds, now) + } + + rules, err := m.store.GetAccountAgentNetworkBudgetRules(ctx, store.LockingStrengthNone, in.AccountID) + if err != nil { + return fmt.Errorf("list account budget rules: %w", err) + } + addAccountBudgetKeys(set, PolicySelectionInput{AccountID: in.AccountID, UserID: in.UserID, GroupIDs: in.GroupIDs}, rules, now) + + if len(set) == 0 { + return nil + } + return m.store.IncrementAgentNetworkConsumptionBatch(ctx, in.AccountID, keysSlice(set), in.TokensIn, in.TokensOut, in.CostUSD) +} + +// addAccountBudgetKeys adds the (dimension, window) keys a served request books +// against every applicable account budget rule into the dedup set. +func addAccountBudgetKeys(set map[types.ConsumptionKey]struct{}, in PolicySelectionInput, rules []*types.AccountBudgetRule, now time.Time) { + for _, r := range rules { + if r == nil || !r.Enabled || !budgetRuleApplies(r, in) { + continue + } + attrGroup := lowestIntersect(r.TargetGroups, in.GroupIDs) + for _, window := range ruleWindows(r) { + addLimitKeys(set, in.UserID, attrGroup, window, now) + } + } +} + +// keysSlice flattens a ConsumptionKey set into a slice. +func keysSlice(set map[types.ConsumptionKey]struct{}) []types.ConsumptionKey { + keys := make([]types.ConsumptionKey, 0, len(set)) + for k := range set { + keys = append(keys, k) + } + return keys +} + +// ruleWindows returns the distinct enabled window lengths a budget rule books +// against (token window and/or budget window, deduplicated). +func ruleWindows(r *types.AccountBudgetRule) []int64 { + var windows []int64 + if r.Limits.TokenLimit.Enabled && r.Limits.TokenLimit.WindowSeconds > 0 { + windows = append(windows, r.Limits.TokenLimit.WindowSeconds) + } + if r.Limits.BudgetLimit.Enabled && r.Limits.BudgetLimit.WindowSeconds > 0 { + bw := r.Limits.BudgetLimit.WindowSeconds + if len(windows) == 0 || windows[0] != bw { + windows = append(windows, bw) + } + } + return windows +} + +// effectiveWindowSeconds returns the window length the proxy should +// hand back to RecordLLMUsage. When both halves are enabled with +// different windows, token_limit wins (the more common config); when +// only one is enabled that one wins; when neither is enabled the +// returned value is 0 — RecordLLMUsage treats 0 as "no limit +// tracking" and skips the increment, which is the right pass-through +// for catch-all-allow policies with no caps configured. +func effectiveWindowSeconds(p *types.Policy) int64 { + if p.Limits.TokenLimit.Enabled && p.Limits.TokenLimit.WindowSeconds > 0 { + return p.Limits.TokenLimit.WindowSeconds + } + if p.Limits.BudgetLimit.Enabled && p.Limits.BudgetLimit.WindowSeconds > 0 { + return p.Limits.BudgetLimit.WindowSeconds + } + return 0 +} + +// lowestIntersect returns the lowest-by-string-sort element of +// callerGroups ∩ sourceGroups. Empty when the intersection is empty. +// Lowest is deterministic so multi-node selection converges. +func lowestIntersect(sourceGroups, callerGroups []string) string { + if len(sourceGroups) == 0 || len(callerGroups) == 0 { + return "" + } + srcSet := make(map[string]struct{}, len(sourceGroups)) + for _, g := range sourceGroups { + srcSet[g] = struct{}{} + } + var best string + for _, g := range callerGroups { + if _, ok := srcSet[g]; !ok { + continue + } + if best == "" || g < best { + best = g + } + } + return best +} + +func anyGroupMatches(sourceGroups []string, callerSet map[string]struct{}) bool { + for _, g := range sourceGroups { + if _, ok := callerSet[g]; ok { + return true + } + } + return false +} + +// isUncapped reports whether a policy has any enabled cap with a +// positive limit value. Mirrors the eval functions' guards: a policy +// with token_limit.enabled=true but every cap value at 0 still +// counts as uncapped because the eval would query nothing and bind +// nothing. +func isUncapped(p *types.Policy) bool { + tl := p.Limits.TokenLimit + if tl.Enabled && tl.WindowSeconds > 0 && (tl.GroupCap > 0 || tl.UserCap > 0) { + return false + } + bl := p.Limits.BudgetLimit + if bl.Enabled && bl.WindowSeconds > 0 && (bl.GroupCapUsd > 0 || bl.UserCapUsd > 0) { + return false + } + return true +} + +// groupCapTokens returns the policy's group-token cap when the token +// limit is enabled, zero otherwise. Drives the primary "bigger pool +// first" sort. +func groupCapTokens(p *types.Policy) int64 { + if p.Limits.TokenLimit.Enabled { + return p.Limits.TokenLimit.GroupCap + } + return 0 +} + +// groupCapBudgetUsd returns the policy's group-budget cap in USD +// when the budget limit is enabled, zero otherwise. Secondary sort +// key after token group cap so budget-only policies still order +// predictably. +func groupCapBudgetUsd(p *types.Policy) float64 { + if p.Limits.BudgetLimit.Enabled { + return p.Limits.BudgetLimit.GroupCapUsd + } + return 0 +} + +// userCapTokens returns the policy's per-user token cap when the +// token limit is enabled, zero otherwise. Tertiary sort key, used +// when group caps tie or are absent. +func userCapTokens(p *types.Policy) int64 { + if p.Limits.TokenLimit.Enabled { + return p.Limits.TokenLimit.UserCap + } + return 0 +} + +// userCapBudgetUsd returns the policy's per-user budget cap in USD +// when the budget limit is enabled, zero otherwise. Quaternary sort +// key for budget-only policies whose group caps tie or are absent. +func userCapBudgetUsd(p *types.Policy) float64 { + if p.Limits.BudgetLimit.Enabled { + return p.Limits.BudgetLimit.UserCapUsd + } + return 0 +} + +func sliceContains(haystack []string, needle string) bool { + for _, v := range haystack { + if v == needle { + return true + } + } + return false +} + +// mockManager fallback so tests that don't care about selection still +// compile. +func (*mockManager) SelectPolicyForRequest(_ context.Context, _ PolicySelectionInput) (*PolicySelectionResult, error) { + return &PolicySelectionResult{Allow: true}, nil +} diff --git a/management/internals/modules/agentnetwork/policyselect_account_realstore_test.go b/management/internals/modules/agentnetwork/policyselect_account_realstore_test.go new file mode 100644 index 000000000..c3b13a6cc --- /dev/null +++ b/management/internals/modules/agentnetwork/policyselect_account_realstore_test.go @@ -0,0 +1,181 @@ +package agentnetwork + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/management/server/store" +) + +// GC-2 no-mock enforcement tests for the account-budget ceiling. They drive the +// real store + real consumption accounting through SelectPolicyForRequest and +// RecordAccountBudgetUsage, asserting min-wins (account binds independently of +// policy), targeting (groups + direct users), and the record fan-out. + +func accountWideUserTokenRule(id string, userCap, window int64) *types.AccountBudgetRule { + r := types.NewAccountBudgetRule(realSelectAccount) + r.ID = id + r.Limits.TokenLimit = types.PolicyTokenLimit{Enabled: true, UserCap: userCap, WindowSeconds: window} + return r +} + +// TestSelectPolicy_RealStore_AccountCeilingBindsEvenWithUncappedPolicy proves +// min-wins: the account user ceiling denies once exhausted even though a +// catch-all-allow (uncapped) policy would otherwise pass the request. The +// account gate runs independently of and ahead of policy selection. +func TestSelectPolicy_RealStore_AccountCeilingBindsEvenWithUncappedPolicy(t *testing.T) { + mgr, s := newRealSelectorMgr(t) + ctx := context.Background() + + // An uncapped (catch-all-allow) policy: enabled token limit, zero caps. + uncapped := capPolicy("pol-open", realSelectAccount, []string{"grp-eng"}, "prov-1", 0, 86_400) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, uncapped)) + + // Account-wide user ceiling of 100 tokens in an hourly window. + require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-1", 100, 3_600))) + + in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"} + + // Fresh: account ceiling has headroom, uncapped policy wins. + res, err := mgr.SelectPolicyForRequest(ctx, in) + require.NoError(t, err) + assert.True(t, res.Allow, "fresh account ceiling must allow") + + // Drain the account user ceiling via the fan-out path. + require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", []string{"grp-eng"}, 100, 0, 0)) + + res, err = mgr.SelectPolicyForRequest(ctx, in) + require.NoError(t, err) + assert.False(t, res.Allow, "account ceiling must deny even though the policy is uncapped (min-wins)") + assert.Equal(t, denyCodeAccountTokenCapExceeded, res.DenyCode, "deny must carry the llm_account.* code") +} + +// TestSelectPolicy_RealStore_AccountGroupCeiling proves a group-targeted rule +// binds the caller's group dimension. +func TestSelectPolicy_RealStore_AccountGroupCeiling(t *testing.T) { + mgr, s := newRealSelectorMgr(t) + ctx := context.Background() + + rule := types.NewAccountBudgetRule(realSelectAccount) + rule.ID = "ainbud-grp" + rule.TargetGroups = []string{"grp-eng"} + rule.Limits.BudgetLimit = types.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 5.0, WindowSeconds: 2_592_000} + require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, rule)) + + in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"} + + res, err := mgr.SelectPolicyForRequest(ctx, in) + require.NoError(t, err) + assert.True(t, res.Allow, "fresh group ceiling must allow") + + require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", []string{"grp-eng"}, 0, 0, 5.0)) + + res, err = mgr.SelectPolicyForRequest(ctx, in) + require.NoError(t, err) + assert.False(t, res.Allow, "group budget ceiling must deny once spent") + assert.Equal(t, denyCodeAccountBudgetCapExceeded, res.DenyCode, "account budget deny code") +} + +// TestSelectPolicy_RealStore_AccountTargetUsersBindsOnlyThatUser proves a +// TargetUsers rule tightens only the named user, leaving others unbound. +func TestSelectPolicy_RealStore_AccountTargetUsersBindsOnlyThatUser(t *testing.T) { + mgr, s := newRealSelectorMgr(t) + ctx := context.Background() + + rule := types.NewAccountBudgetRule(realSelectAccount) + rule.ID = "ainbud-alice" + rule.TargetUsers = []string{"alice"} + rule.Limits.TokenLimit = types.PolicyTokenLimit{Enabled: true, UserCap: 100, WindowSeconds: 3_600} + require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, rule)) + + // Record alice's usage to the rule window. + require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "alice", nil, 100, 0, 0)) + + aliceIn := PolicySelectionInput{AccountID: realSelectAccount, UserID: "alice", ProviderID: "prov-1"} + res, err := mgr.SelectPolicyForRequest(ctx, aliceIn) + require.NoError(t, err) + assert.False(t, res.Allow, "alice is bound by the TargetUsers rule and is exhausted") + + bobIn := PolicySelectionInput{AccountID: realSelectAccount, UserID: "bob", ProviderID: "prov-1"} + res, err = mgr.SelectPolicyForRequest(ctx, bobIn) + require.NoError(t, err) + assert.True(t, res.Allow, "bob is not in TargetUsers, so the rule must not bind him") +} + +// TestSelectPolicy_RealStore_AccountRuleRecordsToOwnWindow proves the record +// fan-out books usage in the rule's own window (distinct from any policy +// window), so the account ceiling accumulates independently. +func TestSelectPolicy_RealStore_AccountRuleRecordsToOwnWindow(t *testing.T) { + mgr, s := newRealSelectorMgr(t) + ctx := context.Background() + + require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-w", 100, 3_600))) + + require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", nil, 60, 0, 0)) + + // Same user, a policy-style daily window must NOT see the account-window + // usage — windows are independent counters. + dailyRow, err := s.GetAgentNetworkConsumption(ctx, store.LockingStrengthNone, realSelectAccount, types.DimensionUser, "user-1", 86_400, types.WindowStart(time.Now().UTC(), 86_400)) + require.NoError(t, err) + assert.Equal(t, int64(0), dailyRow.TokensInput+dailyRow.TokensOutput, "daily window must be untouched by the hourly account-rule record") + + // A second record pushes the hourly account window to its cap → deny. + require.NoError(t, mgr.RecordAccountBudgetUsage(ctx, realSelectAccount, "user-1", nil, 40, 0, 0)) + res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", ProviderID: "prov-1"}) + require.NoError(t, err) + assert.False(t, res.Allow, "100 tokens recorded in the rule's hourly window must exhaust the 100-token ceiling") + assert.Equal(t, denyCodeAccountTokenCapExceeded, res.DenyCode, "account token deny code") +} + +// TestRecordUsage_RealStore_BooksPolicyAndAccountWindows proves the batched +// post-flight write books the selected policy's window AND every applicable +// account rule's (independent) window in a single call — the #6 batched-write +// path the proxy's RecordLLMUsage RPC now uses. +func TestRecordUsage_RealStore_BooksPolicyAndAccountWindows(t *testing.T) { + mgr, s := newRealSelectorMgr(t) + ctx := context.Background() + + // Policy: 100-token group cap on a daily window. Account rule: 100-token + // user ceiling on an hourly window — an independent counter. + policy := capPolicy("pol-1", realSelectAccount, []string{"grp-eng"}, "prov-1", 100, 86_400) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, policy)) + require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, accountWideUserTokenRule("ainbud-1", 100, 3_600))) + + in := PolicySelectionInput{AccountID: realSelectAccount, UserID: "user-1", GroupIDs: []string{"grp-eng"}, ProviderID: "prov-1"} + res, err := mgr.SelectPolicyForRequest(ctx, in) + require.NoError(t, err) + require.True(t, res.Allow) + require.Equal(t, "pol-1", res.SelectedPolicyID) + + // One batched record books the policy window (group + user @86400) and the + // account rule window (user @3600) atomically. + require.NoError(t, mgr.RecordUsage(ctx, RecordUsageInput{ + AccountID: realSelectAccount, + UserID: "user-1", + AttributionGroupID: res.AttributionGroupID, + GroupIDs: []string{"grp-eng"}, + WindowSeconds: res.WindowSeconds, + TokensIn: 100, + })) + + // The next selection denies — the account hourly ceiling binds first. + res, err = mgr.SelectPolicyForRequest(ctx, in) + require.NoError(t, err) + assert.False(t, res.Allow, "usage booked by RecordUsage must enforce on the next request") + + // Prove BOTH windows were booked in the one call via a direct batch read. + now := time.Now().UTC() + userKey := types.ConsumptionKey{Kind: types.DimensionUser, DimID: "user-1", WindowSeconds: 3_600, WindowStartUTC: types.WindowStart(now, 3_600)} + groupKey := types.ConsumptionKey{Kind: types.DimensionGroup, DimID: "grp-eng", WindowSeconds: 86_400, WindowStartUTC: types.WindowStart(now, 86_400)} + rows, err := s.GetAgentNetworkConsumptionBatch(ctx, store.LockingStrengthNone, realSelectAccount, []types.ConsumptionKey{userKey, groupKey}) + require.NoError(t, err) + require.Contains(t, rows, userKey, "account rule user/hourly window booked") + require.Contains(t, rows, groupKey, "policy group/daily window booked") + assert.Equal(t, int64(100), rows[userKey].TokensInput, "account hourly user counter") + assert.Equal(t, int64(100), rows[groupKey].TokensInput, "policy daily group counter") +} diff --git a/management/internals/modules/agentnetwork/policyselect_realstore_test.go b/management/internals/modules/agentnetwork/policyselect_realstore_test.go new file mode 100644 index 000000000..cc8cfb1e7 --- /dev/null +++ b/management/internals/modules/agentnetwork/policyselect_realstore_test.go @@ -0,0 +1,214 @@ +package agentnetwork + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/management/server/store" +) + +// This file is the no-mock regression guard for policy limit enforcement. +// policyselect_test.go pins the same behavior through a gomock store with +// explicit call-sequence expectations — brittle precisely where the upcoming +// account-budget work (GC-2) refactors the cap-eval primitive and adds an +// account-level gate. These tests drive the REAL sqlite store + REAL +// consumption accounting and assert observable behavior (allow / deny / +// selection / attribution), not which store methods get called. They must keep +// passing unchanged after GC-2 lands, which is what proves "current behavior is +// not changed." + +const realSelectAccount = "acc-realselect-1" + +// newRealSelectorMgr builds a managerImpl backed by a real sqlite test store. +func newRealSelectorMgr(t *testing.T) (*managerImpl, store.Store) { + t.Helper() + ctx := context.Background() + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + t.Cleanup(cleanup) + return &managerImpl{store: s}, s +} + +// TestSelectPolicy_RealStore_NoApplicablePolicies pins the pass-through: +// nothing targets the (provider, groups) combination, so the selector allows +// without attribution or consumption tracking. +func TestSelectPolicy_RealStore_NoApplicablePolicies(t *testing.T) { + mgr, _ := newRealSelectorMgr(t) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: realSelectAccount, + UserID: "user-1", + GroupIDs: []string{"grp-x"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.True(t, res.Allow, "no applicable policy must pass through as allow") + assert.Empty(t, res.SelectedPolicyID, "no selection when nothing applies") +} + +// TestSelectPolicy_RealStore_AllowAndLowestGroupAttribution pins the v1 +// attribution rule (lowest intersecting group by string sort) through the +// real store, with a fresh (zero) consumption row. +func TestSelectPolicy_RealStore_AllowAndLowestGroupAttribution(t *testing.T) { + mgr, s := newRealSelectorMgr(t) + ctx := context.Background() + + p := capPolicy("pol-A", realSelectAccount, []string{"grp-zz", "grp-aa", "grp-mm"}, "prov-1", 10_000, 86_400) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p)) + + res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{ + AccountID: realSelectAccount, + UserID: "user-1", + GroupIDs: []string{"grp-zz", "grp-aa", "grp-mm"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.True(t, res.Allow, "fresh state under cap must allow") + assert.Equal(t, "pol-A", res.SelectedPolicyID, "only applicable policy must be selected") + assert.Equal(t, "grp-aa", res.AttributionGroupID, "lowest-by-sort intersecting group must win") + assert.Equal(t, int64(86_400), res.WindowSeconds, "selected policy's window must be returned") +} + +// TestSelectPolicy_RealStore_LargerPoolWins_FallsThroughWhenExhausted pins the +// core selection behavior end to end. The two policies bind DISTINCT groups so +// they read separate counters — the only shape where fall-through actually +// yields headroom (policies on the same group share one counter, as +// policyselect_test.go notes). Larger pool wins fresh; after real consumption +// drains the larger group, selection falls through to the smaller; once both +// counters are exhausted the request is denied. +func TestSelectPolicy_RealStore_LargerPoolWins_FallsThroughWhenExhausted(t *testing.T) { + mgr, s := newRealSelectorMgr(t) + ctx := context.Background() + + tight := capPolicy("pol-tight", realSelectAccount, []string{"grp-tight"}, "prov-1", 100, 86_400) + tight.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + wide := capPolicy("pol-wide", realSelectAccount, []string{"grp-wide"}, "prov-1", 10_000, 86_400) + wide.CreatedAt = time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, tight)) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, wide)) + + // Caller is in both groups, so both policies apply with independent counters. + in := PolicySelectionInput{ + AccountID: realSelectAccount, + UserID: "user-1", + GroupIDs: []string{"grp-tight", "grp-wide"}, + ProviderID: "prov-1", + } + + // Fresh: larger pool wins. + res, err := mgr.SelectPolicyForRequest(ctx, in) + require.NoError(t, err) + assert.Equal(t, "pol-wide", res.SelectedPolicyID, "larger pool drains first") + + // Drain only the wide group's counter to its cap. + require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-wide", 86_400, 10_000, 0, 0)) + + // Wide exhausted, tight's separate counter is fresh → fall through to tight. + res, err = mgr.SelectPolicyForRequest(ctx, in) + require.NoError(t, err) + assert.True(t, res.Allow, "tight pool has its own untouched counter") + assert.Equal(t, "pol-tight", res.SelectedPolicyID, "selection falls through to the smaller pool once the larger is exhausted") + + // Drain the tight group's counter too → both exhausted → deny. + require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-tight", 86_400, 100, 0, 0)) + res, err = mgr.SelectPolicyForRequest(ctx, in) + require.NoError(t, err) + assert.False(t, res.Allow, "both group counters exhausted must deny") + assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode, "deny code names the offending cap kind") +} + +// TestSelectPolicy_RealStore_BudgetCapDenies pins budget (USD) enforcement +// through the real store: once recorded cost reaches the cap, deny. +func TestSelectPolicy_RealStore_BudgetCapDenies(t *testing.T) { + mgr, s := newRealSelectorMgr(t) + ctx := context.Background() + + p := &types.Policy{ + ID: "pol-budget", + AccountID: realSelectAccount, + Enabled: true, + SourceGroups: []string{"grp-eng"}, + DestinationProviderIDs: []string{"prov-1"}, + Limits: types.PolicyLimits{ + BudgetLimit: types.PolicyBudgetLimit{ + Enabled: true, + GroupCapUsd: 5.0, + WindowSeconds: 86_400, + }, + }, + CreatedAt: time.Now().UTC(), + } + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p)) + + in := PolicySelectionInput{ + AccountID: realSelectAccount, + UserID: "user-1", + GroupIDs: []string{"grp-eng"}, + ProviderID: "prov-1", + } + + res, err := mgr.SelectPolicyForRequest(ctx, in) + require.NoError(t, err) + assert.True(t, res.Allow, "fresh budget must allow") + + require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-eng", 86_400, 0, 0, 5.0)) + + res, err = mgr.SelectPolicyForRequest(ctx, in) + require.NoError(t, err) + assert.False(t, res.Allow, "cost at the cap must deny") + assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode, "budget deny code must be surfaced") +} + +// TestSelectPolicy_RealStore_GroupCounterSharedAcrossPolicies pins that two +// policies on the same group+window read one shared consumption counter: usage +// recorded once is visible to both, so exhausting the group budget denies +// regardless of which policy would attribute. +func TestSelectPolicy_RealStore_GroupCounterSharedAcrossPolicies(t *testing.T) { + mgr, s := newRealSelectorMgr(t) + ctx := context.Background() + + a := capPolicy("pol-a", realSelectAccount, []string{"grp-eng"}, "prov-1", 1_000, 86_400) + b := capPolicy("pol-b", realSelectAccount, []string{"grp-eng"}, "prov-1", 1_000, 86_400) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, a)) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, b)) + + in := PolicySelectionInput{ + AccountID: realSelectAccount, + UserID: "user-1", + GroupIDs: []string{"grp-eng"}, + ProviderID: "prov-1", + } + + require.NoError(t, mgr.RecordConsumption(ctx, realSelectAccount, types.DimensionGroup, "grp-eng", 86_400, 1_000, 0, 0)) + + res, err := mgr.SelectPolicyForRequest(ctx, in) + require.NoError(t, err) + assert.False(t, res.Allow, "shared group counter at cap denies both equal policies") + assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode, "token deny code on the shared counter") +} + +// TestSelectPolicy_RealStore_DisabledPolicyIgnored pins that a disabled policy +// is invisible to selection even when it otherwise matches. +func TestSelectPolicy_RealStore_DisabledPolicyIgnored(t *testing.T) { + mgr, s := newRealSelectorMgr(t) + ctx := context.Background() + + p := capPolicy("pol-disabled", realSelectAccount, []string{"grp-eng"}, "prov-1", 10_000, 86_400) + p.Enabled = false + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, p)) + + res, err := mgr.SelectPolicyForRequest(ctx, PolicySelectionInput{ + AccountID: realSelectAccount, + UserID: "user-1", + GroupIDs: []string{"grp-eng"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.True(t, res.Allow, "no enabled policy applies → pass-through allow") + assert.Empty(t, res.SelectedPolicyID, "disabled policy must not be selected") +} diff --git a/management/internals/modules/agentnetwork/policyselect_test.go b/management/internals/modules/agentnetwork/policyselect_test.go new file mode 100644 index 000000000..dd7687fe1 --- /dev/null +++ b/management/internals/modules/agentnetwork/policyselect_test.go @@ -0,0 +1,641 @@ +package agentnetwork + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/management/server/store" + nbstatus "github.com/netbirdio/netbird/shared/management/status" +) + +func newSelectorMgr(t *testing.T, ctrl *gomock.Controller) (*managerImpl, *store.MockStore) { + t.Helper() + mockStore := store.NewMockStore(ctrl) + // SelectPolicyForRequest evaluates the account-budget ceiling before policy + // selection. These policy-selection tests don't exercise account rules, so + // default to "no rules" — the no-mock policyselect_realstore_test.go covers + // the account gate's behavior end to end. + mockStore.EXPECT(). + GetAccountAgentNetworkBudgetRules(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, nil). + AnyTimes() + return &managerImpl{store: mockStore}, mockStore +} + +type usedKey struct { + kind types.ConsumptionDimension + dimID string + window int64 +} + +// expectConsumptionBatch stubs the batched consumption read to return the +// supplied per-(kind, dim, window) counters, filling each row's window start +// from the actual request keys so it always matches what the selector computed. +// Keys absent from used resolve to zero counters. +func expectConsumptionBatch(mockStore *store.MockStore, used map[usedKey]*types.Consumption) { + mockStore.EXPECT(). + GetAgentNetworkConsumptionBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, _ store.LockingStrength, _ string, keys []types.ConsumptionKey) (map[types.ConsumptionKey]*types.Consumption, error) { + out := make(map[types.ConsumptionKey]*types.Consumption) + for _, k := range keys { + if row, ok := used[usedKey{k.Kind, k.DimID, k.WindowSeconds}]; ok { + rc := *row + rc.WindowStartUTC = k.WindowStartUTC + out[k] = &rc + } + } + return out, nil + }). + AnyTimes() +} + +func capPolicy(id, account string, sourceGroups []string, providerID string, tokenCap int64, windowSec int64) *types.Policy { + return &types.Policy{ + ID: id, + AccountID: account, + Enabled: true, + SourceGroups: sourceGroups, + DestinationProviderIDs: []string{providerID}, + Limits: types.PolicyLimits{ + TokenLimit: types.PolicyTokenLimit{ + Enabled: true, + GroupCap: tokenCap, + WindowSeconds: windowSec, + }, + }, + CreatedAt: time.Now().UTC(), + } +} + +// TestSelectPolicy_NoApplicablePolicies covers the pass-through path: +// llm_router authorisation is upstream of selection; when the +// selector finds no policy targeting the (provider, caller-groups) +// combination, it returns Allow with no attribution and lets the +// request continue without consumption tracking. +func TestSelectPolicy_NoApplicablePolicies(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return([]*types.Policy{}, nil) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + UserID: "user-1", + GroupIDs: []string{"grp-x"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.True(t, res.Allow, "no applicable policies = pass-through allow") + assert.Empty(t, res.SelectedPolicyID, "no selection when nothing applies") +} + +// TestSelectPolicy_AllowWithLowestGroupAttribution proves the v1 +// attribution rule: when the caller's groups intersect a policy's +// source_groups in multiple positions, the selector picks the lowest +// group id by string sort so multi-node selection converges. +func TestSelectPolicy_AllowWithLowestGroupAttribution(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + policy := capPolicy("pol-A", "acc-1", []string{"grp-zz", "grp-aa", "grp-mm"}, "prov-1", 10_000, 86_400) + + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return([]*types.Policy{policy}, nil) + // Fresh: zero consumption across the board. + expectConsumptionBatch(mockStore, nil) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + UserID: "user-1", + GroupIDs: []string{"grp-zz", "grp-aa", "grp-mm"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.True(t, res.Allow) + assert.Equal(t, "pol-A", res.SelectedPolicyID) + assert.Equal(t, "grp-aa", res.AttributionGroupID, + "lowest-by-sort intersection wins so multi-node selection converges") + assert.Equal(t, int64(86_400), res.WindowSeconds) +} + +// TestSelectPolicy_LargerPoolWinsAcrossUsageLevels proves the core +// selection rule: among multiple applicable policies with caps, the +// selector picks the one with the larger absolute pool — at every +// usage level, not just at fresh state. The smaller-pool policy is +// only reached when the larger one is exhausted. This is the +// "drain biggest first" semantic operators expect for layered +// tiers; a fraction-based score would flap between the two as +// soon as one is partially used. +func TestSelectPolicy_LargerPoolWinsAcrossUsageLevels(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + tight := capPolicy("pol-tight", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400) + tight.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + wide := capPolicy("pol-wide", "acc-1", []string{"grp-engineers"}, "prov-1", 10_000, 86_400) + wide.CreatedAt = time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC) + + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return([]*types.Policy{tight, wide}, nil) + + // Both partially used. tight at 50/100 (50% used); wide at + // 50/10000 (0.5% used). Old fraction-based algo would pick wide + // here too — but for the wrong reason ("more relative slack"). + // New algo picks wide because its initial group cap is bigger + // (10000 > 100), and that decision is stable as wide drains. + expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{ + {types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 50}, + }) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + UserID: "user-1", + GroupIDs: []string{"grp-engineers"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.Equal(t, "pol-wide", res.SelectedPolicyID, + "the policy with the bigger initial pool wins — operators expect 'drain the privileged tier first', not load-balance across tiers") +} + +// TestSelectPolicy_StaysOnLargerPoolAfterPartialDrain locks the +// stickiness contract reported by operators: with two policies +// where A has a 200-token group cap and B has 150, the very first +// request goes to A AND every subsequent request continues to land +// on A until A's group cap is exhausted — at which point B becomes +// the only candidate. A fraction-based score would flap to B as +// soon as A had any consumption (B's 1.0 fraction beats A's 0.75) +// even though A still has more absolute headroom; that produced +// confusing per-policy attribution ledger entries and stranded +// A's remaining capacity behind B's exhaustion. +func TestSelectPolicy_StaysOnLargerPoolAfterPartialDrain(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + policyA := capPolicy("pol-A-200", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400) + policyB := capPolicy("pol-B-150", "acc-1", []string{"grp-engineers"}, "prov-1", 150, 86_400) + + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return([]*types.Policy{policyA, policyB}, nil) + + // A is partially drained (50/200 used = 25% used; 75% headroom + // remaining). B is fresh (0/150). The old fraction-based score + // would pick B here (1.0 > 0.75 fraction); the new pool-size + // score sticks with A (200 > 150 absolute cap). + expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{ + {types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 50}, + }) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + UserID: "user-1", + GroupIDs: []string{"grp-engineers"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.Equal(t, "pol-A-200", res.SelectedPolicyID, + "once attribution lands on the bigger pool it must STAY there until exhausted — operators expect 'drain A then B', not 'flip to B as soon as A is touched'") +} + +// TestSelectPolicy_FallsThroughToSmallerPoolWhenLargerExhausted +// proves the second half of the stickiness contract: once the +// larger-pool policy IS exhausted, the smaller one takes over. +// Without this we'd deny on requests the smaller policy is fully +// equipped to serve. +func TestSelectPolicy_FallsThroughToSmallerPoolWhenLargerExhausted(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + policyA := capPolicy("pol-A-200", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400) + // B uses a different window length so it has an INDEPENDENT counter — the + // realistic shape for fall-through. On the SAME (group, window) tuple the + // counter is shared, so A's cap of 200 being reached would also exhaust B's + // 150; independent counters are what let A exhaust while B retains headroom. + policyB := capPolicy("pol-B-150", "acc-1", []string{"grp-engineers"}, "prov-1", 150, 3_600) + + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return([]*types.Policy{policyA, policyB}, nil) + + expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{ + {types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 200}, // A: 200 >= 200 → exhausted + {types.DimensionGroup, "grp-engineers", 3_600}: {TokensInput: 100}, // B: 100 < 150 → headroom + }) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + UserID: "user-1", + GroupIDs: []string{"grp-engineers"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.Equal(t, "pol-B-150", res.SelectedPolicyID, + "once the bigger pool is exhausted, the smaller one must take over — denying when capacity remains would strand B's allowance") +} + +// TestSelectPolicy_TiebreakByLargerGroupPool covers the user-reported +// bug: an admin in two groups (Users + Admins) where Users is bound +// by a smaller-group-cap policy (50 group, 100 user) and Admins is +// bound by a bigger-group-cap policy (100 group, 20 user) MUST get +// attributed to the Admins policy on the first request. +// +// Without this rule, the fresh-state fraction is 1.0 for both and +// the older policy wins by created_at. The first 24-token request +// then drains the shared user counter past Admins's tight 20-token +// user cap, locking Admins out of selection forever. The 100-token +// Admins group pool ends up stranded while requests pile onto the +// 50-token Users pool — the opposite of what the operator intended +// when they put the bigger pool on the privileged group. +func TestSelectPolicy_TiebreakByLargerGroupPool(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + // Policy A: Users group, smaller group pool, looser per-user cap. + policyA := &types.Policy{ + ID: "pol-Users", + AccountID: "acc-1", + Enabled: true, + SourceGroups: []string{"grp-Users"}, + DestinationProviderIDs: []string{"prov-1"}, + Limits: types.PolicyLimits{ + TokenLimit: types.PolicyTokenLimit{ + Enabled: true, GroupCap: 50, UserCap: 100, WindowSeconds: 86_400, + }, + }, + // Older — would win the legacy created_at tiebreak. + CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC), + } + // Policy B: Admins group, bigger group pool, tighter per-user cap. + policyB := &types.Policy{ + ID: "pol-Admins", + AccountID: "acc-1", + Enabled: true, + SourceGroups: []string{"grp-Admins"}, + DestinationProviderIDs: []string{"prov-1"}, + Limits: types.PolicyLimits{ + TokenLimit: types.PolicyTokenLimit{ + Enabled: true, GroupCap: 100, UserCap: 20, WindowSeconds: 86_400, + }, + }, + CreatedAt: time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC), + } + + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return([]*types.Policy{policyA, policyB}, nil) + // Fresh state: every cap evaluation reads zero usage. + expectConsumptionBatch(mockStore, nil) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + UserID: "user-1", + GroupIDs: []string{"grp-Users", "grp-Admins"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.Equal(t, "pol-Admins", res.SelectedPolicyID, + "the bigger group pool wins the fresh-state tiebreak — picking Users first would burn the shared user counter past Admins's tight user cap on the very first request and strand the bigger Admins pool") + assert.Equal(t, "grp-Admins", res.AttributionGroupID) +} + +// TestSelectPolicy_TiebreakByCreatedAt proves the deterministic +// final tiebreak: when two applicable policies have the same +// headroom fraction AND the same group cap (so the larger-pool rule +// can't differentiate either), the older policy wins so attribution +// is stable across replays. +func TestSelectPolicy_TiebreakByCreatedAt(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + older := capPolicy("pol-old", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400) + older.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + newer := capPolicy("pol-new", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400) + newer.CreatedAt = time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC) + + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return([]*types.Policy{newer, older}, nil) + // Both at zero consumption → identical headroom fraction. + expectConsumptionBatch(mockStore, nil) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + GroupIDs: []string{"grp-engineers"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.Equal(t, "pol-old", res.SelectedPolicyID, + "older policy wins on equal-headroom tiebreak so attribution is stable across replays") +} + +// TestSelectPolicy_DeniesWhenAllExhausted proves the deny envelope: +// when every applicable policy has at least one cap fully exhausted, +// the selector returns Allow=false with the most-recent exhaustion's +// deny code + human reason. The proxy's middleware surfaces this as +// a 403 with the canonical llm_policy.* code. +func TestSelectPolicy_DeniesWhenAllExhausted(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + a := capPolicy("pol-a", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400) + b := capPolicy("pol-b", "acc-1", []string{"grp-engineers"}, "prov-1", 200, 86_400) + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return([]*types.Policy{a, b}, nil) + + // Shared group counter at 200: A (cap 100) and B (cap 200) both exhausted. + expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{ + {types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 200}, + }) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + GroupIDs: []string{"grp-engineers"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.False(t, res.Allow, "every applicable policy exhausted = deny") + assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode) + assert.Contains(t, res.DenyReason, "token cap exhausted", + "deny reason must name the exhausted cap kind for operator debugging") +} + +// TestSelectPolicy_UncappedPolicyAlwaysWinsAgainstCapped proves the +// catch-all-allow contract: a policy with NO enabled caps wins +// against any capped policy regardless of how much headroom the +// capped one has, because operators who configure unlimited access +// expect requests to attribute there until they explicitly add caps. +func TestSelectPolicy_UncappedPolicyAlwaysWinsAgainstCapped(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + uncapped := &types.Policy{ + ID: "pol-uncapped", + AccountID: "acc-1", + Enabled: true, + SourceGroups: []string{"grp-engineers"}, + DestinationProviderIDs: []string{"prov-1"}, + // All Limits.*.Enabled = false (zero-value). + CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC), + } + wide := capPolicy("pol-wide", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000_000, 86_400) + wide.CreatedAt = time.Date(2025, 12, 1, 0, 0, 0, 0, time.UTC) // older than uncapped + + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return([]*types.Policy{uncapped, wide}, nil) + // Only the wide policy reads consumption; uncapped doesn't query + // because it has no enabled caps. + expectConsumptionBatch(mockStore, nil) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + GroupIDs: []string{"grp-engineers"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.Equal(t, "pol-uncapped", res.SelectedPolicyID, + "a no-caps policy must always win selection — that's how operators express 'unlimited access through this path'") + assert.Equal(t, int64(0), res.WindowSeconds, "no caps configured = WindowSeconds=0 so RecordLLMUsage skips counter writes") +} + +// TestSelectPolicy_DisabledPolicyIgnored proves disabled policies +// don't count toward selection — even when they'd otherwise be the +// best match. Operators disable a policy to take it offline; the +// selector must respect that and route through whatever's left. +func TestSelectPolicy_DisabledPolicyIgnored(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + disabled := capPolicy("pol-disabled", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000_000, 86_400) + disabled.Enabled = false + enabled := capPolicy("pol-enabled", "acc-1", []string{"grp-engineers"}, "prov-1", 100, 86_400) + + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return([]*types.Policy{disabled, enabled}, nil) + expectConsumptionBatch(mockStore, nil) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + GroupIDs: []string{"grp-engineers"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.Equal(t, "pol-enabled", res.SelectedPolicyID, + "disabled policies must be ignored at selection time") +} + +// TestSelectPolicy_StoreErrorPropagates locks the no-fail-open +// contract: a transient store error must surface to the caller, not +// be silently treated as "no policies = allow". A false allow on the +// hot path would let a request slip past every cap. +func TestSelectPolicy_StoreErrorPropagates(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return(nil, errors.New("boom")) + + _, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + }) + require.Error(t, err, "store errors must surface — never fail open on the hot path") +} + +// TestSelectPolicy_RejectsEmptyAccount is the input-validation guard: +// empty account_id is a programmer error and must surface as +// InvalidArgument, not as a silent zero-result lookup. +func TestSelectPolicy_RejectsEmptyAccount(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, _ := newSelectorMgr(t, ctrl) + + _, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{}) + require.Error(t, err) + var sErr *nbstatus.Error + require.True(t, errors.As(err, &sErr)) + assert.Equal(t, nbstatus.InvalidArgument, sErr.Type()) +} + +// TestSelectPolicy_SharesGroupCounterAcrossPolicies locks the +// counter-keying design fork: counters are keyed on (account, +// dim_kind, dim_id, window_hours, window_start) — NOT on policy_id. +// Two policies that target the same group with the SAME window length +// share one bucket: spend booked under policy A is visible to policy +// B's headroom calculation and counts toward B's cap. +// +// This is what makes "operator's per-group enforcement" sane — caps +// describe how much a GROUP can use, not how much each policy owes. +func TestSelectPolicy_SharesGroupCounterAcrossPolicies(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + // Two policies, both targeting grp-engineers + prov-1, same 24h + // window length. Different cap sizes. + policyA := capPolicy("pol-A", "acc-1", []string{"grp-engineers"}, "prov-1", 1_000, 86_400) + policyB := capPolicy("pol-B", "acc-1", []string{"grp-engineers"}, "prov-1", 5_000, 86_400) + + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return([]*types.Policy{policyA, policyB}, nil) + // Both policies query the SAME consumption row — same dim_id, + // same window_hours, same window_start. The mock returns the + // same row for both calls, simulating the shared counter. + expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{ + {types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 800}, + }) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + GroupIDs: []string{"grp-engineers"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + // 800 used → policy A has 200 tokens left of 1000 (20% headroom); + // policy B has 4200 left of 5000 (84% headroom). B wins. + assert.Equal(t, "pol-B", res.SelectedPolicyID, + "the SAME 800 tokens count toward both policies — counters share the (group, window) key, caps differ per policy") +} + +// TestSelectPolicy_AntiFallThroughOnLowestGroup locks the no-fall- +// through behaviour: when a caller is in multiple of a policy's +// source_groups and the lowest-by-sort group is exhausted, we DENY +// rather than fall through to a less-loaded sibling. Per-group caps +// are independent (each group has its own bucket), but attribution +// is one-shot — operators wanting fall-through must split into +// separate policies. +// +// This nails down semantics future contributors might "improve" into +// fall-through behaviour by accident. +func TestSelectPolicy_AntiFallThroughOnLowestGroup(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + // Policy targets two groups; caller is in both. + policy := capPolicy("pol-1", "acc-1", []string{"grp-aaa", "grp-bbb"}, "prov-1", 100, 86_400) + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return([]*types.Policy{policy}, nil) + + // grp-aaa is the lowest by sort → attribution picks it, and the + // prefetch only collects the attribution group's key. We exhaust + // grp-aaa (100/100); grp-bbb's counter is never requested because the + // selector attributes one-shot to the lowest group, so it can't fall + // through to a less-loaded sibling. + expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{ + {types.DimensionGroup, "grp-aaa", 86_400}: {TokensInput: 100}, + }) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + GroupIDs: []string{"grp-aaa", "grp-bbb"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.False(t, res.Allow, + "lowest-group-by-sort attribution does NOT fall through to a less-loaded sibling — operators wanting fall-through must split into separate policies") + assert.Equal(t, denyCodeTokenCapExceeded, res.DenyCode) + assert.Contains(t, res.DenyReason, "pol-1", + "deny reason names the exhausted policy id so operators can grep it from the access log") +} + +// TestSelectPolicy_BudgetOnlyExhaustionDenies covers the symmetric +// path to TestSelectPolicy_DeniesWhenAllExhausted but for the budget +// cap: a policy with token_limit DISABLED and budget_limit at-cap +// must deny with llm_policy.budget_cap_exceeded (not the token code). +// +// Without this, the budget evaluation path in evalBudgetCap could +// silently regress and we'd still pass DeniesWhenAllExhausted (which +// only exercises tokens). +func TestSelectPolicy_BudgetOnlyExhaustionDenies(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + policy := &types.Policy{ + ID: "pol-budget", + AccountID: "acc-1", + Enabled: true, + SourceGroups: []string{"grp-engineers"}, + DestinationProviderIDs: []string{"prov-1"}, + Limits: types.PolicyLimits{ + TokenLimit: types.PolicyTokenLimit{Enabled: false}, + BudgetLimit: types.PolicyBudgetLimit{ + Enabled: true, + GroupCapUsd: 10.00, + WindowSeconds: 86_400, + }, + }, + CreatedAt: time.Now().UTC(), + } + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return([]*types.Policy{policy}, nil) + expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{ + {types.DimensionGroup, "grp-engineers", 86_400}: {CostUSD: 10.50}, // over the $10 cap + }) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + GroupIDs: []string{"grp-engineers"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.False(t, res.Allow, "budget cap exhausted must deny independently of any token cap state") + assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode, + "deny code must be the budget code — token-only deny would silently regress the budget evaluation path") + assert.Contains(t, res.DenyReason, "budget", "deny reason names the budget cap kind for operator debugging") +} + +// TestSelectPolicy_BudgetTighterThanTokenWins is the dual-cap headroom +// fork: when both Token and Budget are enabled on the same policy, +// the SMALLER remaining ratio gates the policy. A policy with +// abundant token headroom but near-zero budget headroom must deny on +// budget, not pass on tokens. +func TestSelectPolicy_BudgetTighterThanTokenWins(t *testing.T) { + ctrl := gomock.NewController(t) + mgr, mockStore := newSelectorMgr(t, ctrl) + + policy := &types.Policy{ + ID: "pol-dual", + AccountID: "acc-1", + Enabled: true, + SourceGroups: []string{"grp-engineers"}, + DestinationProviderIDs: []string{"prov-1"}, + Limits: types.PolicyLimits{ + TokenLimit: types.PolicyTokenLimit{Enabled: true, GroupCap: 10_000_000, WindowSeconds: 86_400}, + BudgetLimit: types.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 1.00, WindowSeconds: 86_400}, + }, + CreatedAt: time.Now().UTC(), + } + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(gomock.Any(), gomock.Any(), "acc-1"). + Return([]*types.Policy{policy}, nil) + // One shared counter carries both token usage (ample headroom) and cost + // (at the $1 budget cap); the tighter budget cap gates the policy. + expectConsumptionBatch(mockStore, map[usedKey]*types.Consumption{ + {types.DimensionGroup, "grp-engineers", 86_400}: {TokensInput: 100, CostUSD: 1.00}, + }) + + res, err := mgr.SelectPolicyForRequest(context.Background(), PolicySelectionInput{ + AccountID: "acc-1", + GroupIDs: []string{"grp-engineers"}, + ProviderID: "prov-1", + }) + require.NoError(t, err) + assert.False(t, res.Allow, + "the tighter of (token, budget) wins — abundant token headroom must NOT mask an exhausted budget") + assert.Equal(t, denyCodeBudgetCapExceeded, res.DenyCode) +} diff --git a/management/internals/modules/agentnetwork/reconcile.go b/management/internals/modules/agentnetwork/reconcile.go new file mode 100644 index 000000000..319553ebc --- /dev/null +++ b/management/internals/modules/agentnetwork/reconcile.go @@ -0,0 +1,131 @@ +package agentnetwork + +import ( + "context" + + log "github.com/sirupsen/logrus" + + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// reconcile recomputes the synthesised reverse-proxy services for an +// account, diffs them against the previously-synthesised set in the +// in-memory cache, and emits Create / Update / Delete proxy mappings +// to the affected clusters. Also triggers a peer-side network-map +// recompute via accountManager.UpdateAccountPeers so the +// private-service ACL injection picks up the new state immediately. +// +// Reconcile failures are logged and swallowed — the underlying CRUD +// has already completed, and the next mutation (or proxy reconnect) +// will re-converge the cluster's view. +func (m *managerImpl) reconcile(ctx context.Context, accountID string) { + if accountID == "" { + return + } + + defer func() { + if m.accountManager != nil { + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{ + Resource: types.UpdateResourceService, + Operation: types.UpdateOperationUpdate, + }) + } + }() + + if m.proxyController == nil { + return + } + + services, err := SynthesizeServices(ctx, m.store, accountID) + if err != nil { + log.WithContext(ctx).WithError(err).Warnf("agent-network reconcile: synthesise services for account %s", accountID) + return + } + + oidcCfg := m.proxyController.GetOIDCValidationConfig() + current := make(map[string]*proto.ProxyMapping, len(services)) + for _, svc := range services { + if svc == nil || svc.ID == "" { + continue + } + current[svc.ID] = svc.ToProtoMapping(rpservice.Update, "", oidcCfg) + } + + m.reconcileMu.Lock() + previous := m.reconcileCache[accountID] + if previous == nil { + previous = make(map[string]*proto.ProxyMapping) + } + + creates, updates, deletes := diffMappings(previous, current) + if len(current) == 0 { + delete(m.reconcileCache, accountID) + } else { + m.reconcileCache[accountID] = current + } + m.reconcileMu.Unlock() + + for _, mapping := range creates { + mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping)) + } + for _, mapping := range updates { + mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping)) + } + for _, mapping := range deletes { + mapping.Type = proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, mapping, clusterFromMapping(mapping)) + } +} + +// diffMappings classifies the previous→current transition for a +// single account into Create / Update / Delete sets. +// +// Cluster moves (current.cluster != previous.cluster) are surfaced as +// a Delete on the old cluster + Create on the new — handled by +// emitting both a delete (on previous mapping) and a create (on the +// current mapping) for that service ID. +func diffMappings(previous, current map[string]*proto.ProxyMapping) (creates, updates, deletes []*proto.ProxyMapping) { + for id, cur := range current { + prev, existed := previous[id] + switch { + case !existed: + creates = append(creates, cur) + case prev.GetDomain() == "" || cur.GetAccountId() == prev.GetAccountId() && currentClusterChanged(prev, cur): + deletes = append(deletes, prev) + creates = append(creates, cur) + default: + updates = append(updates, cur) + } + } + for id, prev := range previous { + if _, stillThere := current[id]; !stillThere { + deletes = append(deletes, prev) + } + } + return creates, updates, deletes +} + +func currentClusterChanged(prev, cur *proto.ProxyMapping) bool { + return clusterFromMapping(prev) != clusterFromMapping(cur) +} + +// clusterFromMapping returns the cluster the mapping should be sent +// to. ProxyMapping doesn't carry the cluster directly, so we rely on +// the synthesised service's domain (`.`) and split on +// the first '.'. +func clusterFromMapping(m *proto.ProxyMapping) string { + if m == nil { + return "" + } + domain := m.GetDomain() + for i := 0; i < len(domain); i++ { + if domain[i] == '.' { + return domain[i+1:] + } + } + return "" +} diff --git a/management/internals/modules/agentnetwork/reconcile_test.go b/management/internals/modules/agentnetwork/reconcile_test.go new file mode 100644 index 000000000..0855f0dc1 --- /dev/null +++ b/management/internals/modules/agentnetwork/reconcile_test.go @@ -0,0 +1,232 @@ +package agentnetwork + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/proto" +) + +func newReconcileMgr(t *testing.T, ctrl *gomock.Controller) (*managerImpl, *store.MockStore, *proxy.MockController) { + t.Helper() + mockStore := store.NewMockStore(ctrl) + mockProxy := proxy.NewMockController(ctrl) + return &managerImpl{ + store: mockStore, + proxyController: mockProxy, + reconcileCache: make(map[string]map[string]*proto.ProxyMapping), + }, mockStore, mockProxy +} + +func newReconcileTestProvider() *types.Provider { + return &types.Provider{ + ID: "prov-1", + AccountID: "acct-1", + ProviderID: "openai_api", + Name: "OpenAI", + UpstreamURL: "https://api.openai.com", + APIKey: "sk-test-key", + Enabled: true, + SessionPrivateKey: "test-priv-key", + SessionPublicKey: "test-pub-key", + } +} + +func newReconcileTestPolicy(providerID, sourceGroupID string) *types.Policy { + return &types.Policy{ + ID: "pol-1", + AccountID: "acct-1", + Name: "engineers", + Enabled: true, + SourceGroups: []string{sourceGroupID}, + DestinationProviderIDs: []string{providerID}, + } +} + +func newReconcileTestSettings() *types.Settings { + return &types.Settings{ + AccountID: "acct-1", + Cluster: "eu.proxy.netbird.io", + Subdomain: "violet", + } +} + +func expectReconcileSynthInputs(mockStore *store.MockStore, ctx context.Context, providers []*types.Provider, policies []*types.Policy, guardrails []*types.Guardrail) { + mockStore.EXPECT(). + GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1"). + Return(newReconcileTestSettings(), nil) + mockStore.EXPECT(). + GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1"). + Return(providers, nil) + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1"). + Return(policies, nil) + mockStore.EXPECT(). + GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1"). + Return(guardrails, nil) +} + +func TestReconcile_FirstSynth_EmitsCreate(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl) + provider := newReconcileTestProvider() + policy := newReconcileTestPolicy(provider.ID, "grp-eng") + + expectReconcileSynthInputs(mockStore, ctx, []*types.Provider{provider}, []*types.Policy{policy}, []*types.Guardrail{}) + mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}) + + var sentMappings []*proto.ProxyMapping + mockProxy.EXPECT(). + SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), "eu.proxy.netbird.io"). + Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) { + sentMappings = append(sentMappings, m) + }) + + mgr.reconcile(ctx, "acct-1") + + require.Len(t, sentMappings, 1, "first synth must emit one mapping") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, sentMappings[0].Type, "first synth is a Create") + assert.Equal(t, "agent-net-svc-acct-1", sentMappings[0].Id, "stable account-scoped virtual service id") + assert.Equal(t, "violet.eu.proxy.netbird.io", sentMappings[0].Domain, "domain comes from settings (subdomain.cluster)") + + mgr.reconcileMu.Lock() + cached := mgr.reconcileCache["acct-1"] + mgr.reconcileMu.Unlock() + require.Len(t, cached, 1, "cache must hold the synth result for next diff") +} + +func TestReconcile_NoChange_EmitsNothingExtra(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl) + provider := newReconcileTestProvider() + policy := newReconcileTestPolicy(provider.ID, "grp-eng") + + // Two identical synth runs. + mockStore.EXPECT(). + GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1"). + Return(newReconcileTestSettings(), nil).Times(2) + mockStore.EXPECT(). + GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1"). + Return([]*types.Provider{provider}, nil).Times(2) + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1"). + Return([]*types.Policy{policy}, nil).Times(2) + mockStore.EXPECT(). + GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1"). + Return([]*types.Guardrail{}, nil).Times(2) + mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).Times(2) + + createCalls := 0 + updateCalls := 0 + mockProxy.EXPECT(). + SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), gomock.Any()). + Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) { + switch m.Type { + case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED: + createCalls++ + case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED: + updateCalls++ + } + }). + AnyTimes() + + mgr.reconcile(ctx, "acct-1") + mgr.reconcile(ctx, "acct-1") + + assert.Equal(t, 1, createCalls, "first reconcile creates") + assert.Equal(t, 1, updateCalls, "second reconcile re-pushes as Modified (no semantic change but mapping fields refresh)") +} + +func TestReconcile_PolicyRemoved_EmitsDelete(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mgr, mockStore, mockProxy := newReconcileMgr(t, ctrl) + provider := newReconcileTestProvider() + policy := newReconcileTestPolicy(provider.ID, "grp-eng") + + gomock.InOrder( + // First reconcile: provider + policy, synthesised. + mockStore.EXPECT().GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").Return(newReconcileTestSettings(), nil), + mockStore.EXPECT().GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Provider{provider}, nil), + mockStore.EXPECT().GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Policy{policy}, nil), + mockStore.EXPECT().GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Guardrail{}, nil), + // Second reconcile: policy gone, provider stays but no longer referenced. + mockStore.EXPECT().GetAgentNetworkSettings(ctx, store.LockingStrengthNone, "acct-1").Return(newReconcileTestSettings(), nil), + mockStore.EXPECT().GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Provider{provider}, nil), + mockStore.EXPECT().GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, "acct-1").Return([]*types.Policy{}, nil), + ) + mockProxy.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes() + + var seenTypes []proto.ProxyMappingUpdateType + mockProxy.EXPECT(). + SendServiceUpdateToCluster(ctx, "acct-1", gomock.Any(), "eu.proxy.netbird.io"). + Do(func(_ context.Context, _ string, m *proto.ProxyMapping, _ string) { + seenTypes = append(seenTypes, m.Type) + }). + AnyTimes() + + mgr.reconcile(ctx, "acct-1") + mgr.reconcile(ctx, "acct-1") + + require.Len(t, seenTypes, 2, "create then delete") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, seenTypes[0]) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, seenTypes[1]) + + mgr.reconcileMu.Lock() + _, present := mgr.reconcileCache["acct-1"] + mgr.reconcileMu.Unlock() + assert.False(t, present, "cache for the account must be cleared once nothing is synthesised") +} + +func TestReconcile_NilProxyController_NoOp(t *testing.T) { + ctx := context.Background() + mgr := &managerImpl{ + reconcileCache: make(map[string]map[string]*proto.ProxyMapping), + } + // Must not panic; must not query the store. + mgr.reconcile(ctx, "acct-1") +} + +func TestReconcile_EmptyAccountID_NoOp(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mgr, _, _ := newReconcileMgr(t, ctrl) + // Empty accountID short-circuits before any store call. + mgr.reconcile(ctx, "") +} + +func TestClusterFromMapping(t *testing.T) { + tests := []struct { + name string + domain string + want string + }{ + {"simple", "openai.eu.proxy.netbird.io", "eu.proxy.netbird.io"}, + {"deeply nested", "a.b.c.d", "b.c.d"}, + {"no dot", "openai", ""}, + {"empty", "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := clusterFromMapping(&proto.ProxyMapping{Domain: tt.domain}) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/management/internals/modules/agentnetwork/synthesizer.go b/management/internals/modules/agentnetwork/synthesizer.go new file mode 100644 index 000000000..9814d1a11 --- /dev/null +++ b/management/internals/modules/agentnetwork/synthesizer.go @@ -0,0 +1,1083 @@ +package agentnetwork + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/url" + "sort" + "strings" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/catalog" + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" +) + +// apiKeyPlaceholder is the literal substituted with the provider's +// decrypted API key in catalog AuthHeaderTemplate strings. +const apiKeyPlaceholder = "${API_KEY}" //nolint:gosec // template marker, not a credential + +// gcpKeyfilePrefix marks an api_key that holds a base64-encoded GCP +// service-account JSON key ("keyfile::") rather than a static bearer +// token; the proxy mints OAuth tokens from it. Mirrors Aperture's convention. +const gcpKeyfilePrefix = "keyfile::" + +// SynthesizedServiceIDPrefix prefixes the in-memory ID of every +// reverse-proxy service synthesised from Agent Network state. One +// synthesised service exists per (account, cluster); the suffix is the +// account ID so the proxy can dedup mappings cleanly. +const SynthesizedServiceIDPrefix = "agent-net-svc-" + +// agentNetworkRequestCaptureBytes is the request-side body capture cap. +// Kept modest: oversized requests (a long conversation's context can be +// many MB) have their routing fields recovered by the proxy's tolerant +// scan rather than buffered here, so there's no need to size this to the +// largest possible request. +const agentNetworkRequestCaptureBytes = 1 << 20 + +// agentNetworkResponseCaptureBytes is the response-side body capture cap. +// Token usage lives in the trailing SSE message_delta event, so the +// captured prefix must reach the end of the stream. Unlike a request's +// unbounded context, a single response is hard-capped by the model's max +// output tokens (~128K on Opus → a few hundred KB of gzipped SSE even +// with thinking), so 8 MiB is comfortably above any real response and is +// effectively unlimited here — not a moving ceiling. The proxy clamps to +// its own MaxBodyCapBytes at apply time. +const agentNetworkResponseCaptureBytes = 8 << 20 + +// agentNetworkCaptureContentTypes is the set of content types whose +// bodies the proxy buffers for the LLM middlewares. JSON covers +// buffered request and response bodies; SSE covers streaming +// responses (the response parser sums delta tokens across chunks). +var agentNetworkCaptureContentTypes = []string{ + "application/json", + "text/event-stream", +} + +// Middleware IDs the synthesised target chain registers, mirroring the +// proxy-side built-in registry. Order matters: on_request runs in the +// order they're listed; on_response runs in reverse, so cost_meter must +// come BEFORE llm_response_parser in the slice so the parser populates +// tokens before the cost meter reads them. +const ( + middlewareIDLLMRequestParser = "llm_request_parser" + middlewareIDLLMRouter = "llm_router" + middlewareIDLLMIdentityInject = "llm_identity_inject" + middlewareIDLLMLimitCheck = "llm_limit_check" + middlewareIDLLMGuardrail = "llm_guardrail" + middlewareIDCostMeter = "cost_meter" + middlewareIDLLMResponseParser = "llm_response_parser" + middlewareIDLLMLimitRecord = "llm_limit_record" +) + +// SynthesizeServicesForCluster walks every account's agent-network +// settings row pinned to clusterAddr and synthesises the per-account +// gateway service. Used by the proxy-mapping snapshot path where the +// connecting proxy has a specific cluster address and cares about every +// account that routes through it. +// +// Returns nil (no error) when no settings row references the cluster. +// Per-account synthesis failures are skipped rather than dropping every +// account on the cluster. +func SynthesizeServicesForCluster(ctx context.Context, s store.Store, clusterAddr string) ([]*rpservice.Service, error) { + clusterAddr = strings.TrimSpace(clusterAddr) + if clusterAddr == "" { + return nil, nil + } + + settingsRows, err := s.GetAgentNetworkSettingsByCluster(ctx, store.LockingStrengthNone, clusterAddr) + if err != nil { + return nil, fmt.Errorf("list agent network settings on cluster: %w", err) + } + if len(settingsRows) == 0 { + return nil, nil + } + + var out []*rpservice.Service + for _, settings := range settingsRows { + if settings == nil { + continue + } + services, serr := SynthesizeServices(ctx, s, settings.AccountID) + if serr != nil { + continue + } + for _, svc := range services { + if svc != nil && svc.ProxyCluster == clusterAddr { + out = append(out, svc) + } + } + } + return out, nil +} + +// SynthesizeServiceForDomain resolves a single agent-network service by its +// public endpoint domain. It lists the (few) settings rows on the domain's +// cluster, matches the one whose endpoint equals the domain, and synthesises +// only that account — avoiding full per-account synthesis for every tenant on +// the cluster, which is what auth/session paths previously paid. Returns nil +// (no error) when no account owns the domain. +func SynthesizeServiceForDomain(ctx context.Context, s store.Store, domain string) (*rpservice.Service, error) { + domain = strings.TrimSpace(domain) + cluster := clusterFromDomain(domain) + if domain != "" && cluster != "" { + settingsRows, err := s.GetAgentNetworkSettingsByCluster(ctx, store.LockingStrengthNone, cluster) + if err != nil { + return nil, fmt.Errorf("list agent network settings on cluster: %w", err) + } + for _, settings := range settingsRows { + if settings == nil || settings.Endpoint() != domain { + continue + } + services, serr := SynthesizeServices(ctx, s, settings.AccountID) + if serr != nil { + return nil, serr + } + for _, svc := range services { + if svc != nil && svc.Domain == domain { + return svc, nil + } + } + break + } + } + return nil, nil //nolint:nilnil // optional lookup: no account owns the domain +} + +// clusterFromDomain returns the cluster portion of an endpoint domain (every +// label after the first). +func clusterFromDomain(domain string) string { + if i := strings.IndexByte(domain, '.'); i >= 0 { + return domain[i+1:] + } + return "" +} + +// SynthesizeServices builds the in-memory reverse-proxy service that +// fronts the account's agent-network gateway. Returns nil when the +// account has no settings row, no enabled providers, or no enabled +// policies — in any of those cases there's nothing useful to expose. +// +// One service per (account, settings.Cluster) is emitted. The router +// middleware encodes a denormalised model→provider routing table +// (auth headers + decrypted API keys baked in); the policy_check +// middleware encodes per-provider authorised group IDs derived from +// the account's enabled policies. +// +// Services are NEVER persisted — callers regenerate them on every +// network-map / proxy-mapping cycle from current state. +func SynthesizeServices(ctx context.Context, s store.Store, accountID string) ([]*rpservice.Service, error) { + settings, ok, err := loadSettings(ctx, s, accountID) + if err != nil { + return nil, err + } + if !ok || strings.TrimSpace(settings.Cluster) == "" { + return nil, nil + } + + providers, err := s.GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return nil, fmt.Errorf("list agent network providers: %w", err) + } + enabledProviders := filterEnabledProviders(providers) + if len(enabledProviders) == 0 { + return nil, nil + } + + policies, err := s.GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return nil, fmt.Errorf("list agent network policies: %w", err) + } + enabledPolicies := filterEnabledPolicies(policies) + if len(enabledPolicies) == 0 { + return nil, nil + } + + // Backfill any missing session keypairs before deriving a + // service-level keypair. Old rows pre-date the column; treating + // the gap as a no-op produces an immediate dial failure, so we + // fix it once here and persist for future cycles. + for _, p := range enabledProviders { + if p.SessionPrivateKey != "" && p.SessionPublicKey != "" { + continue + } + if err := backfillProviderSessionKeys(ctx, s, p); err != nil { + return nil, fmt.Errorf("backfill session keys for provider %s: %w", p.ID, err) + } + } + + guardrails, err := s.GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return nil, fmt.Errorf("list agent network guardrails: %w", err) + } + guardrailsByID := make(map[string]*types.Guardrail, len(guardrails)) + for _, g := range guardrails { + if g != nil { + guardrailsByID[g.ID] = g + } + } + + groupIndex := indexProviderGroups(enabledPolicies) + + routerCfgJSON, err := buildRouterConfigJSON(enabledProviders, groupIndex) + if err != nil { + return nil, err + } + + identityInjectJSON, err := buildIdentityInjectConfigJSON(enabledProviders, groupIndex) + if err != nil { + return nil, err + } + + mergedGuardrails := mergeGuardrails(enabledPolicies, guardrailsByID) + applyAccountCollectionControls(&mergedGuardrails, settings) + guardrailJSON, err := marshalGuardrailConfig(mergedGuardrails) + if err != nil { + return nil, err + } + + // Use the merged decision (account settings OR policy-required redaction), + // not the raw account flag, so a policy that mandates PII redaction is + // honored by the capture parsers even when the account toggle is off. + middlewares := buildMiddlewareChain(routerCfgJSON, identityInjectJSON, guardrailJSON, mergedGuardrails.PromptCapture.RedactPii, mergedGuardrails.PromptCapture.Enabled) + + priv, pub, err := pickServiceSessionKeys(enabledProviders) + if err != nil { + return nil, err + } + + svc := buildAccountService(accountID, settings, enabledPolicies, middlewares, priv, pub) + return []*rpservice.Service{svc}, nil +} + +// loadSettings returns the account's agent-network settings row. The +// boolean reports whether a row exists; a status.NotFound surfaces as +// (nil, false, nil) so callers can treat "no settings" as "no +// synthesis" without inspecting error types themselves. +func loadSettings(ctx context.Context, s store.Store, accountID string) (*types.Settings, bool, error) { + settings, err := s.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID) + if err == nil { + return settings, true, nil + } + var sErr *status.Error + if errors.As(err, &sErr) && sErr.Type() == status.NotFound { + return nil, false, nil + } + return nil, false, fmt.Errorf("get agent network settings: %w", err) +} + +// filterEnabledProviders returns the subset of enabled providers, sorted +// by created_at ascending so the router config is deterministic and +// first-match-wins is stable across synthesis cycles. +func filterEnabledProviders(providers []*types.Provider) []*types.Provider { + out := make([]*types.Provider, 0, len(providers)) + for _, p := range providers { + if p == nil || !p.Enabled { + continue + } + out = append(out, p) + } + sort.SliceStable(out, func(i, j int) bool { + if !out[i].CreatedAt.Equal(out[j].CreatedAt) { + return out[i].CreatedAt.Before(out[j].CreatedAt) + } + return out[i].ID < out[j].ID + }) + return out +} + +// filterEnabledPolicies returns the subset of enabled policies. +func filterEnabledPolicies(policies []*types.Policy) []*types.Policy { + out := make([]*types.Policy, 0, len(policies)) + for _, p := range policies { + if p == nil || !p.Enabled { + continue + } + out = append(out, p) + } + return out +} + +// backfillProviderSessionKeys mints an ed25519 session keypair on a +// provider row that doesn't have one yet (rows created before the +// keys were persistent fields) and persists it via the store so +// subsequent cycles get stable keys. +func backfillProviderSessionKeys(ctx context.Context, s store.Store, p *types.Provider) error { + pair, err := sessionkey.GenerateKeyPair() + if err != nil { + return fmt.Errorf("generate session keys for provider %s: %w", p.ID, err) + } + p.SessionPrivateKey = pair.PrivateKey + p.SessionPublicKey = pair.PublicKey + if err := s.SaveAgentNetworkProvider(ctx, p); err != nil { + return fmt.Errorf("persist backfilled session keys for provider %s: %w", p.ID, err) + } + return nil +} + +// pickServiceSessionKeys returns the keypair the synthesised gateway +// service signs / verifies session JWTs with. The PoC reuses the first +// enabled provider's keypair so existing session cookies survive +// provider edits as long as the first-by-created_at provider stays in +// place. Returns an error when no provider has a usable keypair after +// backfill — that surfaces a misconfigured account loudly instead of +// emitting a service the proxy will reject as "invalid session public +// key size". +func pickServiceSessionKeys(providers []*types.Provider) (priv, pub string, err error) { + for _, p := range providers { + if p.SessionPrivateKey != "" && p.SessionPublicKey != "" { + return p.SessionPrivateKey, p.SessionPublicKey, nil + } + } + return "", "", fmt.Errorf("no provider with session keypair; update one provider to backfill") +} + +// routerConfig mirrors the on-wire shape llm_router accepts. Kept +// private so the synthesiser owns the contract; the proxy-side factory +// JSON-decodes the same shape. +type routerConfig struct { + Providers []routerProviderRoute `json:"providers"` +} + +type routerProviderRoute struct { + ID string `json:"id"` + Vendor string `json:"vendor,omitempty"` + Models []string `json:"models"` + UpstreamScheme string `json:"upstream_scheme"` + UpstreamHost string `json:"upstream_host"` + UpstreamPath string `json:"upstream_path,omitempty"` + AuthHeaderName string `json:"auth_header_name"` + AuthHeaderValue string `json:"auth_header_value"` + AllowedGroupIDs []string `json:"allowed_group_ids,omitempty"` + // Vertex marks a Google Vertex AI provider, whose requests carry the + // model in the URL path. The router selects it by path, bypassing the + // model/vendor table. + Vertex bool `json:"vertex,omitempty"` + // Bedrock marks an AWS Bedrock provider, whose requests carry the model in + // the URL path (/model/{id}/{action}). The router selects it by path, + // bypassing the model/vendor table; auth is a static bearer token. + Bedrock bool `json:"bedrock,omitempty"` + // GCPServiceAccountKeyB64 carries a base64-encoded GCP service-account + // JSON key (from a "keyfile::" api_key). When set, the proxy mints + // + refreshes the OAuth token at request time instead of injecting a static + // AuthHeaderValue. + GCPServiceAccountKeyB64 string `json:"gcp_sa_key_b64,omitempty"` +} + +// indexProviderGroups walks the enabled policies and returns, per +// provider id, the sorted union of source group ids across every +// policy that authorises the provider. Providers with no authorising +// policy are absent from the map. The router consumes this to filter +// candidate routes by the caller's group memberships before the +// path-prefix tiebreak runs. +func indexProviderGroups(policies []*types.Policy) map[string][]string { + sets := make(map[string]map[string]struct{}) + for _, policy := range policies { + if policy == nil { + continue + } + for _, providerID := range policy.DestinationProviderIDs { + if providerID == "" { + continue + } + set, ok := sets[providerID] + if !ok { + set = make(map[string]struct{}) + sets[providerID] = set + } + for _, group := range policy.SourceGroups { + if group != "" { + set[group] = struct{}{} + } + } + } + } + out := make(map[string][]string, len(sets)) + for providerID, set := range sets { + groups := make([]string, 0, len(set)) + for g := range set { + groups = append(groups, g) + } + sort.Strings(groups) + out[providerID] = groups + } + return out +} + +// buildRouterConfigJSON denormalises the account's enabled providers +// into the router middleware's first-match-wins routing table. +// Providers are listed in created_at order so the table is +// deterministic and stable across synth cycles. +// +// AllowedGroupIDs is the union of source group ids across every enabled +// policy that authorises the provider. The router uses it as a hard +// filter — a route whose AllowedGroupIDs has no intersection with the +// caller's user groups is removed from the candidate list before the +// path-prefix tiebreak. Providers no enabled policy authorises +// (orphans) are intentionally OMITTED so the router never observes a +// route with an empty ACL. +func buildRouterConfigJSON(providers []*types.Provider, groupIndex map[string][]string) ([]byte, error) { + cfg := routerConfig{Providers: make([]routerProviderRoute, 0, len(providers))} + for _, p := range providers { + groups, hasPolicy := groupIndex[p.ID] + if !hasPolicy { + // Orphan: skip. No enabled policy authorises this + // provider, so it must not be reachable. + continue + } + scheme, host, path, err := parseUpstreamHost(p.UpstreamURL) + if err != nil { + return nil, fmt.Errorf("router config for provider %s: %w", p.ID, err) + } + headerName, headerValue, gcpSAKeyB64, err := providerAuthHeader(p) + if err != nil { + return nil, err + } + cfg.Providers = append(cfg.Providers, routerProviderRoute{ + ID: p.ID, + Vendor: providerVendor(p), + Models: providerModelIDs(p), + UpstreamScheme: scheme, + UpstreamHost: host, + UpstreamPath: path, + AuthHeaderName: headerName, + AuthHeaderValue: headerValue, + AllowedGroupIDs: groups, + Vertex: catalog.IsVertexPathStyle(p.ProviderID), + Bedrock: catalog.IsBedrockPathStyle(p.ProviderID), + GCPServiceAccountKeyB64: gcpSAKeyB64, + }) + } + out, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal llm_router middleware config: %w", err) + } + return out, nil +} + +// providerVendor returns the parser surface ("openai", "anthropic", …) +// the provider speaks, sourced from its catalog entry's ParserID. The +// router uses it to keep a request the parser tagged with a vendor on a +// route of the same vendor — so e.g. an Anthropic /v1/messages call is +// never sent to an OpenAI-compatible gateway that also claims the model. +// Empty when the catalog entry is unknown or declares no parser surface; +// the router then falls back to model / path routing. +func providerVendor(p *types.Provider) string { + entry, ok := catalog.Lookup(p.ProviderID) + if !ok { + return "" + } + return entry.ParserID +} + +// providerModelIDs returns the model identifiers exposed by the +// provider, deduplicated and in the operator's declared order. Empty +// slice when no models are configured — the router treats that as +// "claim every model" so gateway-style providers (LiteLLM, custom +// OpenAI-compatible endpoints) work without the operator enumerating +// the upstream's full model catalog in NetBird. +func providerModelIDs(p *types.Provider) []string { + if len(p.Models) == 0 { + return []string{} + } + seen := make(map[string]struct{}, len(p.Models)) + out := make([]string, 0, len(p.Models)) + for _, m := range p.Models { + if m.ID == "" { + continue + } + if _, dup := seen[m.ID]; dup { + continue + } + seen[m.ID] = struct{}{} + out = append(out, m.ID) + } + return out +} + +// identityInjectConfig mirrors the on-wire shape llm_identity_inject +// accepts. +type identityInjectConfig struct { + Providers []identityInjectProvider `json:"providers"` +} + +// identityInjectProvider carries one provider's injection rule. +// Identity-stamping uses one of HeaderPair / JSONMetadata (mutually +// exclusive). ExtraHeaders is independent — a list of extra +// per-provider routing/config headers (catalog-declared, value lives +// on the provider record) the middleware stamps with anti-spoof +// (Remove + Add) on every matching request. +type identityInjectProvider struct { + ProviderID string `json:"provider_id"` + HeaderPair *identityInjectHeaderPair `json:"header_pair,omitempty"` + JSONMetadata *identityInjectJSONMetadata `json:"json_metadata,omitempty"` + ExtraHeaders []identityInjectExtraHeader `json:"extra_headers,omitempty"` +} + +type identityInjectExtraHeader struct { + Name string `json:"name"` + Value string `json:"value"` +} + +type identityInjectHeaderPair struct { + EndUserIDHeader string `json:"end_user_id_header,omitempty"` + TagsHeader string `json:"tags_header,omitempty"` + TagsInBody bool `json:"tags_in_body,omitempty"` + EndUserIDInBody bool `json:"end_user_id_in_body,omitempty"` +} + +type identityInjectJSONMetadata struct { + Header string `json:"header"` + UserKey string `json:"user_key,omitempty"` + GroupsKey string `json:"groups_key,omitempty"` + MaxValueLength int `json:"max_value_length,omitempty"` +} + +// buildIdentityInjectConfigJSON walks the enabled providers and emits +// one entry per provider whose catalog entry declares an +// IdentityInjection block. The middleware no-ops for any provider not +// in this list, so the chain is safe to ship to all targets even when +// no identity-stamping provider is configured. +// +// The caller passes groupIndex so we can mirror the synthesiser's own +// "drop orphans" rule — providers no enabled policy authorises don't +// reach the router, so injecting identity for them would never fire. +// We could leave them in for symmetry, but skipping is cheaper and +// clearer. +func buildIdentityInjectConfigJSON(providers []*types.Provider, groupIndex map[string][]string) ([]byte, error) { + cfg := identityInjectConfig{Providers: make([]identityInjectProvider, 0)} + for _, p := range providers { + if _, hasPolicy := groupIndex[p.ID]; !hasPolicy { + continue + } + entry, ok := catalog.Lookup(p.ProviderID) + if !ok { + continue + } + rule, ok := buildIdentityInjectRule(p, entry) + if !ok { + continue + } + cfg.Providers = append(cfg.Providers, rule) + } + out, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal llm_identity_inject middleware config: %w", err) + } + return out, nil +} + +// buildIdentityInjectRule assembles the injection rule for one provider +// from its record and catalog entry. The second return is false when the +// provider would emit nothing, so the caller can skip it entirely rather +// than carry an inert rule for it. +func buildIdentityInjectRule(p *types.Provider, entry catalog.Provider) (identityInjectProvider, bool) { + rule := identityInjectProvider{ProviderID: p.ID} + // Identity-stamping shape (one of HeaderPair / JSONMetadata). Skip the + // shape silently when the catalog entry doesn't declare one — extras + // can still apply, see below. + if entry.IdentityInjection != nil { + switch { + case entry.IdentityInjection.HeaderPair != nil: + rule.HeaderPair = buildIdentityHeaderPair(p, entry.IdentityInjection.HeaderPair) + case entry.IdentityInjection.JSONMetadata != nil: + rule.JSONMetadata = buildIdentityJSONMetadata(p, entry.IdentityInjection.JSONMetadata) + } + } + rule.ExtraHeaders = buildIdentityExtraHeaders(p, entry.ExtraHeaders) + if rule.HeaderPair == nil && rule.JSONMetadata == nil && len(rule.ExtraHeaders) == 0 { + return identityInjectProvider{}, false + } + return rule, true +} + +// buildIdentityHeaderPair resolves the header-pair injection shape, +// returning nil when nothing would be stamped. For Customizable shapes +// (Bifrost today) the wire header names come from the provider record +// verbatim; the catalog values are placeholder defaults shown by the +// dashboard, not authoritative. Empty operator value disables stamping +// for that dimension — applyHeaderPair already no-ops on empty header +// names. The body-inject flags stay catalog-owned because Customizable +// today only applies to gateways that read identity from headers (the +// flags would be no-ops anyway). +func buildIdentityHeaderPair(p *types.Provider, hp *catalog.HeaderPairInjection) *identityInjectHeaderPair { + userHeader := hp.EndUserIDHeader + tagsHeader := hp.TagsHeader + if hp.Customizable { + userHeader = p.IdentityHeaderUserID + tagsHeader = p.IdentityHeaderGroups + } + if userHeader == "" && tagsHeader == "" && !hp.TagsInBody && !hp.EndUserIDInBody { + return nil + } + return &identityInjectHeaderPair{ + EndUserIDHeader: userHeader, + TagsHeader: tagsHeader, + TagsInBody: hp.TagsInBody, + EndUserIDInBody: hp.EndUserIDInBody, + } +} + +// buildIdentityJSONMetadata resolves the JSON-metadata injection shape, +// returning nil when the catalog entry carries no header. Customizable +// JSONMetadata reuses the same provider-record fields HeaderPair uses — +// IdentityHeaderUserID becomes the JSON key for the user dimension, and +// IdentityHeaderGroups becomes the JSON key for groups. Empty operator +// value is honored as "skip this key"; applyJSONMetadata already drops +// keys with empty names. Header itself is catalog-owned (e.g. +// cf-aig-metadata) — operators only override the keys inside the JSON, +// not the wire header that carries it. +func buildIdentityJSONMetadata(p *types.Provider, jm *catalog.JSONMetadataInjection) *identityInjectJSONMetadata { + if jm.Header == "" { + return nil + } + userKey := jm.UserKey + groupsKey := jm.GroupsKey + if jm.Customizable { + userKey = p.IdentityHeaderUserID + groupsKey = p.IdentityHeaderGroups + } + return &identityInjectJSONMetadata{ + Header: jm.Header, + UserKey: userKey, + GroupsKey: groupsKey, + MaxValueLength: jm.MaxValueLength, + } +} + +// buildIdentityExtraHeaders collects catalog-declared static headers (e.g. +// Portkey config id), emitting only entries whose value the operator has +// filled in on the provider record; missing/empty values are no-ops. +func buildIdentityExtraHeaders(p *types.Provider, extras []catalog.ExtraHeader) []identityInjectExtraHeader { + var out []identityInjectExtraHeader + for _, h := range extras { + if h.Name == "" { + continue + } + v := strings.TrimSpace(p.ExtraValues[h.Name]) + if v == "" { + continue + } + out = append(out, identityInjectExtraHeader{Name: h.Name, Value: v}) + } + return out +} + +// buildMiddlewareChain assembles the per-target middleware chain that +// implements the Agent Network behaviour at the proxy. Slot order on +// the request leg is the slice order; on the response leg it runs in +// reverse, so cost_meter must come BEFORE llm_response_parser so the +// parser populates token counts before the cost meter reads them. +// +// Authorisation is fused into llm_router: the router carries +// AllowedGroupIDs per provider and filters candidates by the caller's +// user-groups before the path-prefix tiebreak. Per-policy +// enforcement (token / budget caps) lives in llm_limit_check, which +// runs after the router so it can read the resolved provider id; +// llm_limit_record on the response leg posts deltas back to +// management to keep the consumption counters fresh. +// +// llm_identity_inject runs immediately after the router so the +// resolved provider id is available; it stamps NetBird identity onto +// requests bound for gateways like LiteLLM that key budgets and +// attribution off request headers. CanMutate is required so its +// HeadersAdd / HeadersRemove pass the framework's mutation gate. +func buildMiddlewareChain(routerCfgJSON, identityInjectJSON, guardrailJSON []byte, redactPii, capturePromptContent bool) []rpservice.MiddlewareConfig { + // Both parsers receive an explicit capture flag derived from the account's + // enable_prompt_collection toggle; nil/unset would default to the legacy + // "always emit" behavior in the middleware, which is precisely what we + // must suppress when the operator hasn't opted in. The flag is duplicated + // across both parsers under distinct field names (capture_prompt / + // capture_completion) to keep each parser's config independently + // auditable. + requestParserCfg := buildParserConfigJSON("capture_prompt", redactPii, capturePromptContent) + responseParserCfg := buildParserConfigJSON("capture_completion", redactPii, capturePromptContent) + return []rpservice.MiddlewareConfig{ + { + ID: middlewareIDLLMRequestParser, + Enabled: true, + Slot: rpservice.MiddlewareSlotOnRequest, + ConfigJSON: requestParserCfg, + }, + { + ID: middlewareIDLLMRouter, + Enabled: true, + Slot: rpservice.MiddlewareSlotOnRequest, + ConfigJSON: routerCfgJSON, + // llm_router rewrites the request's headers (strip + // client auth + inject provider auth) and the upstream + // target via Mutations.RewriteUpstream. Both gated on + // CanMutate; without this flag the chain framework + // drops every mutation and the reverse proxy dials the + // placeholder noop.invalid host (502). + CanMutate: true, + }, + { + // llm_limit_check runs after the router so it knows the + // resolved provider id, but before identity_inject so a + // cap-deny doesn't pay the cost of stamping headers + // we'll never use. + ID: middlewareIDLLMLimitCheck, + Enabled: true, + Slot: rpservice.MiddlewareSlotOnRequest, + ConfigJSON: []byte("{}"), + }, + { + ID: middlewareIDLLMIdentityInject, + Enabled: true, + Slot: rpservice.MiddlewareSlotOnRequest, + ConfigJSON: identityInjectJSON, + // CanMutate is required so HeadersAdd / HeadersRemove + // emitted to stamp NetBird identity onto the upstream + // request actually land — without it the framework + // drops every header mutation. + CanMutate: true, + }, + { + ID: middlewareIDLLMGuardrail, + Enabled: true, + Slot: rpservice.MiddlewareSlotOnRequest, + ConfigJSON: guardrailJSON, + }, + { + // Response slot runs in reverse slice order at runtime: + // limit_record sits FIRST in the response section so it + // runs LAST, after llm_response_parser stamped tokens + // and cost_meter computed cost — both of which the + // recorder reads from the metadata bag. + ID: middlewareIDLLMLimitRecord, + Enabled: true, + Slot: rpservice.MiddlewareSlotOnResponse, + ConfigJSON: []byte("{}"), + }, + { + ID: middlewareIDCostMeter, + Enabled: true, + Slot: rpservice.MiddlewareSlotOnResponse, + ConfigJSON: []byte("{}"), + }, + { + ID: middlewareIDLLMResponseParser, + Enabled: true, + Slot: rpservice.MiddlewareSlotOnResponse, + ConfigJSON: responseParserCfg, + }, + } +} + +// guardrailConfig is the JSON shape the proxy-side llm_guardrail +// middleware expects. Mirrors the proxy registration documented in +// the management→proxy contract. +type guardrailConfig struct { + ModelAllowlist []string `json:"model_allowlist,omitempty"` + PromptCapture guardrailPromptCapture `json:"prompt_capture"` +} + +type guardrailPromptCapture struct { + Enabled bool `json:"enabled"` + RedactPii bool `json:"redact_pii"` +} + +// buildParserConfigJSON assembles the request- or response-parser config JSON. +// captureField names the parser-specific gate (capture_prompt for the request +// parser, capture_completion for the response parser); both are sourced from +// settings.EnablePromptCollection. redact_pii is only meaningful when capture +// is on (no content → nothing to redact) but we forward it verbatim so the +// proxy-side parser stays the only place that interprets the combination. +func buildParserConfigJSON(captureField string, redactPii, capture bool) []byte { + payload := map[string]any{ + captureField: capture, + } + if redactPii { + payload["redact_pii"] = true + } + out, err := json.Marshal(payload) + if err != nil { + // json.Marshal on a map[string]any of bools cannot fail; if it + // somehow does, ship the static minimal config so synth keeps + // working instead of panicking. + return []byte(`{}`) + } + return out +} + +// applyAccountCollectionControls folds the account-level collection master +// switches into the merged guardrail set. Prompt capture enablement is sourced +// SOLELY from the account toggle — the account-network setting is the master +// enable, and policies don't need to attach a capture-enabled guardrail to opt +// in. PII redaction is safe-additive: it applies when either the account or a +// policy guardrail enables it (OR). +func applyAccountCollectionControls(merged *MergedGuardrails, settings *types.Settings) { + if settings == nil { + return + } + merged.PromptCapture.Enabled = settings.EnablePromptCollection + merged.PromptCapture.RedactPii = settings.RedactPii || merged.PromptCapture.RedactPii +} + +func marshalGuardrailConfig(merged MergedGuardrails) ([]byte, error) { + cfg := guardrailConfig{ + ModelAllowlist: merged.ModelAllowlist, + PromptCapture: guardrailPromptCapture{ + Enabled: merged.PromptCapture.Enabled, + RedactPii: merged.PromptCapture.RedactPii, + }, + } + out, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal guardrail middleware config: %w", err) + } + return out, nil +} + +// buildAccountService composes the per-account gateway Service. The +// target carries the noop placeholder URL — the router middleware +// rewrites every request to the matched provider's upstream before the +// proxy dials — alongside the full middleware chain and capture caps. +func buildAccountService( + accountID string, + settings *types.Settings, + enabledPolicies []*types.Policy, + middlewares []rpservice.MiddlewareConfig, + sessionPriv, sessionPub string, +) *rpservice.Service { + cluster := settings.Cluster + domain := settings.Endpoint() + serviceID := SynthesizedServiceIDPrefix + accountID + + return &rpservice.Service{ + ID: serviceID, + AccountID: accountID, + Name: "agent-network-" + accountID, + Domain: domain, + ProxyCluster: cluster, + Mode: rpservice.ModeHTTP, + Enabled: true, + Private: true, + // AccessGroups gates tunnel-peer access (ValidateTunnelPeer) to the + // synthesised agent-network endpoint. Agents reach the gateway over + // the WireGuard tunnel and are authorised by their peer→user group + // membership — the union of every enabled policy's source groups. + AccessGroups: unionSourceGroups(enabledPolicies), + PassHostHeader: false, + RewriteRedirects: false, + SessionPrivateKey: sessionPriv, + SessionPublicKey: sessionPub, + Targets: []*rpservice.Target{ + { + AccountID: accountID, + ServiceID: serviceID, + TargetType: rpservice.TargetTypeCluster, + TargetId: cluster, + Host: noopUpstreamHost, + Port: noopUpstreamPort, + Protocol: noopUpstreamScheme, + Enabled: true, + Options: rpservice.TargetOptions{ + DirectUpstream: true, + AgentNetwork: true, + DisableAccessLog: !settings.EnableLogCollection, + Middlewares: middlewares, + CaptureMaxRequestBytes: agentNetworkRequestCaptureBytes, + CaptureMaxResponseBytes: agentNetworkResponseCaptureBytes, + CaptureContentTypes: append([]string(nil), agentNetworkCaptureContentTypes...), + }, + }, + }, + } +} + +const ( + noopUpstreamScheme = "https" + noopUpstreamHost = "noop.invalid" + noopUpstreamPort = uint16(443) +) + +// providerAuthHeader builds the upstream auth header pair for a +// provider from its catalog entry. The catalog declares which header +// name and template a provider's API expects; the synthesiser +// substitutes the provider's decrypted API key into the template and +// returns the (name, value) pair the router middleware injects after +// stripping the inbound vendor auth headers. +func providerAuthHeader(p *types.Provider) (name, value, gcpSAKeyB64 string, err error) { + entry, ok := catalog.Lookup(p.ProviderID) + if !ok { + return "", "", "", fmt.Errorf("provider %s references unknown catalog id %q", p.ID, p.ProviderID) + } + if entry.AuthHeaderName == "" || entry.AuthHeaderTemplate == "" { + return "", "", "", fmt.Errorf("catalog entry %q has no auth header configured", p.ProviderID) + } + if p.APIKey == "" { + return "", "", "", fmt.Errorf("provider %s has no api key", p.ID) + } + // A "keyfile::" api_key is a GCP service-account key, not a + // static bearer. The proxy mints + refreshes a short-lived OAuth token from + // it at request time, so carry the key material on the route and emit no + // static value. + if rest, isKeyfile := strings.CutPrefix(p.APIKey, gcpKeyfilePrefix); isKeyfile { + return entry.AuthHeaderName, "", strings.TrimSpace(rest), nil + } + value = strings.ReplaceAll(entry.AuthHeaderTemplate, apiKeyPlaceholder, p.APIKey) + return entry.AuthHeaderName, value, "", nil +} + +// parseUpstreamHost splits provider.UpstreamURL into (scheme, host, path) +// where host carries an explicit ":port" suffix when the URL set one +// and path is the URL's path component normalised by stripping a +// trailing slash. The router uses path to disambiguate providers that +// claim the same model. Used by the router config so the rewrite +// carries an authority the reverse proxy can dial verbatim. +func parseUpstreamHost(raw string) (scheme, host, path string, err error) { + parsed, perr := url.Parse(strings.TrimSpace(raw)) + if perr != nil { + return "", "", "", fmt.Errorf("parse upstream_url %q: %w", raw, perr) + } + switch strings.ToLower(parsed.Scheme) { + case "http": + scheme = "http" + case "https": + scheme = "https" + default: + return "", "", "", fmt.Errorf("upstream_url scheme must be http or https, got %q", parsed.Scheme) + } + hostname := parsed.Hostname() + if hostname == "" { + return "", "", "", fmt.Errorf("upstream_url %q has no host", raw) + } + if port := parsed.Port(); port != "" { + host = hostname + ":" + port + } else { + host = hostname + } + path = strings.TrimRight(parsed.Path, "/") + return scheme, host, path, nil +} + +// unionSourceGroups deduplicates source-group IDs across the policies +// pointing at any provider, in deterministic order. +func unionSourceGroups(policies []*types.Policy) []string { + seen := make(map[string]struct{}) + for _, policy := range policies { + for _, group := range policy.SourceGroups { + if group == "" { + continue + } + seen[group] = struct{}{} + } + } + out := make([]string, 0, len(seen)) + for group := range seen { + out = append(out, group) + } + sort.Strings(out) + return out +} + +// MergedGuardrails is the JSON shape passed to the proxy via the +// guardrail middleware's config_json. Mirrors the proxy-side +// expectations and is intentionally distinct from +// types.GuardrailChecks so we can evolve either side independently. +type MergedGuardrails struct { + ModelAllowlist []string `json:"model_allowlist,omitempty"` + TokenLimits MergedTokenLimits `json:"token_limits"` + Budget MergedBudget `json:"budget"` + PromptCapture MergedPromptCapture `json:"prompt_capture"` + Retention MergedRetention `json:"retention"` +} + +type MergedTokenLimits struct { + Hourly *MergedTokenWindow `json:"hourly,omitempty"` + Daily *MergedTokenWindow `json:"daily,omitempty"` + Monthly *MergedTokenWindow `json:"monthly,omitempty"` +} + +type MergedTokenWindow struct { + MaxInputTokens int `json:"max_input_tokens,omitempty"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` +} + +type MergedBudget struct { + Hourly *MergedBudgetWindow `json:"hourly,omitempty"` + Daily *MergedBudgetWindow `json:"daily,omitempty"` + Monthly *MergedBudgetWindow `json:"monthly,omitempty"` +} + +type MergedBudgetWindow struct { + SoftCapUSD float64 `json:"soft_cap_usd,omitempty"` + HardCapUSD float64 `json:"hard_cap_usd,omitempty"` +} + +type MergedPromptCapture struct { + Enabled bool `json:"enabled"` + RedactPii bool `json:"redact_pii"` +} + +type MergedRetention struct { + Enabled bool `json:"enabled"` + Days int `json:"days"` +} + +// mergeGuardrails computes the effective guardrail spec applied at the +// proxy, given the referencing policies and the account's guardrail +// catalogue. Policy enabled-ness is the caller's responsibility — only +// enabled policies should be passed in. +// +// Merge rules: +// - Model allowlist: union of allowlists across policies that enable it. +// - Token / Budget: most-restrictive (min of non-zero caps) per window. +// - Prompt capture: enabled if any policy enables it; redact_pii sticks +// if any enabling policy turns it on. +// - Retention: enabled if any enables it; smallest non-zero days wins. +func mergeGuardrails(policies []*types.Policy, byID map[string]*types.Guardrail) MergedGuardrails { + merged := MergedGuardrails{} + allowlist := make(map[string]struct{}) + allowlistEnabled := false + + for _, policy := range policies { + for _, gID := range policy.GuardrailIDs { + g, ok := byID[gID] + if !ok || g == nil { + continue + } + mergeGuardrail(g, &merged, allowlist, &allowlistEnabled) + } + } + + if allowlistEnabled { + merged.ModelAllowlist = make([]string, 0, len(allowlist)) + for m := range allowlist { + merged.ModelAllowlist = append(merged.ModelAllowlist, m) + } + sort.Strings(merged.ModelAllowlist) + } + return merged +} + +// mergeGuardrail folds a single guardrail's enabled checks into the +// running merge: model-allowlist models join the shared set (and flip +// allowlistEnabled), and prompt-capture / redact-pii stick once any +// enabling guardrail turns them on. +// +// TokenLimits, Budget, and Retention have moved off guardrails — token +// and budget caps now live on the Policy itself (Policy.Limits) and +// retention moves to account-level Settings — so they are not merged here. +func mergeGuardrail(g *types.Guardrail, merged *MergedGuardrails, allowlist map[string]struct{}, allowlistEnabled *bool) { + if g.Checks.ModelAllowlist.Enabled { + *allowlistEnabled = true + for _, m := range g.Checks.ModelAllowlist.Models { + if m != "" { + allowlist[m] = struct{}{} + } + } + } + if g.Checks.PromptCapture.Enabled { + merged.PromptCapture.Enabled = true + if g.Checks.PromptCapture.RedactPii { + merged.PromptCapture.RedactPii = true + } + } +} diff --git a/management/internals/modules/agentnetwork/synthesizer_guardrail_realstore_test.go b/management/internals/modules/agentnetwork/synthesizer_guardrail_realstore_test.go new file mode 100644 index 000000000..8ed4910da --- /dev/null +++ b/management/internals/modules/agentnetwork/synthesizer_guardrail_realstore_test.go @@ -0,0 +1,178 @@ +package agentnetwork + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/store" +) + +// decodeServiceGuardrailConfig pulls the llm_guardrail middleware config off the +// synthesised service's single target. +func decodeServiceGuardrailConfig(t *testing.T, svc *rpservice.Service) guardrailConfig { + t.Helper() + require.NotEmpty(t, svc.Targets, "synth service must carry a target") + for _, mw := range svc.Targets[0].Options.Middlewares { + if mw.ID == middlewareIDLLMGuardrail { + var cfg guardrailConfig + require.NoError(t, json.Unmarshal(mw.ConfigJSON, &cfg), "guardrail config must decode") + return cfg + } + } + t.Fatal("llm_guardrail middleware not present on synthesised service") + return guardrailConfig{} +} + +// decodeMiddlewareRawConfig returns the raw ConfigJSON bytes for the named +// middleware on the synth service's target, or fails the test. +func decodeMiddlewareRawConfig(t *testing.T, svc *rpservice.Service, id string) []byte { + t.Helper() + require.NotEmpty(t, svc.Targets, "synth service must carry a target") + for _, mw := range svc.Targets[0].Options.Middlewares { + if mw.ID == id { + return mw.ConfigJSON + } + } + t.Fatalf("middleware %q not present on synthesised service", id) + return nil +} + +// saveGuardrailAndPolicy persists a guardrail with prompt capture + redact + a +// model allowlist, referenced by one enabled policy. Shared by the GC-3 tests. +func saveGuardrailAndPolicy(t *testing.T, ctx context.Context, s store.Store, provider *types.Provider) { + t.Helper() + guardrail := &types.Guardrail{ + ID: "ainguard-1", + AccountID: testAccountID, + Name: "strict", + Checks: types.GuardrailChecks{ + ModelAllowlist: types.GuardrailModelAllowlist{Enabled: true, Models: []string{"gpt-5.4"}}, + PromptCapture: types.GuardrailPromptCapture{Enabled: true, RedactPii: true}, + }, + } + require.NoError(t, s.SaveAgentNetworkGuardrail(ctx, guardrail)) + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", guardrail.ID))) +} + +// TestSynthesizeServices_RealStore_PromptCaptureAccountIsSoleControl is the +// GC-3 contract: the account master switch (EnablePromptCollection) is the +// SOLE control for capture enablement. Policy-level guardrail prompt_capture is +// ignored for enablement — operators don't need to attach a capture guardrail +// to a policy just to turn capture on for the account. Off by default. +func TestSynthesizeServices_RealStore_PromptCaptureAccountIsSoleControl(t *testing.T) { + ctx := context.Background() + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + // Account collection master switch OFF (default). + require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings())) + saveGuardrailAndPolicy(t, ctx, s, newSynthTestProvider()) + + services, err := SynthesizeServices(ctx, s, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + cfg := decodeServiceGuardrailConfig(t, services[0]) + assert.Equal(t, []string{"gpt-5.4"}, cfg.ModelAllowlist, + "model allowlist is a pure policy guardrail and must always reach the config") + assert.False(t, cfg.PromptCapture.Enabled, + "prompt capture must be off when the account toggle is off, even with a capture-enabled guardrail") +} + +// TestSynthesizeServices_RealStore_PromptCaptureFlowsWhenAccountOptsIn proves +// the account toggle is sufficient on its own — even with NO guardrail +// attached to the policy, capture fires when the account opts in. Redact is +// the OR of account + guardrail. +func TestSynthesizeServices_RealStore_PromptCaptureFlowsWhenAccountOptsIn(t *testing.T) { + ctx := context.Background() + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + settings := newSynthTestSettings() + settings.EnablePromptCollection = true + require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings)) + + // Save a provider and a policy with NO guardrails attached — proves the + // account toggle is sufficient on its own. + provider := newSynthTestProvider() + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", ""))) + + services, err := SynthesizeServices(ctx, s, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + cfg := decodeServiceGuardrailConfig(t, services[0]) + assert.True(t, cfg.PromptCapture.Enabled, + "account toggle alone must enable capture; no guardrail attachment required") +} + +// TestSynthesizeServices_RealStore_AccountRedactWithoutGuardrailRedact proves +// the redact OR-merge from the account side: account RedactPii on, guardrail +// redact off, capture on at both levels. +func TestSynthesizeServices_RealStore_AccountRedactWithoutGuardrailRedact(t *testing.T) { + ctx := context.Background() + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + settings := newSynthTestSettings() + settings.EnablePromptCollection = true + settings.RedactPii = true + require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings)) + + provider := newSynthTestProvider() + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + guardrail := &types.Guardrail{ + ID: "ainguard-noredact", + AccountID: testAccountID, + Name: "capture-only", + Checks: types.GuardrailChecks{ + PromptCapture: types.GuardrailPromptCapture{Enabled: true, RedactPii: false}, + }, + } + require.NoError(t, s.SaveAgentNetworkGuardrail(ctx, guardrail)) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", guardrail.ID))) + + services, err := SynthesizeServices(ctx, s, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + cfg := decodeServiceGuardrailConfig(t, services[0]) + assert.True(t, cfg.PromptCapture.Enabled, "capture on (account + guardrail)") + assert.True(t, cfg.PromptCapture.RedactPii, "account RedactPii must apply even when the guardrail leaves it off (OR)") +} + +// TestSynthesizeServices_RealStore_NoGuardrail_CaptureOff pins the default: +// with no guardrail referenced, the synth service's guardrail config has prompt +// capture disabled and an empty allowlist. This is the "off by default" baseline +// the account switch must preserve. +func TestSynthesizeServices_RealStore_NoGuardrail_CaptureOff(t *testing.T) { + ctx := context.Background() + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings())) + provider := newSynthTestProvider() + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", ""))) + + services, err := SynthesizeServices(ctx, s, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1, "exactly one synth service expected") + + cfg := decodeServiceGuardrailConfig(t, services[0]) + assert.Empty(t, cfg.ModelAllowlist, "no guardrail → no allowlist") + assert.False(t, cfg.PromptCapture.Enabled, "no guardrail → prompt capture off by default") + assert.False(t, cfg.PromptCapture.RedactPii, "no guardrail → redact off by default") +} diff --git a/management/internals/modules/agentnetwork/synthesizer_log_collection_realstore_test.go b/management/internals/modules/agentnetwork/synthesizer_log_collection_realstore_test.go new file mode 100644 index 000000000..9aa2a0abe --- /dev/null +++ b/management/internals/modules/agentnetwork/synthesizer_log_collection_realstore_test.go @@ -0,0 +1,70 @@ +package agentnetwork + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/store" +) + +// TestSynthesizeServices_RealStore_LogCollectionOff_SuppressesAccessLog drives the +// happy default: account settings ship with EnableLogCollection=false, so the +// synthesised target opts out of access-log emission (DisableAccessLog=true) and +// the proto mapping the proxy receives reflects that. +func TestSynthesizeServices_RealStore_LogCollectionOff_SuppressesAccessLog(t *testing.T) { + ctx := context.Background() + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings())) + provider := newSynthTestProvider() + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", ""))) + + services, err := SynthesizeServices(ctx, s, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1, "exactly one synth service expected") + require.NotEmpty(t, services[0].Targets, "synth service must carry a target") + assert.True(t, services[0].Targets[0].Options.DisableAccessLog, + "EnableLogCollection=false (default) must produce DisableAccessLog=true on the synth target") + + mapping := services[0].ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{}) + require.NotEmpty(t, mapping.GetPath(), "proto mapping must carry a path") + assert.True(t, mapping.GetPath()[0].GetOptions().GetDisableAccessLog(), + "proto mapping must propagate DisableAccessLog=true so the proxy suppresses access-log emission") +} + +// TestSynthesizeServices_RealStore_LogCollectionOn_PermitsAccessLog asserts the +// inverse: once the account opts in, the synth target leaves DisableAccessLog +// at its default false and the proto wire stays unset. +func TestSynthesizeServices_RealStore_LogCollectionOn_PermitsAccessLog(t *testing.T) { + ctx := context.Background() + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + settings := newSynthTestSettings() + settings.EnableLogCollection = true + require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings)) + provider := newSynthTestProvider() + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", ""))) + + services, err := SynthesizeServices(ctx, s, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1, "exactly one synth service expected") + require.NotEmpty(t, services[0].Targets, "synth service must carry a target") + assert.False(t, services[0].Targets[0].Options.DisableAccessLog, + "EnableLogCollection=true must leave DisableAccessLog=false on the synth target") + + mapping := services[0].ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{}) + require.NotEmpty(t, mapping.GetPath(), "proto mapping must carry a path") + assert.False(t, mapping.GetPath()[0].GetOptions().GetDisableAccessLog(), + "proto mapping must propagate DisableAccessLog=false so access-log emission stays on") +} diff --git a/management/internals/modules/agentnetwork/synthesizer_parser_redact_realstore_test.go b/management/internals/modules/agentnetwork/synthesizer_parser_redact_realstore_test.go new file mode 100644 index 000000000..5f42c3b39 --- /dev/null +++ b/management/internals/modules/agentnetwork/synthesizer_parser_redact_realstore_test.go @@ -0,0 +1,145 @@ +package agentnetwork + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/store" +) + +// parserRedactConfig mirrors the on-wire shape of the redact + capture knobs +// that both llm_request_parser and llm_response_parser unmarshal. We don't +// import the proxy-side packages from a management test (cross-module), so we +// decode the JSON directly and assert on the fields that are part of the +// synth contract. +type parserRedactConfig struct { + RedactPii bool `json:"redact_pii,omitempty"` + CapturePrompt *bool `json:"capture_prompt,omitempty"` // present only on the request parser + CaptureCompletion *bool `json:"capture_completion,omitempty"` // present only on the response parser +} + +// TestSynthesizeServices_RealStore_ParserConfigsCarryRedactPii is the +// management-side contract test for the request/response parser redaction +// wiring. When settings.RedactPii is true, the synthesised middleware chain +// MUST stamp redact_pii=true on both llm_request_parser and llm_response_parser +// configs — otherwise the parsers ship raw prompts / completions to the +// access log even though the account has opted in. This is exactly the live +// leak path that motivated the parser-side redaction in the first place. +func TestSynthesizeServices_RealStore_ParserConfigsCarryRedactPii(t *testing.T) { + ctx := context.Background() + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + settings := newSynthTestSettings() + settings.RedactPii = true + settings.EnablePromptCollection = true + require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings)) + + provider := newSynthTestProvider() + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", ""))) + + services, err := SynthesizeServices(ctx, s, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1, "exactly one synth service expected") + + for _, parserID := range []string{middlewareIDLLMRequestParser, middlewareIDLLMResponseParser} { + raw := decodeMiddlewareRawConfig(t, services[0], parserID) + var cfg parserRedactConfig + require.NoError(t, json.Unmarshal(raw, &cfg), "%s config must be valid JSON", parserID) + assert.True(t, cfg.RedactPii, "%s config must carry redact_pii=true when settings.RedactPii is on (otherwise the parser ships raw prompts/completions to the access log)", parserID) + } + // The capture flag is set explicitly to enable_prompt_collection on each + // parser. With it on here, both must allow emission. + reqCfg := decodeParserConfig(t, services[0], middlewareIDLLMRequestParser) + require.NotNil(t, reqCfg.CapturePrompt, "request parser must carry an explicit capture_prompt") + assert.True(t, *reqCfg.CapturePrompt, "capture_prompt=true when EnablePromptCollection=true") + respCfg := decodeParserConfig(t, services[0], middlewareIDLLMResponseParser) + require.NotNil(t, respCfg.CaptureCompletion, "response parser must carry an explicit capture_completion") + assert.True(t, *respCfg.CaptureCompletion, "capture_completion=true when EnablePromptCollection=true") +} + +// decodeParserConfig is a small helper around decodeMiddlewareRawConfig that +// also unmarshals into parserRedactConfig. +func decodeParserConfig(t *testing.T, svc *rpservice.Service, parserID string) parserRedactConfig { + t.Helper() + raw := decodeMiddlewareRawConfig(t, svc, parserID) + var cfg parserRedactConfig + require.NoError(t, json.Unmarshal(raw, &cfg), "%s config must be valid JSON", parserID) + return cfg +} + +// TestSynthesizeServices_RealStore_ParserConfigsSuppressCaptureWhenLogCollectionOnly +// is the contract test for the bug: enable_log_collection=true with +// enable_prompt_collection=false MUST result in capture_prompt=false on the +// request parser AND capture_completion=false on the response parser, so the +// access-log row stays metadata-only (provider, model, tokens, cost) and +// carries NO prompt input nor response output. Without this, operators who +// want billing-style logs end up with raw user prompts and model outputs in +// every access-log entry. +func TestSynthesizeServices_RealStore_ParserConfigsSuppressCaptureWhenLogCollectionOnly(t *testing.T) { + ctx := context.Background() + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + settings := newSynthTestSettings() + settings.EnableLogCollection = true // operator wants logs ON + settings.EnablePromptCollection = false // but NOT content capture + require.NoError(t, s.SaveAgentNetworkSettings(ctx, settings)) + + provider := newSynthTestProvider() + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", ""))) + + services, err := SynthesizeServices(ctx, s, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + reqCfg := decodeParserConfig(t, services[0], middlewareIDLLMRequestParser) + require.NotNil(t, reqCfg.CapturePrompt, "request parser must carry an explicit capture_prompt gate") + assert.False(t, *reqCfg.CapturePrompt, "capture_prompt MUST be false when EnablePromptCollection is off — otherwise llm.request_prompt_raw leaks user input into the access log") + + respCfg := decodeParserConfig(t, services[0], middlewareIDLLMResponseParser) + require.NotNil(t, respCfg.CaptureCompletion, "response parser must carry an explicit capture_completion gate") + assert.False(t, *respCfg.CaptureCompletion, "capture_completion MUST be false when EnablePromptCollection is off — otherwise llm.response_completion leaks model output into the access log") +} + +// TestSynthesizeServices_RealStore_ParserConfigsOmitRedactPiiWhenOff proves +// the inverse: with the account toggle off, the parser configs stay clean (no +// redact_pii field, which the parsers treat as zero / no redaction). This is +// the operator-opt-out path — the access log keeps raw prompts/completions +// for debugging until the operator opts in. +func TestSynthesizeServices_RealStore_ParserConfigsOmitRedactPiiWhenOff(t *testing.T) { + ctx := context.Background() + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + defer cleanup() + + // Default settings: RedactPii = false. + require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings())) + provider := newSynthTestProvider() + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", ""))) + + services, err := SynthesizeServices(ctx, s, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + for _, parserID := range []string{middlewareIDLLMRequestParser, middlewareIDLLMResponseParser} { + raw := decodeMiddlewareRawConfig(t, services[0], parserID) + // Inspect the decoded JSON directly: a struct decode would also pass + // if redact_pii were present-but-false. The contract is that the key + // is omitted entirely while the account toggle is off. + var rawCfg map[string]json.RawMessage + require.NoError(t, json.Unmarshal(raw, &rawCfg), "%s config must be valid JSON", parserID) + assert.NotContains(t, rawCfg, "redact_pii", + "%s config must omit redact_pii entirely while the account toggle is off", parserID) + } +} diff --git a/management/internals/modules/agentnetwork/synthesizer_realstore_test.go b/management/internals/modules/agentnetwork/synthesizer_realstore_test.go new file mode 100644 index 000000000..1e07c0e81 --- /dev/null +++ b/management/internals/modules/agentnetwork/synthesizer_realstore_test.go @@ -0,0 +1,174 @@ +package agentnetwork + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/store" + nbtypes "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// decodeServiceRouterConfig finds the llm_router middleware on the synthesised +// service's single target and decodes its config — the model→provider routing +// table the proxy authorises against. +func decodeServiceRouterConfig(t *testing.T, svc *rpservice.Service) routerConfig { + t.Helper() + require.NotEmpty(t, svc.Targets, "synth service must carry a target") + for _, mw := range svc.Targets[0].Options.Middlewares { + if mw.ID == middlewareIDLLMRouter { + var cfg routerConfig + require.NoError(t, json.Unmarshal(mw.ConfigJSON, &cfg), "router config must decode") + return cfg + } + } + t.Fatal("llm_router middleware not present on synthesised service") + return routerConfig{} +} + +// decodeMappingRouterConfig is the proto-wire equivalent: it pulls the +// llm_router config off the ProxyMapping the proxy actually receives. +func decodeMappingRouterConfig(t *testing.T, m *proto.ProxyMapping) routerConfig { + t.Helper() + require.NotEmpty(t, m.GetPath(), "mapping must carry a path") + for _, mw := range m.GetPath()[0].GetOptions().GetMiddlewares() { + if mw.GetId() == middlewareIDLLMRouter { + var cfg routerConfig + require.NoError(t, json.Unmarshal(mw.GetConfigJson(), &cfg), "wire router config must decode") + return cfg + } + } + t.Fatal("llm_router middleware not present on proxy mapping") + return routerConfig{} +} + +// TestSynthesizeServices_RealStore_SurvivesStatusToggle drives synthesis through +// a REAL sqlite store (Save → gorm/JSON serialize → reload → decrypt) instead of +// a MockStore, so it exercises the field round-trip that a provider/policy edit +// actually hits. Mock-based tests can't catch a field that dies in persistence; +// this one can. It then performs the exact operation that reproduced the live +// 403 — disable then re-enable the provider — and asserts the re-enabled state +// is fully routable again. +func TestSynthesizeServices_RealStore_SurvivesStatusToggle(t *testing.T) { + ctx := context.Background() + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings())) + provider := newSynthTestProvider() + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", ""))) + + assertRoutable := func(t *testing.T, stage string) { + services, err := SynthesizeServices(ctx, s, testAccountID) + require.NoError(t, err, stage) + require.Len(t, services, 1, "%s: exactly one synth service expected", stage) + svc := services[0] + + assert.True(t, svc.Private, "%s: synth service must be Private after store round-trip", stage) + assert.Equal(t, []string{"grp-eng"}, svc.AccessGroups, "%s: AccessGroups must survive the round-trip", stage) + + m := svc.ToProtoMapping(rpservice.Update, "", rpproxy.OIDCValidationConfig{}) + assert.True(t, m.GetPrivate(), "%s: proto mapping Private must be true (proxy gates tunnel-peer auth on it)", stage) + + cfg := decodeServiceRouterConfig(t, svc) + require.Len(t, cfg.Providers, 1, "%s: the enabled+linked provider must appear in the router config", stage) + assert.Equal(t, []string{"gpt-5.4"}, cfg.Providers[0].Models, "%s: provider models must reach the route", stage) + assert.Equal(t, []string{"grp-eng"}, cfg.Providers[0].AllowedGroupIDs, "%s: policy source groups must reach the route", stage) + } + + assertRoutable(t, "initial") + + provider.Enabled = false + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + disabled, err := SynthesizeServices(ctx, s, testAccountID) + require.NoError(t, err, "synthesis must not error with a disabled provider") + for _, svc := range disabled { + assert.Empty(t, decodeServiceRouterConfig(t, svc).Providers, + "a disabled provider must not appear in the router config (otherwise it would route while off)") + } + + provider.Enabled = true + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + assertRoutable(t, "after disable->enable") +} + +// captureController is a proxy.Controller that records the mappings reconcile +// pushes, so the test can inspect the exact wire payload — Private flag and +// router config included. +type captureController struct { + rpproxy.Controller + pushed []*proto.ProxyMapping +} + +func (c *captureController) GetOIDCValidationConfig() rpproxy.OIDCValidationConfig { + return rpproxy.OIDCValidationConfig{} +} + +func (c *captureController) SendServiceUpdateToCluster(_ context.Context, _ string, update *proto.ProxyMapping, _ string) { + c.pushed = append(c.pushed, update) +} + +// noopAccountManager satisfies the reconcile path's accountManager dependency. +type noopAccountManager struct { + account.Manager +} + +func (noopAccountManager) UpdateAccountPeers(context.Context, string, nbtypes.UpdateReason) {} + +// TestReconcile_RealStore_PushesPrivateAfterStatusToggle reproduces the live +// path end-to-end below the gRPC boundary: a real store + the real +// managerImpl.reconcile + a capturing proxy controller. It runs the operation +// that broke in production — provider disable then re-enable — and asserts the +// mapping reconcile pushes to the cluster after re-enable is Private=true and +// carries the routable provider. If reconcile ever pushes private=false (the +// symptom that left UserGroups empty → no_authorised_provider), this fails. +func TestReconcile_RealStore_PushesPrivateAfterStatusToggle(t *testing.T) { + ctx := context.Background() + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + defer cleanup() + + require.NoError(t, s.SaveAgentNetworkSettings(ctx, newSynthTestSettings())) + provider := newSynthTestProvider() + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + require.NoError(t, s.SaveAgentNetworkPolicy(ctx, newSynthTestPolicy(provider.ID, "grp-eng", ""))) + + ctrl := &captureController{} + m := &managerImpl{ + store: s, + accountManager: noopAccountManager{}, + proxyController: ctrl, + reconcileCache: make(map[string]map[string]*proto.ProxyMapping), + } + + m.reconcile(ctx, testAccountID) // initial, provider enabled + + provider.Enabled = false + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + m.reconcile(ctx, testAccountID) // disabled + + provider.Enabled = true + require.NoError(t, s.SaveAgentNetworkProvider(ctx, provider)) + m.reconcile(ctx, testAccountID) // re-enabled — the reproduction step + + require.NotEmpty(t, ctrl.pushed, "reconcile must push at least one mapping") + last := ctrl.pushed[len(ctrl.pushed)-1] + + assert.Equal(t, newSynthTestSettings().Endpoint(), last.GetDomain(), "synth domain on the wire") + assert.True(t, last.GetPrivate(), + "reconcile-pushed mapping after re-enable MUST be Private=true; a false here is the exact bug — the proxy skips ValidateTunnelPeer, UserGroups stays empty, and llm_router denies no_authorised_provider") + + cfg := decodeMappingRouterConfig(t, last) + require.Len(t, cfg.Providers, 1, "re-enabled provider must be back in the pushed router config") + assert.Equal(t, []string{"gpt-5.4"}, cfg.Providers[0].Models, "model must be routable again after re-enable") + assert.Equal(t, []string{"grp-eng"}, cfg.Providers[0].AllowedGroupIDs, "authorised groups must be present after re-enable") +} diff --git a/management/internals/modules/agentnetwork/synthesizer_test.go b/management/internals/modules/agentnetwork/synthesizer_test.go new file mode 100644 index 000000000..0b07f27b3 --- /dev/null +++ b/management/internals/modules/agentnetwork/synthesizer_test.go @@ -0,0 +1,1098 @@ +package agentnetwork + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" +) + +const ( + testAccountID = "acct-1" + testCluster = "eu.proxy.netbird.io" + testSubdomain = "violet" + testEndpoint = "violet.eu.proxy.netbird.io" +) + +func newSynthTestSettings() *types.Settings { + return &types.Settings{ + AccountID: testAccountID, + Cluster: testCluster, + Subdomain: testSubdomain, + } +} + +func newSynthTestProvider() *types.Provider { + return &types.Provider{ + ID: "prov-1", + AccountID: testAccountID, + ProviderID: "openai_api", + Name: "OpenAI", + UpstreamURL: "https://api.openai.com", + APIKey: "sk-test-key", + Enabled: true, + Models: []types.ProviderModel{{ID: "gpt-5.4", InputPer1k: 0.0025, OutputPer1k: 0.015}}, + SessionPrivateKey: "test-priv-key", + SessionPublicKey: "test-pub-key", + CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC), + } +} + +func newSynthTestPolicy(providerID, sourceGroupID, guardrailID string) *types.Policy { + policy := &types.Policy{ + ID: "pol-1", + AccountID: testAccountID, + Name: "engineers", + Enabled: true, + SourceGroups: []string{sourceGroupID}, + DestinationProviderIDs: []string{providerID}, + } + if guardrailID != "" { + policy.GuardrailIDs = []string{guardrailID} + } + return policy +} + +// expectSynthBaseInputs wires the four reads the new synthesiser issues +// in the happy path: settings, providers, policies, guardrails. +func expectSynthBaseInputs(mockStore *store.MockStore, ctx context.Context, settings *types.Settings, providers []*types.Provider, policies []*types.Policy, guardrails []*types.Guardrail) { + if settings == nil { + mockStore.EXPECT(). + GetAgentNetworkSettings(ctx, store.LockingStrengthNone, testAccountID). + Return(nil, status.Errorf(status.NotFound, "agent network settings not found")) + return + } + mockStore.EXPECT(). + GetAgentNetworkSettings(ctx, store.LockingStrengthNone, testAccountID). + Return(settings, nil) + mockStore.EXPECT(). + GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, testAccountID). + Return(providers, nil) + if hasEnabled(providers) { + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, testAccountID). + Return(policies, nil) + if hasEnabledPolicy(policies) { + mockStore.EXPECT(). + GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, testAccountID). + Return(guardrails, nil) + } + } +} + +func hasEnabled(providers []*types.Provider) bool { + for _, p := range providers { + if p != nil && p.Enabled { + return true + } + } + return false +} + +func hasEnabledPolicy(policies []*types.Policy) bool { + for _, p := range policies { + if p != nil && p.Enabled { + return true + } + } + return false +} + +func TestSynthesizeServices_HappyPath(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + openai := newSynthTestProvider() + anthropic := &types.Provider{ + ID: "prov-2", + AccountID: testAccountID, + ProviderID: "anthropic_api", + Name: "Anthropic", + UpstreamURL: "https://api.anthropic.com", + APIKey: "sk-ant-secret", + Enabled: true, + Models: []types.ProviderModel{{ID: "claude-opus-4-7"}}, + SessionPrivateKey: "ant-priv", + SessionPublicKey: "ant-pub", + CreatedAt: time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC), + } + + policyEng := newSynthTestPolicy(openai.ID, "grp-eng", "") + policyEng.ID = "pol-eng" + policyOps := newSynthTestPolicy(anthropic.ID, "grp-ops", "") + policyOps.ID = "pol-ops" + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{openai, anthropic}, + []*types.Policy{policyEng, policyOps}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err, "synthesis must succeed") + require.Len(t, services, 1, "exactly one service per account") + + svc := services[0] + assert.Equal(t, "agent-net-svc-acct-1", svc.ID, "service id is account-scoped") + assert.Equal(t, testAccountID, svc.AccountID, "service inherits account ID") + assert.Equal(t, testEndpoint, svc.Domain, "domain is settings.Endpoint() (subdomain.cluster)") + assert.Equal(t, testCluster, svc.ProxyCluster, "proxy cluster comes from settings") + assert.Equal(t, rpservice.ModeHTTP, svc.Mode, "synthesised services are HTTP mode") + assert.True(t, svc.Private, "synthesised services are always private") + assert.True(t, svc.Enabled, "synthesised services are enabled when emitted") + assert.Equal(t, []string{"grp-eng", "grp-ops"}, svc.AccessGroups, + "access groups union both policies' source groups (tunnel-peer auth)") + + require.Len(t, svc.Targets, 1, "single cluster target") + target := svc.Targets[0] + assert.Equal(t, rpservice.TargetTypeCluster, target.TargetType, "target type is cluster") + assert.Equal(t, testCluster, target.TargetId, "target id is the cluster address") + assert.Equal(t, "noop.invalid", target.Host, "host is the placeholder; router rewrites at request time") + assert.Equal(t, uint16(443), target.Port, "placeholder port") + assert.Equal(t, "https", target.Protocol, "placeholder scheme") + assert.True(t, target.Options.DirectUpstream, "synth targets imply direct upstream") + assert.True(t, target.Options.AgentNetwork, "synth targets must be flagged as agent_network") + + mws := target.Options.Middlewares + require.Len(t, mws, 8, "eight middlewares: request_parser, router, limit_check, identity_inject, guardrail, limit_record, cost_meter, response_parser") + assert.Equal(t, middlewareIDLLMRequestParser, mws[0].ID, "first middleware is the request parser") + assert.Equal(t, rpservice.MiddlewareSlotOnRequest, mws[0].Slot, "request parser runs on_request") + // Request parser carries the capture_prompt gate sourced from + // settings.EnablePromptCollection. The synth-test settings default + // EnablePromptCollection=false, so capture is off and the access-log row + // will not carry prompt content. + assert.JSONEq(t, `{"capture_prompt":false}`, string(mws[0].ConfigJSON), "request parser config must carry capture_prompt from synth") + + assert.Equal(t, middlewareIDLLMRouter, mws[1].ID, "second middleware is the router") + assert.Equal(t, rpservice.MiddlewareSlotOnRequest, mws[1].Slot, "router runs on_request") + assert.True(t, mws[1].CanMutate, "router must carry CanMutate=true; without it the framework drops the auth-header strip/inject AND the upstream rewrite, leaving the proxy to dial the placeholder noop.invalid") + require.NotEmpty(t, mws[1].ConfigJSON, "router config JSON must be populated") + + var routerCfg routerConfig + require.NoError(t, json.Unmarshal(mws[1].ConfigJSON, &routerCfg), "router config must unmarshal") + require.Len(t, routerCfg.Providers, 2, "both providers must reach the router") + assert.Equal(t, openai.ID, routerCfg.Providers[0].ID, "openai is first by created_at") + assert.Equal(t, "Bearer sk-test-key", routerCfg.Providers[0].AuthHeaderValue, "openai auth header value substitutes the API key") + assert.Equal(t, "Authorization", routerCfg.Providers[0].AuthHeaderName, "openai uses Authorization header") + assert.Equal(t, "https", routerCfg.Providers[0].UpstreamScheme, "openai scheme") + assert.Equal(t, "api.openai.com", routerCfg.Providers[0].UpstreamHost, "openai host") + assert.Equal(t, []string{"grp-eng"}, routerCfg.Providers[0].AllowedGroupIDs, "openai inherits policyEng's source groups") + assert.Equal(t, []string{"gpt-5.4"}, routerCfg.Providers[0].Models, + "the provider's configured model IDs must reach the router route — otherwise the model never matches and llm_router denies model_not_routable") + assert.Equal(t, anthropic.ID, routerCfg.Providers[1].ID, "anthropic follows openai by created_at") + assert.Equal(t, "sk-ant-secret", routerCfg.Providers[1].AuthHeaderValue, "anthropic value is the raw API key") + assert.Equal(t, "x-api-key", routerCfg.Providers[1].AuthHeaderName, "anthropic uses x-api-key header") + assert.Equal(t, []string{"grp-ops"}, routerCfg.Providers[1].AllowedGroupIDs, "anthropic inherits policyOps' source groups") + assert.Equal(t, []string{"claude-opus-4-7"}, routerCfg.Providers[1].Models, "anthropic's configured model ID must reach its route") + + assert.Equal(t, middlewareIDLLMLimitCheck, mws[2].ID, + "limit_check sits between router and identity_inject so deny paths skip header-stamp work") + assert.Equal(t, rpservice.MiddlewareSlotOnRequest, mws[2].Slot, "limit_check runs on_request") + + assert.Equal(t, middlewareIDLLMIdentityInject, mws[3].ID, "fourth middleware is identity inject") + assert.Equal(t, rpservice.MiddlewareSlotOnRequest, mws[3].Slot, "identity inject runs on_request") + assert.True(t, mws[3].CanMutate, "identity inject must carry CanMutate=true so its HeadersAdd / HeadersRemove pass the framework's mutation gate") + require.NotEmpty(t, mws[3].ConfigJSON, "identity inject config JSON must be populated even when no provider needs injection") + + assert.Equal(t, middlewareIDLLMGuardrail, mws[4].ID, "fifth middleware is the guardrail") + assert.Equal(t, rpservice.MiddlewareSlotOnRequest, mws[4].Slot, "guardrail runs on_request") + require.NotEmpty(t, mws[4].ConfigJSON, "guardrail config JSON must be populated") + + assert.Equal(t, middlewareIDLLMLimitRecord, mws[5].ID, + "limit_record sits FIRST in the response section so it RUNS LAST at runtime — needs cost_meter + response_parser to have stamped tokens / cost first") + assert.Equal(t, rpservice.MiddlewareSlotOnResponse, mws[5].Slot, "limit_record runs on_response") + + assert.Equal(t, middlewareIDCostMeter, mws[6].ID, "seventh middleware is the cost meter") + assert.Equal(t, rpservice.MiddlewareSlotOnResponse, mws[6].Slot, "cost meter runs on_response") + assert.Equal(t, []byte("{}"), mws[6].ConfigJSON, "cost meter carries an explicit empty config") + + assert.Equal(t, middlewareIDLLMResponseParser, mws[7].ID, "eighth middleware is the response parser") + assert.Equal(t, rpservice.MiddlewareSlotOnResponse, mws[7].Slot, "response parser runs on_response") +} + +func TestSynthesizeServices_NoSettings_ReturnsNil(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + expectSynthBaseInputs(mockStore, ctx, nil, nil, nil, nil) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + assert.Empty(t, services, "missing settings row must yield no synth") +} + +func TestSynthesizeServices_NoProviders_ReturnsNil(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), []*types.Provider{}, nil, nil) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + assert.Empty(t, services, "settings present but no providers must yield no synth") +} + +func TestSynthesizeServices_DisabledProvider_NoService(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + provider := newSynthTestProvider() + provider.Enabled = false + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{provider}, nil, nil) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + assert.Empty(t, services, "disabled provider must not synthesise a service") +} + +func TestSynthesizeServices_DisabledPolicy_NoService(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + provider := newSynthTestProvider() + policy := newSynthTestPolicy(provider.ID, "grp-eng", "") + policy.Enabled = false + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{provider}, []*types.Policy{policy}, nil) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + assert.Empty(t, services, "disabled policy must not trigger synthesis") +} + +func TestSynthesizeServices_RouterConfigOrdering(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + first := newSynthTestProvider() + first.ID = "prov-first" + first.CreatedAt = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + + second := newSynthTestProvider() + second.ID = "prov-second" + second.ProviderID = "anthropic_api" + second.UpstreamURL = "https://api.anthropic.com" + second.APIKey = "sk-ant" + second.CreatedAt = time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC) + + third := newSynthTestProvider() + third.ID = "prov-third" + third.ProviderID = "mistral_api" + third.UpstreamURL = "https://api.mistral.ai" + third.APIKey = "sk-mistral" + third.CreatedAt = time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC) + + policy := newSynthTestPolicy(first.ID, "grp-eng", "") + policy.DestinationProviderIDs = []string{first.ID, second.ID, third.ID} + + // Pass providers in shuffled order to confirm the synth sorts them. + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{second, first, third}, + []*types.Policy{policy}, []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var routerCfg routerConfig + for _, m := range mws { + if m.ID == middlewareIDLLMRouter { + require.NoError(t, json.Unmarshal(m.ConfigJSON, &routerCfg)) + break + } + } + require.Len(t, routerCfg.Providers, 3, "all three providers must be in the router config") + assert.Equal(t, first.ID, routerCfg.Providers[0].ID, "providers ordered by created_at; first is earliest") + assert.Equal(t, third.ID, routerCfg.Providers[1].ID, "second is mid") + assert.Equal(t, second.ID, routerCfg.Providers[2].ID, "third is latest") +} + +func TestSynthesizeServices_PolicyCheckConfig_UnionsSourceGroups(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + provider := newSynthTestProvider() + + policyA := newSynthTestPolicy(provider.ID, "grp-eng", "") + policyA.ID = "pol-a" + policyA.SourceGroups = []string{"grp-eng", "grp-shared"} + policyB := newSynthTestPolicy(provider.ID, "grp-ops", "") + policyB.ID = "pol-b" + policyB.SourceGroups = []string{"grp-ops", "grp-shared"} + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{provider}, + []*types.Policy{policyA, policyB}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var routerCfg routerConfig + for _, m := range mws { + if m.ID == middlewareIDLLMRouter { + require.NoError(t, json.Unmarshal(m.ConfigJSON, &routerCfg)) + break + } + } + require.Len(t, routerCfg.Providers, 1, "single provider authorised by both policies") + assert.Equal(t, provider.ID, routerCfg.Providers[0].ID) + assert.Equal(t, []string{"grp-eng", "grp-ops", "grp-shared"}, routerCfg.Providers[0].AllowedGroupIDs, + "source groups must be unioned and sorted; the duplicate grp-shared collapses") +} + +func TestSynthesizeServices_OrphanProvider_HasEmptyAllowedGroups(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + authorised := newSynthTestProvider() + authorised.ID = "prov-authed" + + orphan := newSynthTestProvider() + orphan.ID = "prov-orphan" + orphan.ProviderID = "anthropic_api" + orphan.UpstreamURL = "https://api.anthropic.com" + orphan.APIKey = "sk-ant" + orphan.CreatedAt = time.Date(2026, 4, 1, 0, 0, 0, 0, time.UTC) + + // Policy authorises the first provider only. + policy := newSynthTestPolicy(authorised.ID, "grp-eng", "") + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{authorised, orphan}, + []*types.Policy{policy}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var routerCfg routerConfig + for _, m := range mws { + if m.ID == middlewareIDLLMRouter { + require.NoError(t, json.Unmarshal(m.ConfigJSON, &routerCfg)) + break + } + } + + // Orphan providers are dropped from the router config entirely. + // The router treats an empty AllowedGroupIDs as a catch-all (right + // default for non-agent-network targets, wrong default here), so + // we don't ship them at all. Peers attempting to call models only + // the orphan claims see model_not_routable; peers calling models + // shared with the authorised provider get routed there. + require.Len(t, routerCfg.Providers, 1, "only the authorised provider reaches the router") + assert.Equal(t, authorised.ID, routerCfg.Providers[0].ID, + "authorised provider must be in router config") + assert.Equal(t, []string{"grp-eng"}, routerCfg.Providers[0].AllowedGroupIDs, + "authorised provider inherits the policy's source groups") +} + +// TestSynthesizeServices_IdentityInject_LiteLLM pins that a LiteLLM +// provider lands in the identity-inject middleware's config with the +// catalog-defined LiteLLM headers, while a non-LiteLLM provider does +// not. Together they prove the middleware is a no-op for accounts that +// don't use LiteLLM and stamps identity for those that do. +func TestSynthesizeServices_IdentityInject_LiteLLM(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + openai := newSynthTestProvider() + openai.ID = "prov-openai" + + litellm := newSynthTestProvider() + litellm.ID = "prov-litellm" + litellm.ProviderID = "litellm_proxy" + litellm.UpstreamURL = "https://litellm.acme.example.com" + litellm.APIKey = "sk-llm-master" + litellm.CreatedAt = time.Date(2026, 4, 2, 0, 0, 0, 0, time.UTC) + + policyOpenAI := newSynthTestPolicy(openai.ID, "grp-eng", "") + policyOpenAI.ID = "pol-openai" + policyLiteLLM := newSynthTestPolicy(litellm.ID, "grp-eng", "") + policyLiteLLM.ID = "pol-litellm" + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{openai, litellm}, + []*types.Policy{policyOpenAI, policyLiteLLM}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var injectCfg identityInjectConfig + for _, m := range mws { + if m.ID == middlewareIDLLMIdentityInject { + require.NoError(t, json.Unmarshal(m.ConfigJSON, &injectCfg)) + break + } + } + require.Len(t, injectCfg.Providers, 1, + "only providers whose catalog entry declares IdentityInjection should appear in the inject config") + entry := injectCfg.Providers[0] + assert.Equal(t, litellm.ID, entry.ProviderID, + "the LiteLLM provider must be the one identity-stamped, not the OpenAI direct provider") + require.NotNil(t, entry.HeaderPair, "LiteLLM uses the HeaderPair shape") + assert.Nil(t, entry.JSONMetadata, "shapes are mutually exclusive — JSONMetadata must be nil for HeaderPair providers") + assert.Equal(t, "x-litellm-end-user-id", entry.HeaderPair.EndUserIDHeader, + "end-user-id header must come from the catalog entry's IdentityInjection block") + assert.Equal(t, "x-litellm-tags", entry.HeaderPair.TagsHeader) +} + +// TestSynthesizeServices_IdentityInject_Bifrost_OperatorOverrides +// covers the customizable HeaderPair contract. The Bifrost catalog +// entry sets HeaderPair.Customizable=true with x-bf-dim-* defaults +// (placeholders surfaced by the dashboard, NOT authoritative at +// synth time). The wire header names that actually land on the +// inject middleware config come from the provider record's +// IdentityHeaderUserID / IdentityHeaderGroups fields verbatim. This +// lets operators pick between Bifrost's two attribution paths +// (always-on x-bf-lh-* logs metadata vs. label-declared x-bf-dim-* +// telemetry) per provider record without code changes. +// +// Three sub-cases under one fixture: full override, partial +// override (user kept, groups disabled), and ParserID empty so the +// proxy falls back to URL sniffing. +func TestSynthesizeServices_IdentityInject_Bifrost_OperatorOverrides(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + bifrost := newSynthTestProvider() + bifrost.ID = "prov-bifrost" + bifrost.ProviderID = "bifrost" + bifrost.UpstreamURL = "https://bifrost.acme.example.com/openai/v1" + bifrost.APIKey = "sk-bf-key" + bifrost.IdentityHeaderUserID = "x-bf-lh-netbird_user_id" + bifrost.IdentityHeaderGroups = "x-bf-lh-netbird_groups" + bifrost.CreatedAt = time.Date(2026, 4, 2, 0, 0, 0, 0, time.UTC) + + policy := newSynthTestPolicy(bifrost.ID, "grp-eng", "") + policy.ID = "pol-bifrost" + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{bifrost}, + []*types.Policy{policy}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var injectCfg identityInjectConfig + for _, m := range mws { + if m.ID == middlewareIDLLMIdentityInject { + require.NoError(t, json.Unmarshal(m.ConfigJSON, &injectCfg)) + break + } + } + require.Len(t, injectCfg.Providers, 1, + "single bifrost catalog entry → one inject config target — operator's URL path picks the parser, not the catalog id") + + entry := injectCfg.Providers[0] + assert.Equal(t, bifrost.ID, entry.ProviderID) + require.NotNil(t, entry.HeaderPair, "Bifrost uses HeaderPair shape") + assert.Equal(t, "x-bf-lh-netbird_user_id", entry.HeaderPair.EndUserIDHeader, + "operator-set IdentityHeaderUserID overrides the catalog's x-bf-dim- placeholder — proves the Customizable flag actually swaps the source of truth") + assert.Equal(t, "x-bf-lh-netbird_groups", entry.HeaderPair.TagsHeader, + "operator-set IdentityHeaderGroups overrides the catalog's x-bf-dim- placeholder") + assert.False(t, entry.HeaderPair.TagsInBody, + "body-inject flags stay catalog-owned — Bifrost reads identity from headers, body inject would be a no-op") + assert.False(t, entry.HeaderPair.EndUserIDInBody) +} + +// TestSynthesizeServices_IdentityInject_Bifrost_PartialDisable proves +// that clearing one of the IdentityHeader* fields disables stamping +// for THAT dimension only, leaving the other dimension active. +// Critical because the customizable contract says "empty = disabled +// for that dimension"; if the synth path silently fell back to the +// catalog default for an empty operator value, operators couldn't +// turn off groups while keeping user id (or vice versa). +func TestSynthesizeServices_IdentityInject_Bifrost_PartialDisable(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + bifrost := newSynthTestProvider() + bifrost.ID = "prov-bifrost" + bifrost.ProviderID = "bifrost" + bifrost.UpstreamURL = "https://bifrost.acme.example.com/openai/v1" + bifrost.APIKey = "sk-bf-key" + bifrost.IdentityHeaderUserID = "x-bf-lh-netbird_user_id" + bifrost.IdentityHeaderGroups = "" // operator explicitly disabled groups + bifrost.CreatedAt = time.Date(2026, 4, 2, 0, 0, 0, 0, time.UTC) + + policy := newSynthTestPolicy(bifrost.ID, "grp-eng", "") + policy.ID = "pol-bifrost" + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{bifrost}, + []*types.Policy{policy}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var injectCfg identityInjectConfig + for _, m := range mws { + if m.ID == middlewareIDLLMIdentityInject { + require.NoError(t, json.Unmarshal(m.ConfigJSON, &injectCfg)) + break + } + } + require.Len(t, injectCfg.Providers, 1) + entry := injectCfg.Providers[0] + require.NotNil(t, entry.HeaderPair, "user-id header is still set so the rule fires") + assert.Equal(t, "x-bf-lh-netbird_user_id", entry.HeaderPair.EndUserIDHeader) + assert.Empty(t, entry.HeaderPair.TagsHeader, + "groups header must be empty — operator cleared it; the inject middleware no-ops on empty header names so groups are NOT stamped") +} + +// TestSynthesizeServices_IdentityInject_Cloudflare_OperatorOverrides +// covers the JSONMetadata customizable contract: Cloudflare's +// catalog entry sets JSONMetadata.Customizable=true with +// netbird_user_id / netbird_groups defaults that the dashboard +// surfaces as placeholders. The actual JSON keys that land inside +// the cf-aig-metadata header come from the provider record's +// IdentityHeaderUserID / IdentityHeaderGroups fields. Reuses the +// same fields HeaderPair customizable does — the dimensions +// (user identity, groups) match; only the wire encoding (JSON key +// vs HTTP header name) differs. +func TestSynthesizeServices_IdentityInject_Cloudflare_OperatorOverrides(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + cf := newSynthTestProvider() + cf.ID = "prov-cf" + cf.ProviderID = "cloudflare_ai_gateway" + cf.UpstreamURL = "https://gateway.ai.cloudflare.com/v1/acct-xyz/my-gateway/openai" + cf.APIKey = "cf-aig-token" + cf.IdentityHeaderUserID = "team_member" + cf.IdentityHeaderGroups = "team_groups" + cf.CreatedAt = time.Date(2026, 4, 2, 0, 0, 0, 0, time.UTC) + + policy := newSynthTestPolicy(cf.ID, "grp-eng", "") + policy.ID = "pol-cf" + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{cf}, + []*types.Policy{policy}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var injectCfg identityInjectConfig + for _, m := range mws { + if m.ID == middlewareIDLLMIdentityInject { + require.NoError(t, json.Unmarshal(m.ConfigJSON, &injectCfg)) + break + } + } + require.Len(t, injectCfg.Providers, 1) + entry := injectCfg.Providers[0] + require.NotNil(t, entry.JSONMetadata, "Cloudflare uses JSONMetadata shape — single header carrying a JSON object") + assert.Nil(t, entry.HeaderPair, "shapes are mutually exclusive") + assert.Equal(t, "cf-aig-metadata", entry.JSONMetadata.Header, + "the wire header is catalog-owned (cf-aig-metadata) — operator can rename the JSON keys but not the header itself") + assert.Equal(t, "team_member", entry.JSONMetadata.UserKey, + "operator-set IdentityHeaderUserID overrides the catalog's netbird_user_id default — proves the JSONMetadata Customizable flag swaps the source of truth like HeaderPair already does") + assert.Equal(t, "team_groups", entry.JSONMetadata.GroupsKey, + "operator-set IdentityHeaderGroups overrides the catalog's netbird_groups default") +} + +// TestSynthesizeServices_IdentityInject_Portkey_NotCustomizable +// is the JSONMetadata negative case: Portkey's catalog entry leaves +// Customizable=false because Portkey's analytics dashboard reserves +// "_user" and "groups" as fixed JSON keys. An operator-set +// IdentityHeader* on a Portkey provider record must NOT override +// those keys, or Portkey's per-user filters silently break. +func TestSynthesizeServices_IdentityInject_Portkey_NotCustomizable(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + portkey := newSynthTestProvider() + portkey.ID = "prov-portkey" + portkey.ProviderID = "portkey" + portkey.UpstreamURL = "https://api.portkey.ai/v1" + portkey.APIKey = "portkey-account-key" + // Operator set these — but portkey's catalog entry has + // JSONMetadata.Customizable=false, so synth must IGNORE them + // and stick with the catalog's _user / groups defaults. + portkey.IdentityHeaderUserID = "should-be-ignored" + portkey.IdentityHeaderGroups = "should-be-ignored-too" + portkey.CreatedAt = time.Date(2026, 4, 2, 0, 0, 0, 0, time.UTC) + + policy := newSynthTestPolicy(portkey.ID, "grp-eng", "") + policy.ID = "pol-portkey" + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{portkey}, + []*types.Policy{policy}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var injectCfg identityInjectConfig + for _, m := range mws { + if m.ID == middlewareIDLLMIdentityInject { + require.NoError(t, json.Unmarshal(m.ConfigJSON, &injectCfg)) + break + } + } + require.Len(t, injectCfg.Providers, 1) + entry := injectCfg.Providers[0] + require.NotNil(t, entry.JSONMetadata) + assert.Equal(t, "_user", entry.JSONMetadata.UserKey, + "Portkey's reserved JSON key must hold — Customizable=false on the catalog blocks the operator's override fields") + assert.Equal(t, "groups", entry.JSONMetadata.GroupsKey, + "same fixed-schema guarantee for the groups dimension") +} + +// TestSynthesizeServices_IdentityInject_Vercel pins Vercel AI +// Gateway's wiring: HeaderPair shape with fixed wire names dictated +// by Vercel's Custom Reporting API (ai-reporting-user / +// ai-reporting-tags). Customizable=false on the catalog entry, so +// the synth path takes the catalog values verbatim and ignores any +// IdentityHeader* fields the operator might have set. Renaming +// these headers would just silently disable attribution — Vercel's +// reporting endpoint only matches the canonical names — so the +// fixed contract is the right semantic. +func TestSynthesizeServices_IdentityInject_Vercel(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + vercel := newSynthTestProvider() + vercel.ID = "prov-vercel" + vercel.ProviderID = "vercel_ai_gateway" + vercel.UpstreamURL = "https://ai-gateway.vercel.sh/v1" + vercel.APIKey = "vrc-team-key" + // Operator set these — they MUST be ignored because Vercel's + // catalog entry is non-customizable. Renaming the headers on + // the wire would defeat Vercel's reporting endpoint. + vercel.IdentityHeaderUserID = "should-be-ignored" + vercel.IdentityHeaderGroups = "should-be-ignored-too" + vercel.CreatedAt = time.Date(2026, 4, 2, 0, 0, 0, 0, time.UTC) + + policy := newSynthTestPolicy(vercel.ID, "grp-eng", "") + policy.ID = "pol-vercel" + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{vercel}, + []*types.Policy{policy}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var injectCfg identityInjectConfig + for _, m := range mws { + if m.ID == middlewareIDLLMIdentityInject { + require.NoError(t, json.Unmarshal(m.ConfigJSON, &injectCfg)) + break + } + } + require.Len(t, injectCfg.Providers, 1) + entry := injectCfg.Providers[0] + require.NotNil(t, entry.HeaderPair, "Vercel uses HeaderPair shape — separate ai-reporting-user / ai-reporting-tags headers, not a JSON blob") + assert.Nil(t, entry.JSONMetadata, "shapes are mutually exclusive") + assert.Equal(t, "ai-reporting-user", entry.HeaderPair.EndUserIDHeader, + "end-user-id header must be Vercel's canonical ai-reporting-user — renaming would silently disable attribution at Vercel's Custom Reporting endpoint") + assert.Equal(t, "ai-reporting-tags", entry.HeaderPair.TagsHeader, + "tags header must be Vercel's canonical ai-reporting-tags for the same reason") + assert.False(t, entry.HeaderPair.TagsInBody, + "Vercel reads from headers — body inject would be a LiteLLM-specific belt-and-suspenders unneeded here") + assert.False(t, entry.HeaderPair.EndUserIDInBody) +} + +// TestSynthesizeServices_IdentityInject_OpenRouter pins OpenRouter's +// wiring: HeaderPair shape with body-only injection. OpenRouter's +// per-user attribution is the OpenAI-standard `user` body field — +// there's no header path and no groups dimension at all. The catalog +// entry sets EndUserIDInBody=true with empty header names; the inject +// middleware writes user identity into the request body but stamps +// nothing on the header surface. Customizable=false so any operator +// IdentityHeader* fields are ignored. +// +// Also asserts the static ExtraHeaders surface: operators provide +// their app URL and display name on the provider record (HTTP-Referer +// and X-OpenRouter-Title), and these land on every upstream request +// so OpenRouter's app rankings / analytics attribute correctly. +func TestSynthesizeServices_IdentityInject_OpenRouter(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + openrouter := newSynthTestProvider() + openrouter.ID = "prov-openrouter" + openrouter.ProviderID = "openrouter" + openrouter.UpstreamURL = "https://openrouter.ai/api/v1" + openrouter.APIKey = "sk-or-v1-acme" + // These would only apply if the catalog entry was Customizable; + // it isn't, so they must be IGNORED. + openrouter.IdentityHeaderUserID = "should-be-ignored" + openrouter.IdentityHeaderGroups = "should-be-ignored-too" + openrouter.ExtraValues = map[string]string{ + "HTTP-Referer": "https://acme.example/agents", + "X-OpenRouter-Title": "Acme Agents", + } + openrouter.CreatedAt = time.Date(2026, 4, 2, 0, 0, 0, 0, time.UTC) + + policy := newSynthTestPolicy(openrouter.ID, "grp-eng", "") + policy.ID = "pol-openrouter" + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{openrouter}, + []*types.Policy{policy}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var injectCfg identityInjectConfig + for _, m := range mws { + if m.ID == middlewareIDLLMIdentityInject { + require.NoError(t, json.Unmarshal(m.ConfigJSON, &injectCfg)) + break + } + } + require.Len(t, injectCfg.Providers, 1) + entry := injectCfg.Providers[0] + require.NotNil(t, entry.HeaderPair, "OpenRouter uses HeaderPair shape — body-inject is the only branch active") + assert.Empty(t, entry.HeaderPair.EndUserIDHeader, + "OpenRouter does not document a header path for per-user identity; the inject must NOT stamp a header here. Customizable=false means operator IdentityHeader* fields are ignored.") + assert.Empty(t, entry.HeaderPair.TagsHeader, + "OpenRouter has no per-request groups / tags dimension — the tags header MUST stay empty") + assert.True(t, entry.HeaderPair.EndUserIDInBody, + "OpenRouter's only per-user attribution path is the OpenAI-standard `user` body field — body inject is the load-bearing piece for this provider") + assert.False(t, entry.HeaderPair.TagsInBody, + "no tags dimension at all → no tags-in-body either") + + // ExtraHeaders carry the operator-typed app URL + display name to + // OpenRouter's app rankings. The synth must echo BOTH static + // header values with the operator's typed strings. + require.Len(t, entry.ExtraHeaders, 2, + "both ExtraHeaders the catalog declares should land on the inject config when the operator filled in values") + byName := map[string]string{} + for _, h := range entry.ExtraHeaders { + byName[h.Name] = h.Value + } + assert.Equal(t, "https://acme.example/agents", byName["HTTP-Referer"], + "HTTP-Referer is OpenRouter's primary app identifier — must round-trip the operator-typed URL verbatim") + assert.Equal(t, "Acme Agents", byName["X-OpenRouter-Title"], + "X-OpenRouter-Title surfaces as the app's display name in OpenRouter's rankings — must round-trip operator's chosen string") +} + +// TestSynthesizeServices_IdentityInject_NonCustomizable_UsesCatalog +// is the LiteLLM-style negative case: when the catalog entry does +// NOT flag HeaderPair as Customizable, the catalog defaults are +// authoritative and any IdentityHeader* values on the provider +// record are ignored. Without this guard, an operator who set those +// fields on a non-Bifrost provider could accidentally break the +// gateway's wire protocol (LiteLLM only honours x-litellm-end-user- +// id; renaming it would silently drop spend tracking). +func TestSynthesizeServices_IdentityInject_NonCustomizable_UsesCatalog(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + litellm := newSynthTestProvider() + litellm.ID = "prov-litellm" + litellm.ProviderID = "litellm_proxy" + litellm.UpstreamURL = "https://litellm.acme.example.com" + litellm.APIKey = "sk-llm-master" + // Operator set these — but litellm_proxy's catalog entry has + // HeaderPair.Customizable=false, so the synth path must IGNORE + // these and fall back to the catalog defaults. + litellm.IdentityHeaderUserID = "x-bf-lh-should-be-ignored" + litellm.IdentityHeaderGroups = "x-bf-lh-should-be-ignored-too" + litellm.CreatedAt = time.Date(2026, 4, 2, 0, 0, 0, 0, time.UTC) + + policy := newSynthTestPolicy(litellm.ID, "grp-eng", "") + policy.ID = "pol-litellm" + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{litellm}, + []*types.Policy{policy}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var injectCfg identityInjectConfig + for _, m := range mws { + if m.ID == middlewareIDLLMIdentityInject { + require.NoError(t, json.Unmarshal(m.ConfigJSON, &injectCfg)) + break + } + } + require.Len(t, injectCfg.Providers, 1) + entry := injectCfg.Providers[0] + require.NotNil(t, entry.HeaderPair) + assert.Equal(t, "x-litellm-end-user-id", entry.HeaderPair.EndUserIDHeader, + "Customizable=false on the catalog entry must hold — operator IdentityHeader* fields cannot rename a fixed wire protocol's headers") + assert.Equal(t, "x-litellm-tags", entry.HeaderPair.TagsHeader, + "Customizable=false on the catalog entry must hold for tags too") +} + +func TestSynthesizeServices_GuardrailMerge_AllowlistUnion_LimitsRestrictive(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + provider := newSynthTestProvider() + + guardrailA := &types.Guardrail{ + ID: "g-a", + AccountID: testAccountID, + Checks: types.GuardrailChecks{ + ModelAllowlist: types.GuardrailModelAllowlist{Enabled: true, Models: []string{"gpt-5.4-mini"}}, + }, + } + guardrailB := &types.Guardrail{ + ID: "g-b", + AccountID: testAccountID, + Checks: types.GuardrailChecks{ + ModelAllowlist: types.GuardrailModelAllowlist{Enabled: true, Models: []string{"gpt-5.4-pro"}}, + }, + } + + policyA := newSynthTestPolicy(provider.ID, "grp-a", guardrailA.ID) + policyA.ID = "pol-a" + policyB := newSynthTestPolicy(provider.ID, "grp-b", guardrailB.ID) + policyB.ID = "pol-b" + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{provider}, + []*types.Policy{policyA, policyB}, + []*types.Guardrail{guardrailA, guardrailB}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var guardrailJSON []byte + for _, m := range mws { + if m.ID == middlewareIDLLMGuardrail { + guardrailJSON = m.ConfigJSON + break + } + } + require.NotEmpty(t, guardrailJSON, "guardrail middleware config JSON must be present") + + var cfg guardrailConfig + require.NoError(t, json.Unmarshal(guardrailJSON, &cfg), "guardrail config must unmarshal cleanly") + assert.ElementsMatch(t, []string{"gpt-5.4-mini", "gpt-5.4-pro"}, cfg.ModelAllowlist, + "model allowlist union must keep both models") +} + +func TestSynthesizeServices_BackfillsMissingSessionKeys(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + provider := newSynthTestProvider() + provider.SessionPrivateKey = "" + provider.SessionPublicKey = "" + policy := newSynthTestPolicy(provider.ID, "grp-eng", "") + + mockStore.EXPECT(). + GetAgentNetworkSettings(ctx, store.LockingStrengthNone, testAccountID). + Return(newSynthTestSettings(), nil) + mockStore.EXPECT(). + GetAccountAgentNetworkProviders(ctx, store.LockingStrengthNone, testAccountID). + Return([]*types.Provider{provider}, nil) + mockStore.EXPECT(). + GetAccountAgentNetworkPolicies(ctx, store.LockingStrengthNone, testAccountID). + Return([]*types.Policy{policy}, nil) + // Backfill must persist the new keys before synthesising. + mockStore.EXPECT(). + SaveAgentNetworkProvider(ctx, gomock.Any()). + DoAndReturn(func(_ context.Context, p *types.Provider) error { + require.NotEmpty(t, p.SessionPrivateKey, "backfill must populate private key") + require.NotEmpty(t, p.SessionPublicKey, "backfill must populate public key") + return nil + }) + mockStore.EXPECT(). + GetAccountAgentNetworkGuardrails(ctx, store.LockingStrengthNone, testAccountID). + Return([]*types.Guardrail{}, nil) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1, "synthesis must complete after backfill") + assert.NotEmpty(t, services[0].SessionPrivateKey, "synthesised service inherits the freshly-minted private key") + assert.NotEmpty(t, services[0].SessionPublicKey, "synthesised service inherits the freshly-minted public key") +} + +func TestSynthesizeServices_HTTPUpstream_KeepsExplicitPort(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + provider := newSynthTestProvider() + provider.UpstreamURL = "http://internal-llm.lan:8080" + policy := newSynthTestPolicy(provider.ID, "grp-eng", "") + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{provider}, + []*types.Policy{policy}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var routerCfg routerConfig + for _, m := range mws { + if m.ID == middlewareIDLLMRouter { + require.NoError(t, json.Unmarshal(m.ConfigJSON, &routerCfg)) + break + } + } + require.Len(t, routerCfg.Providers, 1) + assert.Equal(t, "http", routerCfg.Providers[0].UpstreamScheme, "scheme follows the upstream URL") + assert.Equal(t, "internal-llm.lan:8080", routerCfg.Providers[0].UpstreamHost, + "explicit port travels with host so the router rewrite carries an authority the proxy can dial") +} + +func TestSynthesizeServices_UpstreamURLPath_FlowsToRouter(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + // Provider configured with a path-prefixed upstream — common for + // OpenAI-compatible endpoints behind corporate gateways. The path + // is the router's disambiguator when two providers claim the same + // model, so it must round-trip through buildRouterConfigJSON with + // the trailing slash trimmed. + provider := newSynthTestProvider() + provider.UpstreamURL = "https://corp.example.com/openai/" + policy := newSynthTestPolicy(provider.ID, "grp-eng", "") + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{provider}, + []*types.Policy{policy}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var routerCfg routerConfig + for _, m := range mws { + if m.ID == middlewareIDLLMRouter { + require.NoError(t, json.Unmarshal(m.ConfigJSON, &routerCfg)) + break + } + } + require.Len(t, routerCfg.Providers, 1) + assert.Equal(t, "corp.example.com", routerCfg.Providers[0].UpstreamHost, "host should drop the path") + assert.Equal(t, "/openai", routerCfg.Providers[0].UpstreamPath, + "upstream path must be carried so the router can disambiguate same-model providers; trailing slash trimmed for stable string-prefix matching") +} + +func TestSynthesizeServices_UnknownProviderID_FailsClosed(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + provider := newSynthTestProvider() + provider.ProviderID = "nonexistent_provider" + policy := newSynthTestPolicy(provider.ID, "grp-eng", "") + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{provider}, + []*types.Policy{policy}, + []*types.Guardrail{}) + + _, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.Error(t, err, "synthesis must fail when the catalog can't resolve the provider id") + assert.Contains(t, err.Error(), "unknown catalog id", "error must surface the misconfiguration") +} + +func TestSynthesizeServices_EmptyAPIKey_FailsClosed(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + provider := newSynthTestProvider() + provider.APIKey = "" + policy := newSynthTestPolicy(provider.ID, "grp-eng", "") + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{provider}, + []*types.Policy{policy}, + []*types.Guardrail{}) + + _, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.Error(t, err, "synthesis must refuse a provider with no api key") + assert.Contains(t, err.Error(), "no api key", "error must surface the missing credential") +} diff --git a/management/internals/modules/agentnetwork/types/accesslog.go b/management/internals/modules/agentnetwork/types/accesslog.go new file mode 100644 index 000000000..92b8bc358 --- /dev/null +++ b/management/internals/modules/agentnetwork/types/accesslog.go @@ -0,0 +1,289 @@ +package types + +import ( + "time" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// AgentNetworkAccessLog is the dedicated, flattened agent-network access-log +// row. Unlike the shared reverse-proxy AccessLogEntry (which kept LLM data in +// an opaque metadata JSON blob), the LLM dimensions live in first-class, +// indexed columns so the access-log surface can filter server-side by +// user / group / provider / model / decision. +type AgentNetworkAccessLog struct { + // The composite index idx_anal_acct_session_ts backs the session-grouped + // listing (GROUP BY session_id ORDER BY MAX(timestamp) within an account); + // the single-column indexes still back the flat filters/sorts. + ID string `gorm:"primaryKey"` + AccountID string `gorm:"index;index:idx_anal_acct_session_ts,priority:1"` + ServiceID string `gorm:"index"` + Timestamp time.Time `gorm:"index;index:idx_anal_acct_session_ts,priority:3"` + UserID string `gorm:"index"` + SourceIP string + Method string + Host string + Path string `gorm:"type:text"` + Duration time.Duration + StatusCode int `gorm:"index"` + AuthMethod string + BytesUpload int64 + BytesDownload int64 + + // Flattened LLM dimensions (queryable). Sourced from proxy metadata keys. + Provider string `gorm:"index"` // vendor, e.g. "openai" (llm.provider) + Model string `gorm:"index"` // llm.model + SessionID string `gorm:"index;index:idx_anal_acct_session_ts,priority:2"` // llm.session_id — groups a conversation / coding session + ResolvedProviderID string `gorm:"index"` // llm.resolved_provider_id + SelectedPolicyID string `gorm:"index"` // llm.selected_policy_id + Decision string `gorm:"index"` // llm_policy.decision (allow/deny) + DenyReason string // llm_policy.reason (raw code, mapped in the UI) + InputTokens int64 + OutputTokens int64 + TotalTokens int64 + CostUSD float64 + Stream bool + + // Prompt capture. Only populated when prompt collection is enabled + // (account master switch AND policy guardrail). Heavy free text. + RequestPrompt string `gorm:"type:text"` + ResponseCompletion string `gorm:"type:text"` + + CreatedAt time.Time + + // GroupIDs is the authorising group ids for this entry, hydrated from the + // group child table on read. Not a column. + GroupIDs []string `gorm:"-"` +} + +// TableName keeps agent-network access logs in their own table, separate from +// the reverse-proxy AccessLogEntry table. +func (AgentNetworkAccessLog) TableName() string { return "agent_network_access_log" } + +// ToAPIResponse renders the flattened entry as the API representation. +func (a *AgentNetworkAccessLog) ToAPIResponse() api.AgentNetworkAccessLog { + out := api.AgentNetworkAccessLog{ + Id: a.ID, + ServiceId: a.ServiceID, + Timestamp: a.Timestamp, + StatusCode: a.StatusCode, + DurationMs: int(a.Duration.Milliseconds()), + InputTokens: a.InputTokens, + OutputTokens: a.OutputTokens, + TotalTokens: a.TotalTokens, + CostUsd: a.CostUSD, + Stream: &a.Stream, + } + + out.UserId = strPtr(a.UserID) + out.SourceIp = strPtr(a.SourceIP) + out.Method = strPtr(a.Method) + out.Host = strPtr(a.Host) + out.Path = strPtr(a.Path) + out.Provider = strPtr(a.Provider) + out.Model = strPtr(a.Model) + out.SessionId = strPtr(a.SessionID) + out.ResolvedProviderId = strPtr(a.ResolvedProviderID) + out.SelectedPolicyId = strPtr(a.SelectedPolicyID) + out.Decision = strPtr(a.Decision) + out.DenyReason = strPtr(a.DenyReason) + out.RequestPrompt = strPtr(a.RequestPrompt) + out.ResponseCompletion = strPtr(a.ResponseCompletion) + + if len(a.GroupIDs) > 0 { + groups := a.GroupIDs + out.GroupIds = &groups + } + return out +} + +// strPtr returns a pointer to s, or nil when s is empty — so empty optional +// fields are omitted from the JSON rather than serialised as "". +func strPtr(s string) *string { + if s == "" { + return nil + } + return &s +} + +// AgentNetworkAccessLogSession is a session-grouped view of access-log entries: +// all requests sharing a session id (or, for a request the client sent no +// session id for, that single request keyed by its own row id) folded into one +// summary plus its ordered entries. Assembled in Go from a page of entries — it +// is not a stored table. +type AgentNetworkAccessLogSession struct { + SessionID string // empty for a session-less (singleton) request + UserID string + GroupIDs []string // union of the entries' authorising groups + StartedAt time.Time + EndedAt time.Time + RequestCount int + InputTokens int64 + OutputTokens int64 + TotalTokens int64 + CostUSD float64 + Providers []string // distinct vendors seen in the session + Models []string // distinct models seen in the session + Decision string // "deny" if any entry was denied, else "allow" + Entries []*AgentNetworkAccessLog +} + +// sessionKey is the grouping key for an entry: its session id, or — when the +// client sent none — its own row id, so session-less requests each form their +// own singleton group. Must match the SQL group key +// COALESCE(NULLIF(session_id, ”), id). +func sessionKey(e *AgentNetworkAccessLog) string { + if e.SessionID != "" { + return e.SessionID + } + return e.ID +} + +// FoldAccessLogSessions folds a page of entries into per-session summaries, +// preserving the order of orderedKeys (the already-sorted, already-paginated +// session keys from the store). Entries are expected pre-sorted by timestamp +// within each key. Aggregation (sums, distinct providers/models, deny rollup) +// happens here in Go rather than in SQL so the query stays engine-portable. +func FoldAccessLogSessions(orderedKeys []string, entries []*AgentNetworkAccessLog) []*AgentNetworkAccessLogSession { + byKey := make(map[string]*AgentNetworkAccessLogSession, len(orderedKeys)) + order := make([]*AgentNetworkAccessLogSession, 0, len(orderedKeys)) + for _, k := range orderedKeys { + if _, ok := byKey[k]; ok { + continue + } + sess := &AgentNetworkAccessLogSession{Decision: "allow"} + byKey[k] = sess + order = append(order, sess) + } + + seenBy := make(map[string]*sessionSeen, len(orderedKeys)) + + for _, e := range entries { + k := sessionKey(e) + sess, ok := byKey[k] + if !ok { + continue // entry outside the paged set; defensive + } + sk := seenBy[k] + if sk == nil { + sk = newSessionSeen() + seenBy[k] = sk + sess.SessionID = e.SessionID + sess.UserID = e.UserID + sess.StartedAt = e.Timestamp + sess.EndedAt = e.Timestamp + } + sess.foldEntry(sk, e) + } + + out := make([]*AgentNetworkAccessLogSession, 0, len(order)) + for _, sess := range order { + if sess.RequestCount > 0 { + out = append(out, sess) + } + } + return out +} + +// sessionSeen tracks the distinct provider / model / group values already +// recorded for a session so foldEntry can dedupe as it accumulates. +type sessionSeen struct{ providers, models, groups map[string]struct{} } + +func newSessionSeen() *sessionSeen { + return &sessionSeen{ + providers: map[string]struct{}{}, + models: map[string]struct{}{}, + groups: map[string]struct{}{}, + } +} + +// foldEntry accumulates a single entry into the session summary: sums, time +// bounds, first-seen user, deny rollup, distinct provider / model / group +// lists, and the entry itself. +func (sess *AgentNetworkAccessLogSession) foldEntry(sk *sessionSeen, e *AgentNetworkAccessLog) { + sess.RequestCount++ + sess.InputTokens += e.InputTokens + sess.OutputTokens += e.OutputTokens + sess.TotalTokens += e.TotalTokens + sess.CostUSD += e.CostUSD + if e.Timestamp.Before(sess.StartedAt) { + sess.StartedAt = e.Timestamp + } + if e.Timestamp.After(sess.EndedAt) { + sess.EndedAt = e.Timestamp + } + if sess.UserID == "" { + sess.UserID = e.UserID + } + if e.Decision == "deny" { + sess.Decision = "deny" + } + sess.Providers = appendDistinct(sk.providers, sess.Providers, e.Provider) + sess.Models = appendDistinct(sk.models, sess.Models, e.Model) + for _, g := range e.GroupIDs { + sess.GroupIDs = appendDistinct(sk.groups, sess.GroupIDs, g) + } + sess.Entries = append(sess.Entries, e) +} + +// appendDistinct appends v to list when v is non-empty and not already recorded +// in seen, returning the possibly-extended list. +func appendDistinct(seen map[string]struct{}, list []string, v string) []string { + if v == "" { + return list + } + if _, dup := seen[v]; dup { + return list + } + seen[v] = struct{}{} + return append(list, v) +} + +// ToAPIResponse renders the session summary (and its entries) as the API +// representation. +func (sess *AgentNetworkAccessLogSession) ToAPIResponse() api.AgentNetworkAccessLogSession { + entries := make([]api.AgentNetworkAccessLog, 0, len(sess.Entries)) + for _, e := range sess.Entries { + entries = append(entries, e.ToAPIResponse()) + } + + out := api.AgentNetworkAccessLogSession{ + StartedAt: sess.StartedAt, + EndedAt: sess.EndedAt, + RequestCount: sess.RequestCount, + InputTokens: sess.InputTokens, + OutputTokens: sess.OutputTokens, + TotalTokens: sess.TotalTokens, + CostUsd: sess.CostUSD, + Decision: sess.Decision, + Entries: entries, + } + out.SessionId = strPtr(sess.SessionID) + out.UserId = strPtr(sess.UserID) + if len(sess.Providers) > 0 { + providers := sess.Providers + out.Providers = &providers + } + if len(sess.Models) > 0 { + models := sess.Models + out.Models = &models + } + if len(sess.GroupIDs) > 0 { + groups := sess.GroupIDs + out.GroupIds = &groups + } + return out +} + +// AgentNetworkAccessLogGroup is the normalised many-to-many row linking a log +// entry to one authorising group, so the access-log endpoint can filter by +// group with a simple `group_id IN (...)` join instead of substring-matching a +// CSV column. +type AgentNetworkAccessLogGroup struct { + LogID string `gorm:"primaryKey"` + GroupID string `gorm:"primaryKey;index"` + AccountID string `gorm:"index"` +} + +// TableName names the access-log group child table. +func (AgentNetworkAccessLogGroup) TableName() string { return "agent_network_access_log_group" } diff --git a/management/internals/modules/agentnetwork/types/accesslogfilter.go b/management/internals/modules/agentnetwork/types/accesslogfilter.go new file mode 100644 index 000000000..d571a87b6 --- /dev/null +++ b/management/internals/modules/agentnetwork/types/accesslogfilter.go @@ -0,0 +1,249 @@ +package types + +import ( + "math" + "net/http" + "strconv" + "strings" + "time" + + "github.com/netbirdio/netbird/shared/management/status" +) + +const ( + // AccessLogDefaultPageSize is the default number of records per page. + AccessLogDefaultPageSize = 50 + // AccessLogMaxPageSize is the maximum number of records allowed per page. + AccessLogMaxPageSize = 100 + + accessLogDefaultSortBy = "timestamp" + accessLogDefaultSortOrder = "desc" + + // usageOverviewDefaultLookback bounds an unbounded usage-overview query so + // it never aggregates an account's entire history into memory. + usageOverviewDefaultLookback = 90 * 24 * time.Hour + // usageOverviewMaxRange caps how far back an explicit range may reach. + usageOverviewMaxRange = 366 * 24 * time.Hour +) + +// ApplyUsageOverviewBounds bounds a missing or over-wide date range so the +// in-memory usage aggregation can't load an account's full usage history. An +// absent range defaults to the last usageOverviewDefaultLookback; a range wider +// than usageOverviewMaxRange is clamped from the (possibly defaulted) end. +func (f *AgentNetworkAccessLogFilter) ApplyUsageOverviewBounds(now time.Time) { + end := now + if f.EndDate != nil { + end = *f.EndDate + } + f.EndDate = &end + if f.StartDate == nil { + start := end.Add(-usageOverviewDefaultLookback) + f.StartDate = &start + return + } + if end.Sub(*f.StartDate) > usageOverviewMaxRange { + start := end.Add(-usageOverviewMaxRange) + f.StartDate = &start + } +} + +// accessLogSortFields maps the API sort_by values to their database columns. +var accessLogSortFields = map[string]string{ + "timestamp": "timestamp", + "model": "model", + "provider": "provider", + "status_code": "status_code", + "duration": "duration", + "cost_usd": "cost_usd", + "total_tokens": "total_tokens", + "user_id": "user_id", + "decision": "decision", +} + +// sessionSortExprs maps the API sort_by values to the aggregate expression a +// session-grouped query sorts on. A session has no single row, so per-row +// columns become aggregates: "timestamp" (the default) is the session's last +// activity, "started_at" its first. Every expression is a plain SQL aggregate +// over the GROUP BY, so the ordering stays portable across SQLite and Postgres. +// Keys absent here (e.g. "model", "provider") fall back to the default — the +// grouped UI only offers the session-level sorts below. +var sessionSortExprs = map[string]string{ //nolint:gosec // G101 false positive: "total_tokens" sort key, not a credential + "timestamp": "MAX(timestamp)", + "started_at": "MIN(timestamp)", + "cost_usd": "SUM(cost_usd)", + "total_tokens": "SUM(total_tokens)", + "duration": "SUM(duration)", + "request_count": "COUNT(*)", + "status_code": "MAX(status_code)", + "user_id": "MIN(user_id)", + "decision": "MAX(decision)", // "deny" > "allow": DESC surfaces denied sessions first +} + +// AgentNetworkAccessLogFilter holds pagination, filtering and sorting +// parameters for the agent-network access-log listing. Group / provider / +// model are multi-valued (the UI uses multi-select; an entry matches when it +// matches any selected value). +type AgentNetworkAccessLogFilter struct { + Page int + PageSize int + + SortBy string + SortOrder string + + Search *string // log id, host, path, model, user email/name + UserID *string // exact user id (the dashboard sends the picked user's id) + SessionID *string // exact session id — groups one conversation / coding session + GroupIDs []string // authorising group ids (match any) + ProviderIDs []string // resolved provider ids (match any) + Models []string // models (match any) + Decision *string // policy decision (allow/deny) + PathPrefix *string // request path prefix (path LIKE 'prefix%') + StartDate *time.Time // timestamp >= start_date + EndDate *time.Time // timestamp <= end_date +} + +// ParseFromRequest fills the filter from the request query parameters. It +// returns a validation error when a supplied start_date / end_date is present +// but not valid RFC3339: silently dropping a malformed date would broaden the +// query (and, for the usage overview, fall back to the default window). +func (f *AgentNetworkAccessLogFilter) ParseFromRequest(r *http.Request) error { + q := r.URL.Query() + + f.Page = parseAccessLogPositiveInt(q.Get("page"), 1) + f.PageSize = min(parseAccessLogPositiveInt(q.Get("page_size"), AccessLogDefaultPageSize), AccessLogMaxPageSize) + + f.SortBy = parseAccessLogSortField(q.Get("sort_by")) + f.SortOrder = parseAccessLogSortOrder(q.Get("sort_order")) + + f.Search = parseAccessLogOptionalString(q.Get("search")) + f.UserID = parseAccessLogOptionalString(q.Get("user_id")) + f.SessionID = parseAccessLogOptionalString(q.Get("session_id")) + f.Decision = parseAccessLogOptionalString(q.Get("decision")) + f.PathPrefix = parseAccessLogOptionalString(q.Get("path")) + // Multi-value filters accept either repeated params (?group_id=a&group_id=b) + // or a single comma-separated value (?group_id=a,b) so both the OpenAPI + // array form and the dashboard's single-value query builder work. + f.GroupIDs = splitMultiValue(q["group_id"]) + f.ProviderIDs = splitMultiValue(q["provider_id"]) + f.Models = splitMultiValue(q["model"]) + + var err error + if f.StartDate, err = parseAccessLogOptionalRFC3339(q.Get("start_date")); err != nil { + return status.Errorf(status.InvalidArgument, "invalid start_date: %v", err) + } + if f.EndDate, err = parseAccessLogOptionalRFC3339(q.Get("end_date")); err != nil { + return status.Errorf(status.InvalidArgument, "invalid end_date: %v", err) + } + return nil +} + +// GetSortColumn returns the database column for the active sort field. +func (f *AgentNetworkAccessLogFilter) GetSortColumn() string { + if col, ok := accessLogSortFields[f.SortBy]; ok { + return col + } + return accessLogSortFields[accessLogDefaultSortBy] +} + +// GetSessionSortExpr returns the aggregate ORDER BY expression for the active +// sort field when listing session-grouped logs. Unknown / non-session sort +// fields fall back to the default (last activity). +func (f *AgentNetworkAccessLogFilter) GetSessionSortExpr() string { + if expr, ok := sessionSortExprs[f.SortBy]; ok { + return expr + } + return sessionSortExprs[accessLogDefaultSortBy] +} + +// GetSortOrder returns the normalised sort order ("ASC"/"DESC"). +func (f *AgentNetworkAccessLogFilter) GetSortOrder() string { + if strings.EqualFold(f.SortOrder, "asc") { + return "ASC" + } + return "DESC" +} + +// GetLimit returns the page size, defaulting/clamping when unset. +func (f *AgentNetworkAccessLogFilter) GetLimit() int { + if f.PageSize <= 0 { + return AccessLogDefaultPageSize + } + return min(f.PageSize, AccessLogMaxPageSize) +} + +// GetOffset returns the zero-based row offset for the active page. Page is +// user-controlled, so the multiplication is guarded against int overflow. +func (f *AgentNetworkAccessLogFilter) GetOffset() int { + limit := f.GetLimit() + if f.Page <= 1 || limit <= 0 { + return 0 + } + if f.Page-1 > math.MaxInt/limit { + return math.MaxInt - (math.MaxInt % limit) + } + return (f.Page - 1) * limit +} + +func parseAccessLogPositiveInt(s string, def int) int { + if v, err := strconv.Atoi(strings.TrimSpace(s)); err == nil && v > 0 { + return v + } + return def +} + +func parseAccessLogSortField(s string) string { + if _, ok := accessLogSortFields[s]; ok { + return s + } + // Session-grouped listings sort on aggregates (e.g. request_count, + // started_at) that aren't flat-row columns; accept those too. The flat + // listing maps any unknown field back to the default, so this stays safe + // for the non-grouped endpoint. + if _, ok := sessionSortExprs[s]; ok { + return s + } + return accessLogDefaultSortBy +} + +func parseAccessLogSortOrder(s string) string { + if strings.EqualFold(s, "asc") { + return "asc" + } + return accessLogDefaultSortOrder +} + +func parseAccessLogOptionalString(s string) *string { + if s = strings.TrimSpace(s); s != "" { + return &s + } + return nil +} + +func parseAccessLogOptionalRFC3339(s string) (*time.Time, error) { + if s = strings.TrimSpace(s); s == "" { + return nil, nil //nolint:nilnil // not provided: no value and no error + } + t, err := time.Parse(time.RFC3339, s) + if err != nil { + return nil, err + } + return &t, nil +} + +// splitMultiValue flattens repeated query params and comma-separated values +// into a single trimmed, blank-free list. Returns nil when nothing remains so +// callers can skip the filter entirely. +func splitMultiValue(values []string) []string { + out := make([]string, 0, len(values)) + for _, raw := range values { + for _, v := range strings.Split(raw, ",") { + if v = strings.TrimSpace(v); v != "" { + out = append(out, v) + } + } + } + if len(out) == 0 { + return nil + } + return out +} diff --git a/management/internals/modules/agentnetwork/types/budgetrule.go b/management/internals/modules/agentnetwork/types/budgetrule.go new file mode 100644 index 000000000..6f02a5b92 --- /dev/null +++ b/management/internals/modules/agentnetwork/types/budgetrule.go @@ -0,0 +1,106 @@ +package types + +import ( + "time" + + "github.com/rs/xid" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// AccountBudgetRule is an account-level, limit-only rule bound to groups +// and/or users. It mirrors the policy budget experience without any routing: +// it carries the same cap shape as a policy (PolicyLimits) but never selects a +// provider. Rules apply across policies as an always-on ceiling — every +// applicable rule binds (min-wins), so a rule can only tighten a caller's +// effective limit, never loosen it. +// +// TargetGroups matches when it intersects the caller's groups; TargetUsers +// binds a specific user directly. Empty TargetGroups and TargetUsers means the +// rule applies to every caller (the account-wide default). +type AccountBudgetRule struct { + ID string `gorm:"primaryKey"` + AccountID string `gorm:"index"` + Name string + Enabled bool + TargetGroups []string `gorm:"serializer:json;column:target_groups"` + TargetUsers []string `gorm:"serializer:json;column:target_users"` + Limits PolicyLimits `gorm:"serializer:json;column:limits"` + + CreatedAt time.Time + UpdatedAt time.Time +} + +// TableName puts budget rules in their own table. +func (AccountBudgetRule) TableName() string { return "agent_network_budget_rules" } + +// NewAccountBudgetRule returns a new rule with a freshly minted ID. +func NewAccountBudgetRule(accountID string) *AccountBudgetRule { + now := time.Now().UTC() + return &AccountBudgetRule{ + ID: "ainbud_" + xid.New().String(), + AccountID: accountID, + Enabled: true, + CreatedAt: now, + UpdatedAt: now, + } +} + +// Copy returns a deep copy of the rule, including its target slices. +func (r *AccountBudgetRule) Copy() *AccountBudgetRule { + c := *r + c.TargetGroups = append([]string(nil), r.TargetGroups...) + c.TargetUsers = append([]string(nil), r.TargetUsers...) + return &c +} + +// EventMeta renders the rule for the activity log. +func (r *AccountBudgetRule) EventMeta() map[string]any { + return map[string]any{ + "name": r.Name, + "enabled": r.Enabled, + } +} + +// FromAPIRequest applies the request payload onto the receiver. +func (r *AccountBudgetRule) FromAPIRequest(req *api.AgentNetworkBudgetRuleRequest) { + r.Name = req.Name + if req.Enabled != nil { + r.Enabled = *req.Enabled + } + if req.TargetGroups != nil { + r.TargetGroups = append([]string(nil), (*req.TargetGroups)...) + } else { + r.TargetGroups = []string{} + } + if req.TargetUsers != nil { + r.TargetUsers = append([]string(nil), (*req.TargetUsers)...) + } else { + r.TargetUsers = []string{} + } + r.Limits = limitsFromAPI(req.Limits) +} + +// ToAPIResponse renders the rule as the API representation. +func (r *AccountBudgetRule) ToAPIResponse() *api.AgentNetworkBudgetRule { + groups := r.TargetGroups + if groups == nil { + groups = []string{} + } + users := r.TargetUsers + if users == nil { + users = []string{} + } + created := r.CreatedAt + updated := r.UpdatedAt + return &api.AgentNetworkBudgetRule{ + Id: r.ID, + Name: r.Name, + Enabled: r.Enabled, + TargetGroups: groups, + TargetUsers: users, + Limits: limitsToAPI(r.Limits), + CreatedAt: &created, + UpdatedAt: &updated, + } +} diff --git a/management/internals/modules/agentnetwork/types/consumption.go b/management/internals/modules/agentnetwork/types/consumption.go new file mode 100644 index 000000000..5295b5570 --- /dev/null +++ b/management/internals/modules/agentnetwork/types/consumption.go @@ -0,0 +1,69 @@ +package types + +import "time" + +// ConsumptionDimension classifies which kind of identity a consumption +// row counts against. The proxy-side enforcement layer ticks one row +// per dimension per request — typically one user row plus one group +// row. +type ConsumptionDimension string + +const ( + // DimensionUser counts tokens / spend for a single end user. The + // dim_id column carries the netbird user id (or peer.ID when the + // caller is a tunnel-peer principal). + DimensionUser ConsumptionDimension = "user" + // DimensionGroup counts tokens / spend for a single source group + // across every member of that group. The dim_id column carries + // the netbird group id. + DimensionGroup ConsumptionDimension = "group" +) + +// Consumption is a per-dimension token + USD counter for a fixed +// aligned window. The (account, dim_kind, dim_id, window_seconds, +// window_start) tuple is the primary key; rows are rolled forward by +// the proxy's post-flight RecordLLMUsage path on every request. +// +// The same dim_id (e.g. a group id) gets one row per distinct +// window_seconds length in scope across the account's policies, +// because two policies with different window lengths read independent +// counters even though they share the dimension. Two policies with +// identical window_seconds on the same dimension share one counter +// (correct: their caps are checked against the same shared bucket). +type Consumption struct { + AccountID string `gorm:"primaryKey;type:varchar(255)"` + DimensionKind ConsumptionDimension `gorm:"primaryKey;type:varchar(16);column:dim_kind"` + DimensionID string `gorm:"primaryKey;type:varchar(255);column:dim_id"` + WindowSeconds int64 `gorm:"primaryKey;column:window_seconds"` + WindowStartUTC time.Time `gorm:"primaryKey;column:window_start_utc"` + TokensInput int64 `gorm:"column:tokens_input"` + TokensOutput int64 `gorm:"column:tokens_output"` + CostUSD float64 `gorm:"column:cost_usd"` + UpdatedAt time.Time +} + +// TableName forces a stable name independent of GORM's pluraliser. +func (Consumption) TableName() string { return "agent_network_consumption" } + +// ConsumptionKey identifies a single consumption counter within an account: +// the (dim_kind, dim_id, window_seconds, window_start) part of the row's +// primary key. Used to batch-read and batch-increment many counters for one +// request in a single store round-trip / transaction. +type ConsumptionKey struct { + Kind ConsumptionDimension + DimID string + WindowSeconds int64 + WindowStartUTC time.Time +} + +// WindowStart returns the aligned UTC start of the window of length +// windowSeconds that contains t. Aligned to the unix epoch so the +// same bucket boundary is computed deterministically across processes. +func WindowStart(t time.Time, windowSeconds int64) time.Time { + if windowSeconds <= 0 { + return t.UTC() + } + step := windowSeconds * int64(time.Second) + bucketed := t.UTC().UnixNano() / step * step + return time.Unix(0, bucketed).UTC() +} diff --git a/management/internals/modules/agentnetwork/types/consumption_test.go b/management/internals/modules/agentnetwork/types/consumption_test.go new file mode 100644 index 000000000..596cc158c --- /dev/null +++ b/management/internals/modules/agentnetwork/types/consumption_test.go @@ -0,0 +1,141 @@ +package types + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// TestWindowStart_AlignedToUnixEpoch is the multi-node-convergence +// guarantee: any two proxies computing WindowStart(now, s) for the +// same s must land on the same boundary. The implementation aligns +// to the unix epoch (UTC) rather than local time, calendar weeks, or +// process start time — none of which are shared across nodes. +// +// Table covers the load-bearing window lengths (5m, 1h, 24h, 30d) +// plus a few odd values that still need to align cleanly. +func TestWindowStart_AlignedToUnixEpoch(t *testing.T) { + cases := []struct { + name string + instant time.Time + windowSeconds int64 + want time.Time + }{ + { + name: "5m window — drops seconds inside the bucket", + instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.UTC), + windowSeconds: 300, + want: time.Date(2026, 5, 6, 13, 45, 0, 0, time.UTC), + }, + { + name: "1h window — drops minutes / seconds, keeps the hour", + instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.UTC), + windowSeconds: 3600, + want: time.Date(2026, 5, 6, 13, 0, 0, 0, time.UTC), + }, + { + name: "24h window aligns to UTC midnight", + instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.UTC), + windowSeconds: 86_400, + want: time.Date(2026, 5, 6, 0, 0, 0, 0, time.UTC), + }, + { + name: "30d (2_592_000s) window aligns to the 30d epoch grid, not month boundaries", + instant: time.Date(2026, 5, 6, 0, 0, 0, 0, time.UTC), + windowSeconds: 2_592_000, + // 2026-05-06 UTC = 1778025600s; 1778025600 / 2592000 = 685 + // 685 * 2592000 = 1775520000s = 2026-04-07 00:00:00 UTC + want: time.Date(2026, 4, 7, 0, 0, 0, 0, time.UTC), + }, + { + name: "non-UTC input still anchors on UTC epoch boundaries", + instant: time.Date(2026, 5, 6, 13, 47, 23, 0, time.FixedZone("CEST", 2*3600)), + windowSeconds: 86_400, + // 2026-05-06 13:47:23 CEST = 11:47:23 UTC → bucket 2026-05-06 00:00:00 UTC + want: time.Date(2026, 5, 6, 0, 0, 0, 0, time.UTC), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := WindowStart(tc.instant, tc.windowSeconds) + assert.True(t, got.Equal(tc.want), + "WindowStart(%v, %ds) = %v, want %v", tc.instant, tc.windowSeconds, got, tc.want) + }) + } +} + +// TestWindowStart_WithinWindowConverges proves the determinism +// contract: any two timestamps inside the same window land on the +// exact same boundary. Two proxy nodes serving requests 7s apart +// must agree on which counter row to upsert. +func TestWindowStart_WithinWindowConverges(t *testing.T) { + t1 := time.Date(2026, 5, 6, 14, 0, 0, 0, time.UTC) + t2 := t1.Add(7 * time.Second) + t3 := t1.Add(59*time.Minute + 59*time.Second) + + a := WindowStart(t1, 3600) + b := WindowStart(t2, 3600) + c := WindowStart(t3, 3600) + + assert.True(t, a.Equal(b), "two timestamps 7s apart in the same 1h window must align to the same boundary") + assert.True(t, a.Equal(c), "the very last second of a 1h window still lands on the SAME bucket as the first second") +} + +// TestWindowStart_AcrossWindowsDiverges is the symmetric guarantee: +// two timestamps separated by a window's worth of time MUST land on +// different boundaries. Without this, a 24h window's "rollover" +// would never reset the counter. +func TestWindowStart_AcrossWindowsDiverges(t *testing.T) { + t1 := time.Date(2026, 5, 6, 23, 59, 59, 0, time.UTC) + t2 := t1.Add(2 * time.Second) // 2026-05-07 00:00:01 + + a := WindowStart(t1, 86_400) + b := WindowStart(t2, 86_400) + assert.False(t, a.Equal(b), + "timestamps straddling a 24h-window boundary must land on different buckets — otherwise daily caps never reset") +} + +// TestWindowStart_DifferentWindowsHaveDifferentBuckets locks the +// design fork "two policies with different window_seconds on the same +// group produce independent counters". A 24h boundary at noon is NOT +// the same as the 30d boundary that contains it. +func TestWindowStart_DifferentWindowsHaveDifferentBuckets(t *testing.T) { + now := time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC) + short := WindowStart(now, 86_400) + long := WindowStart(now, 2_592_000) + assert.False(t, short.Equal(long), + "the 24h bucket and 30d bucket containing the same instant must differ — independent counters require independent keys") +} + +// TestWindowStart_SubMinuteAndMinuteAlignment locks sub-hour windows. +// A 5-minute window must align to multiples of 300s from the unix +// epoch — minute marks 0/5/10/.../55 within an hour, deterministic +// across nodes regardless of clock drift. +func TestWindowStart_SubMinuteAndMinuteAlignment(t *testing.T) { + t1 := time.Date(2026, 5, 6, 14, 12, 30, 0, time.UTC) + t2 := time.Date(2026, 5, 6, 14, 14, 59, 0, time.UTC) + t3 := time.Date(2026, 5, 6, 14, 15, 0, 0, time.UTC) + + a := WindowStart(t1, 300) + b := WindowStart(t2, 300) + c := WindowStart(t3, 300) + + assert.True(t, a.Equal(b), + "14:12:30 and 14:14:59 fall in the same 5m bucket starting at 14:10:00") + assert.True(t, a.Equal(time.Date(2026, 5, 6, 14, 10, 0, 0, time.UTC)), + "5m bucket containing 14:12 starts at 14:10 — aligned to multiples of 300s from unix epoch") + assert.False(t, a.Equal(c), + "14:15:00 is the start of the next 5m bucket — must not fold into the previous one") +} + +// TestWindowStart_ZeroWindowReturnsInputUTC covers the defensive +// path: caller hands a zero / negative window (shouldn't happen, but +// might mid-refactor). The function returns the input as UTC rather +// than dividing by zero. +func TestWindowStart_ZeroWindowReturnsInputUTC(t *testing.T) { + now := time.Date(2026, 5, 6, 12, 30, 45, 0, time.FixedZone("CEST", 2*3600)) + got := WindowStart(now, 0) + assert.True(t, got.Equal(now.UTC()), "zero window must not panic — return input as UTC") +} diff --git a/management/internals/modules/agentnetwork/types/guardrail.go b/management/internals/modules/agentnetwork/types/guardrail.go new file mode 100644 index 000000000..12edd815c --- /dev/null +++ b/management/internals/modules/agentnetwork/types/guardrail.go @@ -0,0 +1,120 @@ +package types + +import ( + "time" + + "github.com/rs/xid" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// GuardrailChecks is the configurable parameter set persisted with each +// guardrail. Stored as a JSON blob to keep the table flat. +type GuardrailChecks struct { + ModelAllowlist GuardrailModelAllowlist `json:"model_allowlist"` + PromptCapture GuardrailPromptCapture `json:"prompt_capture"` +} + +type GuardrailModelAllowlist struct { + Enabled bool `json:"enabled"` + Models []string `json:"models"` +} + +type GuardrailPromptCapture struct { + Enabled bool `json:"enabled"` + RedactPii bool `json:"redact_pii"` +} + +// Guardrail is an Agent Network reusable guardrail set persisted per account. +type Guardrail struct { + ID string `gorm:"primaryKey"` + AccountID string `gorm:"index"` + Name string + Description string + Checks GuardrailChecks `gorm:"serializer:json"` + CreatedAt time.Time + UpdatedAt time.Time +} + +// TableName uses an explicit name so guardrail rows live in their own +// table. +func (Guardrail) TableName() string { return "agent_network_guardrails" } + +// NewGuardrail returns a new Guardrail with a freshly minted ID. +func NewGuardrail(accountID string) *Guardrail { + now := time.Now().UTC() + return &Guardrail{ + ID: "ainguard_" + xid.New().String(), + AccountID: accountID, + Checks: GuardrailChecks{ModelAllowlist: GuardrailModelAllowlist{Models: []string{}}}, + CreatedAt: now, + UpdatedAt: now, + } +} + +// FromAPIRequest applies the request payload onto the receiver. +func (g *Guardrail) FromAPIRequest(req *api.AgentNetworkGuardrailRequest) { + g.Name = req.Name + if req.Description != nil { + g.Description = *req.Description + } + g.Checks = checksFromAPI(req.Checks) +} + +// ToAPIResponse renders the guardrail as the API representation. +func (g *Guardrail) ToAPIResponse() *api.AgentNetworkGuardrail { + created := g.CreatedAt + updated := g.UpdatedAt + return &api.AgentNetworkGuardrail{ + Id: g.ID, + Name: g.Name, + Description: g.Description, + Checks: checksToAPI(g.Checks), + CreatedAt: &created, + UpdatedAt: &updated, + } +} + +// Copy returns a deep copy of the guardrail. +func (g *Guardrail) Copy() *Guardrail { + clone := *g + if g.Checks.ModelAllowlist.Models != nil { + clone.Checks.ModelAllowlist.Models = append([]string(nil), g.Checks.ModelAllowlist.Models...) + } + return &clone +} + +// EventMeta is the audit-log payload for activity events. +func (g *Guardrail) EventMeta() map[string]any { + return map[string]any{"name": g.Name} +} + +func checksFromAPI(c api.AgentNetworkGuardrailChecks) GuardrailChecks { + models := append([]string(nil), c.ModelAllowlist.Models...) + if models == nil { + models = []string{} + } + return GuardrailChecks{ + ModelAllowlist: GuardrailModelAllowlist{ + Enabled: c.ModelAllowlist.Enabled, + Models: models, + }, + PromptCapture: GuardrailPromptCapture{ + Enabled: c.PromptCapture.Enabled, + RedactPii: c.PromptCapture.RedactPii, + }, + } +} + +func checksToAPI(c GuardrailChecks) api.AgentNetworkGuardrailChecks { + models := c.ModelAllowlist.Models + if models == nil { + models = []string{} + } + out := api.AgentNetworkGuardrailChecks{} + out.ModelAllowlist.Enabled = c.ModelAllowlist.Enabled + out.ModelAllowlist.Models = models + out.PromptCapture.Enabled = c.PromptCapture.Enabled + out.PromptCapture.RedactPii = c.PromptCapture.RedactPii + return out +} diff --git a/management/internals/modules/agentnetwork/types/policy.go b/management/internals/modules/agentnetwork/types/policy.go new file mode 100644 index 000000000..709b8149d --- /dev/null +++ b/management/internals/modules/agentnetwork/types/policy.go @@ -0,0 +1,192 @@ +package types + +import ( + "time" + + "github.com/rs/xid" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// Policy is an Agent Network policy persisted per account. A policy +// authorises members of SourceGroups to reach the listed +// DestinationProviderIDs under the attached GuardrailIDs and Limits. +// +// Token and budget limits live on the Policy itself (Limits field); +// guardrails carry only model allowlist and prompt capture. +type Policy struct { + ID string `gorm:"primaryKey"` + AccountID string `gorm:"index"` + Name string + Description string + Enabled bool + SourceGroups []string `gorm:"serializer:json;column:source_groups"` + DestinationProviderIDs []string `gorm:"serializer:json;column:destination_provider_ids"` + GuardrailIDs []string `gorm:"serializer:json;column:guardrail_ids"` + Limits PolicyLimits `gorm:"serializer:json;column:limits"` + + CreatedAt time.Time + UpdatedAt time.Time +} + +// PolicyLimits aggregates the token and budget caps attached directly +// to a policy. Both halves are always present; their Enabled flags +// control whether the proxy enforces them. +type PolicyLimits struct { + TokenLimit PolicyTokenLimit `json:"token_limit"` + BudgetLimit PolicyBudgetLimit `json:"budget_limit"` +} + +// PolicyTokenLimit is a token-count cap evaluated over an aligned +// window of WindowSeconds seconds. GroupCap is applied to each +// source group independently — every group in the policy's +// SourceGroups gets its own bucket of GroupCap tokens. UserCap +// applies independently to each individual user. A zero cap means +// uncapped. WindowSeconds must be at least 60 (one minute) when the +// limit is enabled. +type PolicyTokenLimit struct { + Enabled bool `json:"enabled"` + GroupCap int64 `json:"group_cap"` + UserCap int64 `json:"user_cap"` + WindowSeconds int64 `json:"window_seconds"` +} + +// PolicyBudgetLimit is a USD spend cap evaluated over an aligned +// window of WindowSeconds seconds. GroupCapUsd is applied to each +// source group independently — every group in the policy's +// SourceGroups gets its own bucket of GroupCapUsd USD. UserCapUsd +// applies independently to each individual user. A zero cap means +// uncapped. WindowSeconds must be at least 60 (one minute) when the +// limit is enabled. +type PolicyBudgetLimit struct { + Enabled bool `json:"enabled"` + GroupCapUsd float64 `json:"group_cap_usd"` + UserCapUsd float64 `json:"user_cap_usd"` + WindowSeconds int64 `json:"window_seconds"` +} + +// TableName forces a unique GORM table to avoid collision with the access +// control Policy type, which also resolves to "policies" by default. +func (Policy) TableName() string { return "agent_network_policies" } + +// NewPolicy returns a new Policy with a freshly minted ID. +func NewPolicy(accountID string) *Policy { + now := time.Now().UTC() + return &Policy{ + ID: "ainpol_" + xid.New().String(), + AccountID: accountID, + Enabled: true, + CreatedAt: now, + UpdatedAt: now, + } +} + +// FromAPIRequest applies the request payload onto the receiver. +func (p *Policy) FromAPIRequest(req *api.AgentNetworkPolicyRequest) { + p.Name = req.Name + if req.Description != nil { + p.Description = *req.Description + } + if req.Enabled != nil { + p.Enabled = *req.Enabled + } + p.SourceGroups = append([]string(nil), req.SourceGroups...) + p.DestinationProviderIDs = append([]string(nil), req.DestinationProviderIds...) + if req.GuardrailIds != nil { + p.GuardrailIDs = append([]string(nil), (*req.GuardrailIds)...) + } else { + p.GuardrailIDs = []string{} + } + if req.Limits != nil { + p.Limits = limitsFromAPI(*req.Limits) + } else { + p.Limits = PolicyLimits{} + } +} + +// ToAPIResponse renders the policy as the API representation. +func (p *Policy) ToAPIResponse() *api.AgentNetworkPolicy { + src := p.SourceGroups + if src == nil { + src = []string{} + } + dst := p.DestinationProviderIDs + if dst == nil { + dst = []string{} + } + guardrails := p.GuardrailIDs + if guardrails == nil { + guardrails = []string{} + } + created := p.CreatedAt + updated := p.UpdatedAt + return &api.AgentNetworkPolicy{ + Id: p.ID, + Name: p.Name, + Description: p.Description, + Enabled: p.Enabled, + SourceGroups: src, + DestinationProviderIds: dst, + GuardrailIds: guardrails, + Limits: limitsToAPI(p.Limits), + CreatedAt: &created, + UpdatedAt: &updated, + } +} + +// Copy returns a deep copy of the policy. +func (p *Policy) Copy() *Policy { + clone := *p + if p.SourceGroups != nil { + clone.SourceGroups = append([]string(nil), p.SourceGroups...) + } + if p.DestinationProviderIDs != nil { + clone.DestinationProviderIDs = append([]string(nil), p.DestinationProviderIDs...) + } + if p.GuardrailIDs != nil { + clone.GuardrailIDs = append([]string(nil), p.GuardrailIDs...) + } + return &clone +} + +// EventMeta is the audit-log payload for activity events. +func (p *Policy) EventMeta() map[string]any { + return map[string]any{ + "name": p.Name, + "enabled": p.Enabled, + } +} + +func limitsFromAPI(in api.AgentNetworkPolicyLimits) PolicyLimits { + return PolicyLimits{ + TokenLimit: PolicyTokenLimit{ + Enabled: in.TokenLimit.Enabled, + GroupCap: in.TokenLimit.GroupCap, + UserCap: in.TokenLimit.UserCap, + WindowSeconds: in.TokenLimit.WindowSeconds, + }, + BudgetLimit: PolicyBudgetLimit{ + Enabled: in.BudgetLimit.Enabled, + GroupCapUsd: in.BudgetLimit.GroupCapUsd, + UserCapUsd: in.BudgetLimit.UserCapUsd, + WindowSeconds: in.BudgetLimit.WindowSeconds, + }, + } +} + +func limitsToAPI(in PolicyLimits) api.AgentNetworkPolicyLimits { + return api.AgentNetworkPolicyLimits{ + TokenLimit: api.AgentNetworkPolicyTokenLimit{ + Enabled: in.TokenLimit.Enabled, + GroupCap: in.TokenLimit.GroupCap, + UserCap: in.TokenLimit.UserCap, + WindowSeconds: in.TokenLimit.WindowSeconds, + }, + BudgetLimit: api.AgentNetworkPolicyBudgetLimit{ + Enabled: in.BudgetLimit.Enabled, + GroupCapUsd: in.BudgetLimit.GroupCapUsd, + UserCapUsd: in.BudgetLimit.UserCapUsd, + WindowSeconds: in.BudgetLimit.WindowSeconds, + }, + } +} diff --git a/management/internals/modules/agentnetwork/types/provider.go b/management/internals/modules/agentnetwork/types/provider.go new file mode 100644 index 000000000..28c8a94e2 --- /dev/null +++ b/management/internals/modules/agentnetwork/types/provider.go @@ -0,0 +1,252 @@ +package types + +import ( + "fmt" + "strings" + "time" + + "github.com/rs/xid" + + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/util/crypt" +) + +// ProviderModel is one row in the provider's models list. The operator +// pins the per-1k input/output price for cost tracking; ID is the +// model identifier the upstream provider expects on the wire. +type ProviderModel struct { + ID string `json:"id"` + InputPer1k float64 `json:"input_per_1k"` + OutputPer1k float64 `json:"output_per_1k"` +} + +// Provider is an Agent Network AI provider record persisted per account. +// The proxy cluster fronting the account lives on the per-account +// agent-network Settings row, not on the Provider — every provider in +// an account routes through the same cluster. +type Provider struct { + ID string `gorm:"primaryKey"` + AccountID string `gorm:"index"` + ProviderID string `gorm:"index:idx_agent_network_provider"` + Name string + // UpstreamURL is the full upstream URL (e.g. https://api.openai.com) + // the operator selected. + UpstreamURL string `gorm:"column:upstream_url"` + APIKey string `gorm:"column:api_key"` + // ExtraValues holds operator-typed values for catalog-declared + // ExtraHeaders (see catalog.Provider.ExtraHeaders). Keyed by + // header name (e.g. "x-portkey-config"); a non-empty value is + // stamped on every upstream request to this provider via the + // proxy's identity-inject middleware (anti-spoof Remove + Add). + // Empty / missing keys = no header stamped. Stored as a JSON + // blob so the schema doesn't grow per-catalog-entry. + ExtraValues map[string]string `gorm:"serializer:json;column:extra_values"` + // Models is the operator's curated list of models exposed by this + // provider together with their per-1k input/output prices (USD). + // Empty means all catalog models are allowed at catalog prices. + Models []ProviderModel `gorm:"serializer:json"` + Enabled bool + // SessionPrivateKey + SessionPublicKey are the ed25519 keypair the + // synthesised reverse-proxy service uses to sign / verify session + // JWTs after a successful OIDC handshake. Generated once on + // provider create and never rotated by the manager so existing + // session cookies survive provider edits. SessionPrivateKey is + // encrypted at rest via EncryptSensitiveData / + // DecryptSensitiveData; SessionPublicKey is plain. + SessionPrivateKey string `gorm:"column:session_private_key"` + SessionPublicKey string `gorm:"column:session_public_key"` + // IdentityHeaderUserID + IdentityHeaderGroups are the operator- + // chosen wire header names for HeaderPair-style identity + // injection on catalog entries that flag the shape as + // Customizable (e.g. Bifrost, where the operator picks between + // the always-on x-bf-lh- log-metadata family and the + // label-declared x-bf-dim- telemetry family). Empty value + // disables stamping for that dimension; the inject middleware + // already no-ops on empty header names. Catalog entries with + // Customizable=false ignore these fields and use the static + // header names defined in their HeaderPairInjection block. + IdentityHeaderUserID string `gorm:"column:identity_header_user_id"` + IdentityHeaderGroups string `gorm:"column:identity_header_groups"` + CreatedAt time.Time + UpdatedAt time.Time +} + +// TableName uses an explicit name so the Agent Network provider rows live +// in their own table, separate from any future "providers"-named entity. +func (Provider) TableName() string { return "agent_network_providers" } + +// NewProvider returns a new Provider with a freshly minted ID. +func NewProvider(accountID string) *Provider { + now := time.Now().UTC() + return &Provider{ + ID: xid.New().String(), + AccountID: accountID, + CreatedAt: now, + UpdatedAt: now, + } +} + +// FromAPIRequest applies the request payload onto the receiver. The api_key +// is only overwritten when the caller provided one — empty/nil leaves the +// existing key intact, so updates can omit it. +func (p *Provider) FromAPIRequest(req *api.AgentNetworkProviderRequest) { + p.ProviderID = req.ProviderId + p.Name = req.Name + p.UpstreamURL = req.UpstreamUrl + if req.ApiKey != nil && strings.TrimSpace(*req.ApiKey) != "" { + p.APIKey = *req.ApiKey + } + if req.ExtraValues != nil { + // Replace the whole map (rather than merge) so unsetting a + // value on the dashboard actually clears it. Empty strings + // are dropped so we don't waste a row on no-op values. + next := make(map[string]string, len(*req.ExtraValues)) + for k, v := range *req.ExtraValues { + v = strings.TrimSpace(v) + if v != "" { + next[k] = v + } + } + if len(next) == 0 { + p.ExtraValues = nil + } else { + p.ExtraValues = next + } + } + p.Models = p.Models[:0] + if req.Models != nil { + for _, m := range *req.Models { + p.Models = append(p.Models, ProviderModel{ + ID: m.Id, + InputPer1k: m.InputPer1k, + OutputPer1k: m.OutputPer1k, + }) + } + } + if p.Models == nil { + p.Models = []ProviderModel{} + } + if req.Enabled != nil { + p.Enabled = *req.Enabled + } + // Identity-header overrides for catalogs flagged Customizable. + // nil pointer = "field omitted on the wire" → leave the stored + // value untouched (per the openapi description). Empty string is + // an explicit clear that disables stamping for this dimension. + if req.IdentityHeaderUserId != nil { + p.IdentityHeaderUserID = strings.TrimSpace(*req.IdentityHeaderUserId) + } + if req.IdentityHeaderGroups != nil { + p.IdentityHeaderGroups = strings.TrimSpace(*req.IdentityHeaderGroups) + } +} + +// ToAPIResponse renders the provider as the API representation. The API +// key is intentionally never surfaced. +func (p *Provider) ToAPIResponse() *api.AgentNetworkProvider { + models := make([]api.AgentNetworkProviderModel, 0, len(p.Models)) + for _, m := range p.Models { + models = append(models, api.AgentNetworkProviderModel{ + Id: m.ID, + InputPer1k: m.InputPer1k, + OutputPer1k: m.OutputPer1k, + }) + } + created := p.CreatedAt + updated := p.UpdatedAt + resp := &api.AgentNetworkProvider{ + Id: p.ID, + ProviderId: p.ProviderID, + Name: p.Name, + UpstreamUrl: p.UpstreamURL, + Models: models, + Enabled: p.Enabled, + CreatedAt: &created, + UpdatedAt: &updated, + } + if len(p.ExtraValues) > 0 { + out := make(map[string]string, len(p.ExtraValues)) + for k, v := range p.ExtraValues { + out[k] = v + } + resp.ExtraValues = &out + } + if p.IdentityHeaderUserID != "" { + v := p.IdentityHeaderUserID + resp.IdentityHeaderUserId = &v + } + if p.IdentityHeaderGroups != "" { + v := p.IdentityHeaderGroups + resp.IdentityHeaderGroups = &v + } + return resp +} + +// Copy returns a deep copy of the provider. +func (p *Provider) Copy() *Provider { + clone := *p + if p.Models != nil { + clone.Models = append([]ProviderModel(nil), p.Models...) + } + if p.ExtraValues != nil { + clone.ExtraValues = make(map[string]string, len(p.ExtraValues)) + for k, v := range p.ExtraValues { + clone.ExtraValues[k] = v + } + } + return &clone +} + +// EventMeta is the audit-log payload for activity events. +func (p *Provider) EventMeta() map[string]any { + return map[string]any{ + "name": p.Name, + "provider_id": p.ProviderID, + } +} + +// EncryptSensitiveData encrypts the upstream API key and the session +// signing key in place. +func (p *Provider) EncryptSensitiveData(enc *crypt.FieldEncrypt) error { + if enc == nil { + return nil + } + if p.APIKey != "" { + encrypted, err := enc.Encrypt(p.APIKey) + if err != nil { + return fmt.Errorf("encrypt agent network provider api key: %w", err) + } + p.APIKey = encrypted + } + if p.SessionPrivateKey != "" { + encrypted, err := enc.Encrypt(p.SessionPrivateKey) + if err != nil { + return fmt.Errorf("encrypt agent network provider session key: %w", err) + } + p.SessionPrivateKey = encrypted + } + return nil +} + +// DecryptSensitiveData decrypts the upstream API key and the session +// signing key in place. +func (p *Provider) DecryptSensitiveData(enc *crypt.FieldEncrypt) error { + if enc == nil { + return nil + } + if p.APIKey != "" { + decrypted, err := enc.Decrypt(p.APIKey) + if err != nil { + return fmt.Errorf("decrypt agent network provider api key: %w", err) + } + p.APIKey = decrypted + } + if p.SessionPrivateKey != "" { + decrypted, err := enc.Decrypt(p.SessionPrivateKey) + if err != nil { + return fmt.Errorf("decrypt agent network provider session key: %w", err) + } + p.SessionPrivateKey = decrypted + } + return nil +} diff --git a/management/internals/modules/agentnetwork/types/settings.go b/management/internals/modules/agentnetwork/types/settings.go new file mode 100644 index 000000000..d61d9deff --- /dev/null +++ b/management/internals/modules/agentnetwork/types/settings.go @@ -0,0 +1,78 @@ +package types + +import ( + "time" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// DefaultAccessLogRetentionDays is the retention applied to new accounts' +// agent-network access logs. Usage records are not subject to this — they are +// the long-term aggregate and are retained independently. +const DefaultAccessLogRetentionDays = 30 + +// Settings is the per-account agent-network configuration row. One +// row per account. Cluster + Subdomain are immutable once written and +// produce the public endpoint agents call (`.`). +type Settings struct { + AccountID string `gorm:"primaryKey"` + Cluster string + Subdomain string `gorm:"index:idx_agent_network_settings_cluster_subdomain"` + + // Account-level collection controls sourced by the synthesizer. + // EnableLogCollection gates the per-request access-log trail and defaults + // ON for new accounts. EnablePromptCollection is the master gate for + // request/response prompt capture (AND-gated with the policy-level + // guardrail). RedactPii enables PII redaction on captured prompts; + // effective redaction is account OR policy. + EnableLogCollection bool + EnablePromptCollection bool + RedactPii bool + + // AccessLogRetentionDays bounds how long full access-log rows are kept; a + // periodic sweep deletes older rows. <= 0 means keep indefinitely. Usage + // records are unaffected. + AccessLogRetentionDays int + + CreatedAt time.Time + UpdatedAt time.Time +} + +// TableName puts the rows in their own table to keep the agent-network +// schema cohesive. +func (Settings) TableName() string { return "agent_network_settings" } + +// Endpoint returns the bare hostname agents reach this account at: +// `.`. +func (s *Settings) Endpoint() string { + return s.Subdomain + "." + s.Cluster +} + +// ToAPIResponse renders the settings as the API representation. +func (s *Settings) ToAPIResponse() *api.AgentNetworkSettings { + created := s.CreatedAt + updated := s.UpdatedAt + retention := s.AccessLogRetentionDays + return &api.AgentNetworkSettings{ + Cluster: s.Cluster, + Subdomain: s.Subdomain, + Endpoint: s.Endpoint(), + EnableLogCollection: s.EnableLogCollection, + EnablePromptCollection: s.EnablePromptCollection, + RedactPii: s.RedactPii, + AccessLogRetentionDays: &retention, + CreatedAt: &created, + UpdatedAt: &updated, + } +} + +// FromAPIRequest applies the mutable settings fields from the request. Cluster +// and Subdomain are immutable and intentionally not touched here. +func (s *Settings) FromAPIRequest(req *api.AgentNetworkSettingsRequest) { + s.EnableLogCollection = req.EnableLogCollection + s.EnablePromptCollection = req.EnablePromptCollection + s.RedactPii = req.RedactPii + if req.AccessLogRetentionDays != nil { + s.AccessLogRetentionDays = *req.AccessLogRetentionDays + } +} diff --git a/management/internals/modules/agentnetwork/types/usage.go b/management/internals/modules/agentnetwork/types/usage.go new file mode 100644 index 000000000..dd01d4300 --- /dev/null +++ b/management/internals/modules/agentnetwork/types/usage.go @@ -0,0 +1,47 @@ +package types + +import ( + "time" +) + +// AgentNetworkUsage is the stripped, always-collected per-request usage record +// powering the Usage overview. Unlike AgentNetworkAccessLog it carries no +// request detail (host/path/source IP/prompt) — only the dimensions needed to +// aggregate and filter spend by user / group / provider / model over time. +// +// It is written unconditionally on every served agent-network request, +// independent of the account's EnableLogCollection toggle: when log collection +// is off the proxy ships a stripped, usage-only entry and management still +// records the usage row (but skips the full AgentNetworkAccessLog row). +type AgentNetworkUsage struct { + ID string `gorm:"primaryKey"` + AccountID string `gorm:"index"` + Timestamp time.Time `gorm:"index"` + UserID string `gorm:"index"` + ResolvedProviderID string `gorm:"index"` + Provider string // vendor, e.g. "openai" + Model string `gorm:"index"` + SessionID string `gorm:"index"` // llm.session_id — groups a conversation / coding session + InputTokens int64 + OutputTokens int64 + TotalTokens int64 + CostUSD float64 + CreatedAt time.Time +} + +// TableName keeps usage records in their own stripped table. Named +// distinctly (…_request_usage) to avoid colliding with any pre-existing +// agent_network_usage table in a shared database. +func (AgentNetworkUsage) TableName() string { return "agent_network_request_usage" } + +// AgentNetworkUsageGroup is the normalised many-to-many row linking a usage +// record to one authorising group, mirroring AgentNetworkAccessLogGroup so the +// usage overview can filter by group with a `group_id IN (...)` join. +type AgentNetworkUsageGroup struct { + UsageID string `gorm:"primaryKey"` + GroupID string `gorm:"primaryKey;index"` + AccountID string `gorm:"index"` +} + +// TableName names the usage group child table. +func (AgentNetworkUsageGroup) TableName() string { return "agent_network_request_usage_group" } diff --git a/management/internals/modules/agentnetwork/types/usageoverview.go b/management/internals/modules/agentnetwork/types/usageoverview.go new file mode 100644 index 000000000..658832bec --- /dev/null +++ b/management/internals/modules/agentnetwork/types/usageoverview.go @@ -0,0 +1,96 @@ +package types + +import ( + "sort" + "time" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// UsageGranularity is the time-bucket width for the usage overview. New values +// can be added here and handled in bucketStart without touching the store. +type UsageGranularity string + +const ( + UsageGranularityDay UsageGranularity = "day" + UsageGranularityWeek UsageGranularity = "week" + UsageGranularityMonth UsageGranularity = "month" +) + +// ParseUsageGranularity maps the API query value to a granularity, defaulting +// to day for empty/unknown input. +func ParseUsageGranularity(s string) UsageGranularity { + switch UsageGranularity(s) { + case UsageGranularityWeek: + return UsageGranularityWeek + case UsageGranularityMonth: + return UsageGranularityMonth + default: + return UsageGranularityDay + } +} + +// AgentNetworkUsageBucket is one aggregated usage time bucket. PeriodStart is +// the UTC start of the bucket as YYYY-MM-DD. +type AgentNetworkUsageBucket struct { + PeriodStart string + InputTokens int64 + OutputTokens int64 + TotalTokens int64 + CostUSD float64 +} + +// ToAPIResponse renders the bucket as the API representation. +func (b *AgentNetworkUsageBucket) ToAPIResponse() api.AgentNetworkUsageBucket { + return api.AgentNetworkUsageBucket{ + PeriodStart: b.PeriodStart, + InputTokens: b.InputTokens, + OutputTokens: b.OutputTokens, + TotalTokens: b.TotalTokens, + CostUsd: b.CostUSD, + } +} + +// bucketStart truncates t (in UTC) to the start of its bucket for the given +// granularity. Week buckets start on Monday (ISO week). +func bucketStart(t time.Time, g UsageGranularity) time.Time { + t = t.UTC() + switch g { + case UsageGranularityWeek: + // Monday-start week. time.Weekday: Sunday=0..Saturday=6. + offset := (int(t.Weekday()) + 6) % 7 + day := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC) + return day.AddDate(0, 0, -offset) + case UsageGranularityMonth: + return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC) + default: // day + return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC) + } +} + +// AggregateUsageByGranularity buckets the usage rows by the requested +// granularity and returns the buckets ordered oldest-first. Aggregation is done +// in Go (rather than per-engine SQL date_trunc) so granularities stay portable +// across SQLite/Postgres/MySQL and easy to extend. +func AggregateUsageByGranularity(rows []*AgentNetworkUsage, g UsageGranularity) []*AgentNetworkUsageBucket { + byPeriod := make(map[string]*AgentNetworkUsageBucket) + for _, r := range rows { + key := bucketStart(r.Timestamp, g).Format("2006-01-02") + b := byPeriod[key] + if b == nil { + b = &AgentNetworkUsageBucket{PeriodStart: key} + byPeriod[key] = b + } + b.InputTokens += r.InputTokens + b.OutputTokens += r.OutputTokens + b.TotalTokens += r.TotalTokens + b.CostUSD += r.CostUSD + } + + out := make([]*AgentNetworkUsageBucket, 0, len(byPeriod)) + for _, b := range byPeriod { + out = append(out, b) + } + sort.Slice(out, func(i, j int) bool { return out[i].PeriodStart < out[j].PeriodStart }) + return out +} diff --git a/management/internals/modules/agentnetwork/wire_shape_test.go b/management/internals/modules/agentnetwork/wire_shape_test.go new file mode 100644 index 000000000..b574ab3e1 --- /dev/null +++ b/management/internals/modules/agentnetwork/wire_shape_test.go @@ -0,0 +1,109 @@ +package agentnetwork + +import ( + "context" + "encoding/json" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// TestSynthesizedService_WireShape locks down the proto shape that +// flows from the synthesizer through ToProtoMapping to the proxy. +// Drift between this test and what the proxy expects manifests as +// "service not matching" — the proxy receives a mapping but can't +// register an SNI/HTTP route from it. +func TestSynthesizedService_WireShape(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + provider := newSynthTestProvider() + policy := newSynthTestPolicy(provider.ID, "grp-eng", "") + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{provider}, + []*types.Policy{policy}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + svc := services[0] + mapping := svc.ToProtoMapping(rpservice.Create, "test-token", proxy.OIDCValidationConfig{}) + + // Identifiers — account-scoped service ID, settings-derived domain. + assert.Equal(t, "agent-net-svc-acct-1", mapping.GetId(), "stable account-scoped virtual service ID") + assert.Equal(t, testAccountID, mapping.GetAccountId(), "account id round-trips") + assert.Equal(t, testEndpoint, mapping.GetDomain(), "domain matches settings.Endpoint() output") + + // Mode + listen port — addMapping at proxy/server.go switches on Mode. + assert.Equal(t, "http", mapping.GetMode(), "synthesised services are HTTP mode") + assert.Equal(t, int32(0), mapping.GetListenPort(), "no custom listen port for HTTP services") + + // Auth token + private/tunnel shape: agent-network endpoints authenticate + // inbound agents via ValidateTunnelPeer against AccessGroups, not OIDC. + assert.Equal(t, "test-token", mapping.GetAuthToken(), "auth token round-trips for proxy CreateProxyPeer") + assert.True(t, mapping.GetPrivate(), "synthesised services are private (tunnel-peer auth via AccessGroups)") + require.NotNil(t, mapping.GetAuth(), "auth payload carries the session key") + assert.False(t, mapping.GetAuth().GetOidc(), "OIDC is off for tunnel-auth agent-network services") + + // Path mappings — proxy/server.go::setupHTTPMapping early-returns when + // len(mapping.GetPath()) == 0, so this is a critical assertion. + require.Len(t, mapping.GetPath(), 1, "exactly one path mapping for the cluster target") + pm := mapping.GetPath()[0] + assert.Equal(t, "/", pm.GetPath(), "default path is '/'") + assert.Equal(t, "https://noop.invalid/", pm.GetTarget(), + "target URL is the placeholder; the router middleware rewrites it per request") + require.NotNil(t, pm.GetOptions(), "target options must be populated so direct_upstream + middleware chain reach the proxy") + assert.True(t, pm.GetOptions().GetDirectUpstream(), "synth targets imply direct_upstream so the proxy dials via the host stack") + assert.True(t, pm.GetOptions().GetAgentNetwork(), "agent_network flag must travel on the wire so the proxy can tag access logs") + + mws := pm.GetOptions().GetMiddlewares() + require.Len(t, mws, 8, "eight middlewares reach the proxy: request_parser, router, limit_check, identity_inject, guardrail, limit_record, cost_meter, response_parser") + + assert.Equal(t, middlewareIDLLMRequestParser, mws[0].GetId(), "first middleware id") + assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[0].GetSlot(), "request parser slot") + + assert.Equal(t, middlewareIDLLMRouter, mws[1].GetId(), "second middleware id") + assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[1].GetSlot(), "router slot") + require.NotEmpty(t, mws[1].GetConfigJson(), "router config must travel on the wire") + var routerCfg routerConfig + require.NoError(t, json.Unmarshal(mws[1].GetConfigJson(), &routerCfg), "router config decodes") + require.Len(t, routerCfg.Providers, 1, "the only enabled provider reaches the router") + assert.Equal(t, provider.ID, routerCfg.Providers[0].ID, "router provider id matches synth provider") + assert.Equal(t, "Bearer sk-test-key", routerCfg.Providers[0].AuthHeaderValue, + "openai catalog template substitutes the API key on the wire") + + assert.Equal(t, middlewareIDLLMLimitCheck, mws[2].GetId(), + "limit_check runs after the router so the resolved provider id is available, before identity_inject so a deny doesn't pay the header-stamp cost") + assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[2].GetSlot()) + + assert.Equal(t, middlewareIDLLMIdentityInject, mws[3].GetId(), "fourth middleware id") + assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[3].GetSlot(), "identity inject slot") + require.NotEmpty(t, mws[3].GetConfigJson(), "identity inject config JSON must travel on the wire") + + assert.Equal(t, middlewareIDLLMGuardrail, mws[4].GetId(), "fifth middleware id") + assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, mws[4].GetSlot(), "guardrail slot") + require.NotEmpty(t, mws[4].GetConfigJson(), "guardrail middleware config JSON must travel on the wire") + + assert.Equal(t, middlewareIDLLMLimitRecord, mws[5].GetId(), + "limit_record sits FIRST in the response section so it RUNS LAST at runtime — slot order on the response leg is reverse-of-slice") + assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE, mws[5].GetSlot()) + + assert.Equal(t, middlewareIDCostMeter, mws[6].GetId(), "seventh middleware id") + assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE, mws[6].GetSlot(), "cost meter slot") + + assert.Equal(t, middlewareIDLLMResponseParser, mws[7].GetId(), "eighth middleware id") + assert.Equal(t, proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE, mws[7].GetSlot(), "response parser slot") +} diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go index e22d1e6e0..239d6b09c 100644 --- a/management/internals/modules/peers/manager.go +++ b/management/internals/modules/peers/manager.go @@ -220,12 +220,36 @@ func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, er func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error { existingPeerID, err := m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey) if err == nil && existingPeerID != "" { - // Peer already exists + // Same pubkey already registered — idempotent. return nil } + // Dedupe stale embedded peer records for the same (account, cluster). + // The proxy generates a fresh WireGuard keypair on every startup + // (proxy/internal/roundtrip/netbird.go), so without this sweep the + // prior embedded peer would linger forever — holding its CGNAT IP + // allocation, polluting other peers' rosters, and (most visibly) + // leaving the synth DNS pointing at the dead address. The + // (account, cluster) tuple identifies "the embedded peer for this + // proxy instance at this cluster"; any record matching that tuple + // with a different pubkey is by definition stale and must go. + staleIDs, err := m.findStaleEmbeddedProxyPeers(ctx, accountID, cluster, peerKey) + if err != nil { + return fmt.Errorf("scan for stale embedded proxy peers: %w", err) + } + if len(staleIDs) > 0 { + // userID="" + checkConnected=false: the deletion is initiated + // by management itself on behalf of the freshly-registering + // proxy, not by an end user; the stale peer may still be + // marked Connected from its prior session, but its session is + // dead by definition (its key no longer exists). + if err := m.DeletePeers(ctx, accountID, staleIDs, "", false); err != nil { + return fmt.Errorf("delete stale embedded proxy peers %v: %w", staleIDs, err) + } + } + name := fmt.Sprintf("proxy-%s", xid.New().String()) - peer := &peer.Peer{ + newPeer := &peer.Peer{ Ephemeral: true, ProxyMeta: peer.ProxyMeta{ Cluster: cluster, @@ -242,10 +266,36 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee }, } - _, _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true) + _, _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", newPeer, true) if err != nil { return fmt.Errorf("failed to create proxy peer: %w", err) } return nil } + +// findStaleEmbeddedProxyPeers returns the peer IDs of embedded proxy peer +// records in accountID that target the same cluster but carry a different +// WireGuard pubkey than the freshly-registering one. Used by CreateProxyPeer +// to garbage-collect stale records left behind when the proxy restarts with a +// regenerated keypair. +func (m *managerImpl) findStaleEmbeddedProxyPeers(ctx context.Context, accountID, cluster, newKey string) ([]string, error) { + account, err := m.store.GetAccount(ctx, accountID) + if err != nil { + return nil, err + } + var stale []string + for _, p := range account.Peers { + if p == nil || !p.ProxyMeta.Embedded { + continue + } + if p.ProxyMeta.Cluster != cluster { + continue + } + if p.Key == newKey { + continue + } + stale = append(stale, p.ID) + } + return stale, nil +} diff --git a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go index f2ecfd5f9..6705eab1a 100644 --- a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go +++ b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go @@ -39,6 +39,10 @@ type AccessLogEntry struct { BytesDownload int64 `gorm:"index"` Protocol AccessLogProtocol `gorm:"index"` Metadata map[string]string `gorm:"serializer:json"` + // AgentNetwork marks the entry as emitted by a synthesised agent-network + // service. Sourced from proto.AccessLog.AgentNetwork the proxy stamps + // before shipping. Indexed so the agent-network log surface filters cheaply. + AgentNetwork bool `gorm:"index"` } // FromProto creates an AccessLogEntry from a proto.AccessLog @@ -58,6 +62,7 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) { a.BytesDownload = serviceLog.GetBytesDownload() a.Protocol = AccessLogProtocol(serviceLog.GetProtocol()) a.Metadata = maps.Clone(serviceLog.GetMetadata()) + a.AgentNetwork = serviceLog.GetAgentNetwork() if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" { if addr, err := netip.ParseAddr(sourceIP); err == nil { diff --git a/management/internals/modules/reverseproxy/accesslogs/manager/manager.go b/management/internals/modules/reverseproxy/accesslogs/manager/manager.go index ced2ec4d1..4b24adaa1 100644 --- a/management/internals/modules/reverseproxy/accesslogs/manager/manager.go +++ b/management/internals/modules/reverseproxy/accesslogs/manager/manager.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/permissions" @@ -31,8 +32,14 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, geo g } } -// SaveAccessLog saves an access log entry to the database after enriching it +// SaveAccessLog saves an access log entry to the database after enriching it. +// Agent-network entries are flattened into their own dedicated table (queryable +// LLM columns + group child rows) instead of the shared reverse-proxy table. func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.AccessLogEntry) error { + if logEntry.AgentNetwork { + return agentnetwork.IngestAccessLog(ctx, m.store, logEntry) + } + if m.geo != nil && logEntry.GeoLocation.ConnectionIP != nil { location, err := m.geo.Lookup(logEntry.GeoLocation.ConnectionIP) if err != nil { diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index ee1e3c8b2..b6438abde 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -66,6 +66,51 @@ type TargetOptions struct { // reachable without WireGuard (public APIs, LAN services, localhost // sidecars). Default false. DirectUpstream bool `json:"direct_upstream,omitempty"` + // Middlewares carries per-target agent-network middleware configs. Empty + // for private and operator-defined services; populated only by the + // agent-network synthesizer. + Middlewares []MiddlewareConfig `gorm:"serializer:json" json:"middlewares,omitempty"` + CaptureMaxRequestBytes int64 `json:"capture_max_request_bytes,omitempty"` + CaptureMaxResponseBytes int64 `json:"capture_max_response_bytes,omitempty"` + CaptureContentTypes []string `gorm:"serializer:json" json:"capture_content_types,omitempty"` + // AgentNetwork marks targets synthesised from Agent Network state. The + // proxy uses it to gate agent-network-specific behaviour (access log + // tagging, observability, etc.). + AgentNetwork bool `json:"agent_network,omitempty"` + // DisableAccessLog suppresses the per-request access-log emission for this + // target. Defaults false to preserve access-log behaviour for every + // non-agent-network target. The agent-network synthesizer sets this true + // only when the account's EnableLogCollection toggle is off. + DisableAccessLog bool `json:"disable_access_log,omitempty"` +} + +// MiddlewareSlot mirrors proto.MiddlewareSlot / middleware.Slot. +type MiddlewareSlot string + +const ( + MiddlewareSlotOnRequest MiddlewareSlot = "on_request" + MiddlewareSlotOnResponse MiddlewareSlot = "on_response" + MiddlewareSlotTerminal MiddlewareSlot = "terminal" +) + +// MiddlewareFailMode mirrors proto.MiddlewareConfig_FailMode. +type MiddlewareFailMode string + +const ( + MiddlewareFailOpen MiddlewareFailMode = "fail_open" + MiddlewareFailClosed MiddlewareFailMode = "fail_closed" +) + +// MiddlewareConfig is the per-target configuration for a single +// middleware instance. Mirrors proto.MiddlewareConfig. +type MiddlewareConfig struct { + ID string `json:"id"` + Enabled bool `json:"enabled"` + Slot MiddlewareSlot `json:"slot"` + ConfigJSON []byte `json:"config_json,omitempty"` + FailMode MiddlewareFailMode `json:"fail_mode,omitempty"` + TimeoutMs int32 `json:"timeout_ms,omitempty"` + CanMutate bool `json:"can_mutate"` } type Target struct { @@ -504,21 +549,75 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions { func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions { if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && - len(opts.CustomHeaders) == 0 && !opts.DirectUpstream { + len(opts.CustomHeaders) == 0 && !opts.DirectUpstream && + len(opts.Middlewares) == 0 && opts.CaptureMaxRequestBytes == 0 && + opts.CaptureMaxResponseBytes == 0 && len(opts.CaptureContentTypes) == 0 && + !opts.AgentNetwork && !opts.DisableAccessLog { return nil } popts := &proto.PathTargetOptions{ - SkipTlsVerify: opts.SkipTLSVerify, - PathRewrite: pathRewriteToProto(opts.PathRewrite), - CustomHeaders: opts.CustomHeaders, - DirectUpstream: opts.DirectUpstream, + SkipTlsVerify: opts.SkipTLSVerify, + PathRewrite: pathRewriteToProto(opts.PathRewrite), + CustomHeaders: opts.CustomHeaders, + DirectUpstream: opts.DirectUpstream, + AgentNetwork: opts.AgentNetwork, + DisableAccessLog: opts.DisableAccessLog, } if opts.RequestTimeout != 0 { popts.RequestTimeout = durationpb.New(opts.RequestTimeout) } + if len(opts.Middlewares) > 0 { + popts.Middlewares = middlewaresToProto(opts.Middlewares) + } + popts.CaptureMaxRequestBytes = opts.CaptureMaxRequestBytes + popts.CaptureMaxResponseBytes = opts.CaptureMaxResponseBytes + if len(opts.CaptureContentTypes) > 0 { + popts.CaptureContentTypes = append([]string(nil), opts.CaptureContentTypes...) + } return popts } +// middlewaresToProto converts the internal middleware slice to the proto +// representation sent to the proxy via the mapping stream. +func middlewaresToProto(in []MiddlewareConfig) []*proto.MiddlewareConfig { + out := make([]*proto.MiddlewareConfig, 0, len(in)) + for _, m := range in { + pm := &proto.MiddlewareConfig{ + Id: m.ID, + Enabled: m.Enabled, + Slot: middlewareSlotToProto(m.Slot), + ConfigJson: append([]byte(nil), m.ConfigJSON...), + CanMutate: m.CanMutate, + FailMode: middlewareFailModeToProto(m.FailMode), + } + if m.TimeoutMs > 0 { + pm.Timeout = durationpb.New(time.Duration(m.TimeoutMs) * time.Millisecond) + } + out = append(out, pm) + } + return out +} + +func middlewareSlotToProto(s MiddlewareSlot) proto.MiddlewareSlot { + switch s { + case MiddlewareSlotOnRequest: + return proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST + case MiddlewareSlotOnResponse: + return proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE + case MiddlewareSlotTerminal: + return proto.MiddlewareSlot_MIDDLEWARE_SLOT_TERMINAL + default: + return proto.MiddlewareSlot_MIDDLEWARE_SLOT_UNSPECIFIED + } +} + +func middlewareFailModeToProto(m MiddlewareFailMode) proto.MiddlewareConfig_FailMode { + if m == MiddlewareFailClosed { + return proto.MiddlewareConfig_FAIL_CLOSED + } + return proto.MiddlewareConfig_FAIL_OPEN +} + // l4TargetOptionsToProto converts L4-relevant target options to proto. func l4TargetOptionsToProto(target *Target) *proto.PathTargetOptions { if !target.ProxyProtocol && target.Options.RequestTimeout == 0 && target.Options.SessionIdleTimeout == 0 { diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index ae82b60fe..1c78af9d0 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -26,9 +26,11 @@ import ( "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" activitystore "github.com/netbirdio/netbird/management/server/activity/store" + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork" nbcache "github.com/netbirdio/netbird/management/server/cache" nbContext "github.com/netbirdio/netbird/management/server/context" nbhttp "github.com/netbirdio/netbird/management/server/http" @@ -120,7 +122,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount, s.AgentNetworkManager()) if err != nil { log.Fatalf("failed to create API handler: %v", err) } @@ -223,11 +225,35 @@ func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { s.AfterInit(func(s *BaseServer) { proxyService.SetServiceManager(s.ServiceManager()) proxyService.SetProxyController(s.ServiceProxyController()) + proxyService.SetAgentNetworkSynthesizer(newAgentNetworkSynthesizer(s.Store())) + proxyService.SetAgentNetworkLimitsService(s.AgentNetworkManager()) }) return proxyService }) } +// agentNetworkSynthesizerAdapter implements nbgrpc.AgentNetworkSynthesizer by +// delegating to the agentnetwork package's store-backed synthesiser. +type agentNetworkSynthesizerAdapter struct { + store store.Store +} + +func newAgentNetworkSynthesizer(s store.Store) *agentNetworkSynthesizerAdapter { + return &agentNetworkSynthesizerAdapter{store: s} +} + +func (a *agentNetworkSynthesizerAdapter) SynthesizeServicesForCluster(ctx context.Context, clusterAddr string) ([]*rpservice.Service, error) { + return agentnetwork.SynthesizeServicesForCluster(ctx, a.store, clusterAddr) +} + +func (a *agentNetworkSynthesizerAdapter) SynthesizeServicesForAccount(ctx context.Context, accountID string) ([]*rpservice.Service, error) { + return agentnetwork.SynthesizeServices(ctx, a.store, accountID) +} + +func (a *agentNetworkSynthesizerAdapter) SynthesizeServiceForDomain(ctx context.Context, domain string) (*rpservice.Service, error) { + return agentnetwork.SynthesizeServiceForDomain(ctx, a.store, domain) +} + func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig { return Create(s, func() nbgrpc.ProxyOIDCConfig { return nbgrpc.ProxyOIDCConfig{ diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index a70da855a..6b1365f3b 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -20,6 +20,7 @@ import ( recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/idp" @@ -194,6 +195,24 @@ func (s *BaseServer) NetworksManager() networks.Manager { }) } +func (s *BaseServer) AgentNetworkManager() agentnetwork.Manager { + return Create(s, func() agentnetwork.Manager { + mgr := agentnetwork.NewManager( + s.Store(), + s.PermissionsManager(), + s.AccountManager(), + s.ServiceProxyController(), + ) + // Sweep expired agent-network access logs per account retention, + // reusing the reverse-proxy cleanup interval config. + mgr.StartAccessLogCleanup( + context.Background(), + s.Config.ReverseProxy.AccessLogCleanupIntervalHours, + ) + return mgr + }) +} + func (s *BaseServer) ZonesManager() zones.Manager { return Create(s, func() zones.Manager { return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain()) diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 76663f898..0dfa24bc4 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "math" "net" "net/http" "net/url" @@ -35,6 +36,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" proxyauth "github.com/netbirdio/netbird/proxy/auth" @@ -60,6 +62,23 @@ type ProxyTokenChecker interface { } // ProxyServiceServer implements the ProxyService gRPC server +// AgentNetworkSynthesizer produces in-memory reverse-proxy services from +// Agent Network provider/policy state for the proxy snapshot path; synthesised +// services never appear in the reverseproxy_services table. +type AgentNetworkSynthesizer interface { + SynthesizeServicesForCluster(ctx context.Context, clusterAddr string) ([]*rpservice.Service, error) + SynthesizeServicesForAccount(ctx context.Context, accountID string) ([]*rpservice.Service, error) + SynthesizeServiceForDomain(ctx context.Context, domain string) (*rpservice.Service, error) +} + +// AgentNetworkLimitsService is the minimal slice of agentnetwork.Manager the +// gRPC layer needs for CheckLLMPolicyLimits + RecordLLMUsage — kept narrow so +// the grpc package doesn't take a hard import on the full manager. +type AgentNetworkLimitsService interface { + SelectPolicyForRequest(ctx context.Context, in agentnetwork.PolicySelectionInput) (*agentnetwork.PolicySelectionResult, error) + RecordUsage(ctx context.Context, in agentnetwork.RecordUsageInput) error +} + type ProxyServiceServer struct { proto.UnimplementedProxyServiceServer @@ -72,6 +91,14 @@ type ProxyServiceServer struct { mu sync.RWMutex // Manager for reverse proxy operations serviceManager rpservice.Manager + // agentNetworkSynth produces synthesised reverse-proxy services from + // Agent Network state. Optional — when nil the snapshot path only ships + // persisted services. + agentNetworkSynth AgentNetworkSynthesizer + // agentNetworkLimits handles the pre-flight selection (CheckLLMPolicyLimits) + // and the post-flight consumption write (RecordLLMUsage). Optional — when + // nil both RPCs return Unimplemented. + agentNetworkLimits AgentNetworkLimitsService // ProxyController for service updates and cluster management proxyController proxy.Controller @@ -209,6 +236,127 @@ func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) { s.serviceManager = manager } +// SetAgentNetworkSynthesizer wires the agent-network service synthesiser. +// Optional — when nil the snapshot path skips agent-network synthesis. The +// modules layer injects this after both the proxy server and the agent-network +// manager are constructed. +func (s *ProxyServiceServer) SetAgentNetworkSynthesizer(synth AgentNetworkSynthesizer) { + s.mu.Lock() + s.agentNetworkSynth = synth + s.mu.Unlock() +} + +// SetAgentNetworkLimitsService wires the policy-selection + post-flight +// consumption sink. Pass nil to disable; both RPCs return Unimplemented while +// unset so partial wiring surfaces during integration. +func (s *ProxyServiceServer) SetAgentNetworkLimitsService(svc AgentNetworkLimitsService) { + s.mu.Lock() + s.agentNetworkLimits = svc + s.mu.Unlock() +} + +// agentNetworkSynthesizer returns the synthesiser under read lock. +func (s *ProxyServiceServer) agentNetworkSynthesizer() AgentNetworkSynthesizer { + s.mu.RLock() + defer s.mu.RUnlock() + return s.agentNetworkSynth +} + +// CheckLLMPolicyLimits is the pre-flight policy gate the proxy calls before +// forwarding an LLM request upstream. Delegates to the agent-network selector, +// which scores applicable policies by remaining headroom and returns the +// policy that pays for this request (or a deny when all are exhausted). +func (s *ProxyServiceServer) CheckLLMPolicyLimits(ctx context.Context, req *proto.CheckLLMPolicyLimitsRequest) (*proto.CheckLLMPolicyLimitsResponse, error) { + s.mu.RLock() + svc := s.agentNetworkLimits + s.mu.RUnlock() + if svc == nil { + return nil, status.Errorf(codes.Unimplemented, "agent-network limits service not configured on management") + } + if req.GetAccountId() == "" { + return nil, status.Errorf(codes.InvalidArgument, "account_id is required") + } + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + + res, err := svc.SelectPolicyForRequest(ctx, agentnetwork.PolicySelectionInput{ + AccountID: req.GetAccountId(), + UserID: req.GetUserId(), + GroupIDs: req.GetGroupIds(), + ProviderID: req.GetProviderId(), + }) + if err != nil { + log.WithContext(ctx).Errorf("select policy for request: %v", err) + return nil, status.Error(codes.Internal, "select policy failed") + } + + if !res.Allow { + return &proto.CheckLLMPolicyLimitsResponse{ + Decision: "deny", + SelectedPolicyId: res.SelectedPolicyID, + AttributionGroupId: res.AttributionGroupID, + WindowSeconds: res.WindowSeconds, + DenyCode: res.DenyCode, + DenyReason: res.DenyReason, + }, nil + } + return &proto.CheckLLMPolicyLimitsResponse{ + Decision: "allow", + SelectedPolicyId: res.SelectedPolicyID, + AttributionGroupId: res.AttributionGroupID, + WindowSeconds: res.WindowSeconds, + }, nil +} + +// RecordLLMUsage increments the per-(dimension, window) consumption counter for +// the user and optional attribution group after a served request. Returns +// Unimplemented when the agent-network limits service hasn't been wired. +func (s *ProxyServiceServer) RecordLLMUsage(ctx context.Context, req *proto.RecordLLMUsageRequest) (*proto.RecordLLMUsageResponse, error) { + s.mu.RLock() + svc := s.agentNetworkLimits + s.mu.RUnlock() + if svc == nil { + return nil, status.Errorf(codes.Unimplemented, "agent-network limits service not configured on management") + } + + accountID := req.GetAccountId() + if accountID == "" { + return nil, status.Errorf(codes.InvalidArgument, "account_id is required") + } + if err := enforceAccountScope(ctx, accountID); err != nil { + return nil, err + } + tokensIn := req.GetTokensInput() + tokensOut := req.GetTokensOutput() + costUSD := req.GetCostUsd() + + // Reject impossible counters at the boundary instead of recording them: + // a negative window, negative tokens, or a negative / non-finite cost + // would otherwise decrement or poison the persisted consumption totals. + if req.GetWindowSeconds() < 0 || tokensIn < 0 || tokensOut < 0 || costUSD < 0 || math.IsNaN(costUSD) || math.IsInf(costUSD, 0) { + return nil, status.Errorf(codes.InvalidArgument, "usage counters must be non-negative and finite") + } + + // Book the policy-window dimensions (when a policy cap bound this request) + // and every applicable account budget rule's window in a single batched + // transaction. + if err := svc.RecordUsage(ctx, agentnetwork.RecordUsageInput{ + AccountID: accountID, + UserID: req.GetUserId(), + AttributionGroupID: req.GetGroupId(), + GroupIDs: req.GetGroupIds(), + WindowSeconds: req.GetWindowSeconds(), + TokensIn: tokensIn, + TokensOut: tokensOut, + CostUSD: costUSD, + }); err != nil { + log.WithContext(ctx).Errorf("record usage: %v", err) + return nil, status.Error(codes.Internal, "record usage failed") + } + return &proto.RecordLLMUsageResponse{}, nil +} + // SetProxyController sets the proxy controller. Must be called before serving. func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) { s.mu.Lock() @@ -623,12 +771,40 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn * return nil, fmt.Errorf("get services from store: %w", err) } + if synth := s.agentNetworkSynthesizer(); synth != nil { + var synthesised []*rpservice.Service + var serr error + // Account-scoped connections synthesise only their own account, so the + // snapshot can never carry another tenant's mappings (which embed the + // upstream auth header derived from that tenant's provider API key). + // Global connections still see the whole cluster. + if conn.accountID != nil { + synthesised, serr = synth.SynthesizeServicesForAccount(ctx, *conn.accountID) + } else { + synthesised, serr = synth.SynthesizeServicesForCluster(ctx, conn.address) + } + if serr != nil { + // Surface a real synthesis failure instead of silently shipping an + // incomplete snapshot (which would drop the account's agent-network + // routes). Consistent with the persisted-services error above; the + // proxy retries the snapshot on connection error. + return nil, fmt.Errorf("synthesise agent-network services: %w", serr) + } + services = append(services, synthesised...) + } + oidcCfg := s.GetOIDCValidationConfig() var mappings []*proto.ProxyMapping for _, service := range services { if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address { continue } + // Defense in depth: an account-scoped proxy must never receive another + // account's mapping, matching the per-account filtering the incremental + // update path already applies. + if conn.accountID != nil && service.AccountID != *conn.accountID { + continue + } m := service.ToProtoMapping(rpservice.Create, "", oidcCfg) if !proxyAcceptsMapping(conn, m) { @@ -1617,7 +1793,29 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val } func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) { - return s.serviceManager.GetServiceByDomain(ctx, domain) + service, err := s.serviceManager.GetServiceByDomain(ctx, domain) + if err == nil { + return service, nil + } + + // Fall back to the Agent Network synthesiser scoped directly to the domain's + // account. Synthesised services are never persisted, so they must resolve + // here for OIDC / session / tunnel-peer flows against agent-network + // endpoints. Resolving by domain synthesises only the owning account rather + // than every tenant on the cluster. + if synth := s.agentNetworkSynthesizer(); synth != nil { + svc, serr := synth.SynthesizeServiceForDomain(ctx, domain) + if serr != nil { + // A real synthesis failure must surface, not be masked by the + // original store miss — otherwise a transient DB error looks like + // "no such service". + return nil, fmt.Errorf("synthesize agent-network service for %s: %w", domain, serr) + } + if svc != nil { + return svc, nil + } + } + return nil, err } func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error { diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 852193a3b..cfd809871 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -245,6 +245,37 @@ const ( // tunnel. Distinct from UserLoggedInPeer (full interactive login). UserExtendedPeerSession Activity = 125 + // AgentNetworkProviderCreated indicates that a user created an Agent Network provider + AgentNetworkProviderCreated Activity = 126 + // AgentNetworkProviderUpdated indicates that a user updated an Agent Network provider + AgentNetworkProviderUpdated Activity = 127 + // AgentNetworkProviderDeleted indicates that a user deleted an Agent Network provider + AgentNetworkProviderDeleted Activity = 128 + + // AgentNetworkPolicyCreated indicates that a user created an Agent Network policy + AgentNetworkPolicyCreated Activity = 129 + // AgentNetworkPolicyUpdated indicates that a user updated an Agent Network policy + AgentNetworkPolicyUpdated Activity = 130 + // AgentNetworkPolicyDeleted indicates that a user deleted an Agent Network policy + AgentNetworkPolicyDeleted Activity = 131 + + // AgentNetworkGuardrailCreated indicates that a user created an Agent Network guardrail + AgentNetworkGuardrailCreated Activity = 132 + // AgentNetworkGuardrailUpdated indicates that a user updated an Agent Network guardrail + AgentNetworkGuardrailUpdated Activity = 133 + // AgentNetworkGuardrailDeleted indicates that a user deleted an Agent Network guardrail + AgentNetworkGuardrailDeleted Activity = 134 + + // AgentNetworkBudgetRuleCreated indicates that a user created an Agent Network budget rule + AgentNetworkBudgetRuleCreated Activity = 135 + // AgentNetworkBudgetRuleUpdated indicates that a user updated an Agent Network budget rule + AgentNetworkBudgetRuleUpdated Activity = 136 + // AgentNetworkBudgetRuleDeleted indicates that a user deleted an Agent Network budget rule + AgentNetworkBudgetRuleDeleted Activity = 137 + + // AgentNetworkSettingsUpdated indicates that a user updated Agent Network account settings + AgentNetworkSettingsUpdated Activity = 139 + AccountDeleted Activity = 99999 ) @@ -400,6 +431,24 @@ var activityMap = map[Activity]Code{ UserExtendedPeerSession: {"User extended peer session", "user.peer.session.extend"}, + AgentNetworkProviderCreated: {"Agent Network provider created", "agent_network.provider.create"}, + AgentNetworkProviderUpdated: {"Agent Network provider updated", "agent_network.provider.update"}, + AgentNetworkProviderDeleted: {"Agent Network provider deleted", "agent_network.provider.delete"}, + + AgentNetworkPolicyCreated: {"Agent Network policy created", "agent_network.policy.create"}, + AgentNetworkPolicyUpdated: {"Agent Network policy updated", "agent_network.policy.update"}, + AgentNetworkPolicyDeleted: {"Agent Network policy deleted", "agent_network.policy.delete"}, + + AgentNetworkGuardrailCreated: {"Agent Network guardrail created", "agent_network.guardrail.create"}, + AgentNetworkGuardrailUpdated: {"Agent Network guardrail updated", "agent_network.guardrail.update"}, + AgentNetworkGuardrailDeleted: {"Agent Network guardrail deleted", "agent_network.guardrail.delete"}, + + AgentNetworkBudgetRuleCreated: {"Agent Network budget rule created", "agent_network.budget_rule.create"}, + AgentNetworkBudgetRuleUpdated: {"Agent Network budget rule updated", "agent_network.budget_rule.update"}, + AgentNetworkBudgetRuleDeleted: {"Agent Network budget rule deleted", "agent_network.budget_rule.delete"}, + + AgentNetworkSettingsUpdated: {"Agent Network settings updated", "agent_network.settings.update"}, + DomainAdded: {"Domain added", "domain.add"}, DomainDeleted: {"Domain deleted", "domain.delete"}, DomainValidated: {"Domain validated", "domain.validate"}, diff --git a/management/server/affectedpeers/proxy_synth_test.go b/management/server/affectedpeers/proxy_synth_test.go new file mode 100644 index 000000000..07273b18b --- /dev/null +++ b/management/server/affectedpeers/proxy_synth_test.go @@ -0,0 +1,95 @@ +package affectedpeers + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/store" +) + +// fakeProxyStore implements only the two store methods loadProxyServices calls; +// the embedded nil store.Store panics if anything else is invoked, which keeps +// the test honest about the surface under test. +type fakeProxyStore struct { + store.Store + proxyByCluster map[string][]string + persisted []*rpservice.Service +} + +func (f *fakeProxyStore) GetEmbeddedProxyPeerIDsByCluster(_ context.Context, _ string) (map[string][]string, error) { + return f.proxyByCluster, nil +} + +func (f *fakeProxyStore) GetAccountServices(_ context.Context, _ store.LockingStrength, _ string) ([]*rpservice.Service, error) { + return f.persisted, nil +} + +func serviceIDs(svcs []*rpservice.Service) []string { + ids := make([]string, 0, len(svcs)) + for _, s := range svcs { + ids = append(ids, s.ID) + } + return ids +} + +// loadProxyServices must merge the synthesised agent-network services (which are +// never persisted) with the persisted ones, so the proxy-affected expansion can +// see agent-network AccessGroups. Without this the embedded proxy peer is never +// flagged on a client group change and only a full resync (restart) recovers. +func TestLoadProxyServices_MergesSynthesizedAgentNetworkServices(t *testing.T) { + prev := agentNetworkSynthesizer + t.Cleanup(func() { agentNetworkSynthesizer = prev }) + SetAgentNetworkSynthesizer(func(_ context.Context, _ store.Store, _ string) ([]*rpservice.Service, error) { + return []*rpservice.Service{ + {ID: "agent-net-svc-acc", ProxyCluster: "proxy.netbird.local", Private: true, AccessGroups: []string{"gB"}}, + }, nil + }) + + s := &fakeProxyStore{ + proxyByCluster: map[string][]string{"proxy.netbird.local": {"proxy-peer-1"}}, + persisted: []*rpservice.Service{{ID: "persisted-rp-svc", ProxyCluster: "proxy.netbird.local"}}, + } + snap := &Snapshot{} + require.NoError(t, snap.loadProxyServices(context.Background(), s, "acc")) + + ids := serviceIDs(snap.services) + assert.Contains(t, ids, "persisted-rp-svc", "persisted services must be kept") + assert.Contains(t, ids, "agent-net-svc-acc", "synthesised agent-network service must be merged in") +} + +// With no synthesiser registered, loadProxyServices falls back to persisted +// services only (no panic, no behaviour change for non-agent-network builds). +func TestLoadProxyServices_NoSynthesizerRegistered(t *testing.T) { + prev := agentNetworkSynthesizer + t.Cleanup(func() { agentNetworkSynthesizer = prev }) + agentNetworkSynthesizer = nil + + s := &fakeProxyStore{ + proxyByCluster: map[string][]string{"c": {"proxy-1"}}, + persisted: []*rpservice.Service{{ID: "persisted"}}, + } + snap := &Snapshot{} + require.NoError(t, snap.loadProxyServices(context.Background(), s, "acc")) + assert.Equal(t, []string{"persisted"}, serviceIDs(snap.services)) +} + +// No embedded proxy peers → skip entirely (don't even call the synthesiser). +func TestLoadProxyServices_NoEmbeddedProxyPeersSkips(t *testing.T) { + prev := agentNetworkSynthesizer + t.Cleanup(func() { agentNetworkSynthesizer = prev }) + called := false + SetAgentNetworkSynthesizer(func(_ context.Context, _ store.Store, _ string) ([]*rpservice.Service, error) { + called = true + return nil, nil + }) + + s := &fakeProxyStore{proxyByCluster: map[string][]string{}} + snap := &Snapshot{} + require.NoError(t, snap.loadProxyServices(context.Background(), s, "acc")) + assert.False(t, called, "synthesiser must not run for accounts without embedded proxy peers") + assert.Empty(t, snap.services) +} diff --git a/management/server/affectedpeers/resolver.go b/management/server/affectedpeers/resolver.go index 94e24ced6..16a795539 100644 --- a/management/server/affectedpeers/resolver.go +++ b/management/server/affectedpeers/resolver.go @@ -29,6 +29,19 @@ import ( "github.com/netbirdio/netbird/route" ) +// agentNetworkSynthesizer returns the account's synthesised (never-persisted) +// agent-network reverse-proxy services. It is registered at boot via +// SetAgentNetworkSynthesizer to avoid an import cycle (agentnetwork → account → +// affectedpeers). nil when agent-network is not wired, in which case only +// persisted services are considered. +var agentNetworkSynthesizer func(ctx context.Context, s store.Store, accountID string) ([]*rpservice.Service, error) + +// SetAgentNetworkSynthesizer registers the agent-network service synthesiser. +// Called once during boot, before any request is served. +func SetAgentNetworkSynthesizer(fn func(ctx context.Context, s store.Store, accountID string) ([]*rpservice.Service, error)) { + agentNetworkSynthesizer = fn +} + // Snapshot is an in-memory view of the collections needed to expand a Change. // Loaded in-tx, walked by Expand after commit. Only the collections the Change // can touch are loaded; the rest stay nil (see Load). @@ -124,7 +137,12 @@ func (snap *Snapshot) loadDNS(ctx context.Context, s store.Store, accountID stri } // loadProxyServices loads the embedded-proxy cluster index, and the services only -// when the account actually has embedded proxy peers. +// when the account actually has embedded proxy peers. Both the persisted +// reverse-proxy services and the synthesised agent-network services are loaded: +// agent-network services are never persisted, so without synthesising them here +// collectFromProxyServices can't fold the embedded proxy peer into the affected +// set when a client's group changes, and the proxy never learns a newly +// authorised client until it reconnects (full network-map resync). func (snap *Snapshot) loadProxyServices(ctx context.Context, s store.Store, accountID string) error { var err error if snap.proxyByCluster, err = s.GetEmbeddedProxyPeerIDsByCluster(ctx, accountID); err != nil { @@ -133,8 +151,21 @@ func (snap *Snapshot) loadProxyServices(ctx context.Context, s store.Store, acco if len(snap.proxyByCluster) == 0 { return nil } - snap.services, err = s.GetAccountServices(ctx, store.LockingStrengthNone, accountID) - return err + if snap.services, err = s.GetAccountServices(ctx, store.LockingStrengthNone, accountID); err != nil { + return err + } + if agentNetworkSynthesizer == nil { + return nil + } + synth, serr := agentNetworkSynthesizer(ctx, s, accountID) + if serr != nil { + // Non-fatal: fall back to persisted services. The next full + // network-map resync still converges the proxy. + log.WithContext(ctx).Warnf("affectedpeers: synthesise agent-network services for account %s: %v", accountID, serr) + return nil + } + snap.services = append(snap.services, synth...) + return nil } // loadGroupIndex loads all groups (for group.Resources) and builds the diff --git a/management/server/agentnetwork_budgetrule_realstack_test.go b/management/server/agentnetwork_budgetrule_realstack_test.go new file mode 100644 index 000000000..d17f2e26a --- /dev/null +++ b/management/server/agentnetwork_budgetrule_realstack_test.go @@ -0,0 +1,126 @@ +package server + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork" + agenttypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" +) + +// TestAgentNetwork_BudgetRuleCRUD_RealManager is the GC-1 no-mock guard for the +// account budget-rule manager surface: real DefaultAccountManager, real store, +// real permissions. It exercises create/get/list/update/delete through the +// permission-gated manager (not the store directly) and asserts the reused +// PolicyLimits cap shape and targets survive each step. +func TestAgentNetwork_BudgetRuleCRUD_RealManager(t *testing.T) { + am, _, err := createManager(t) + require.NoError(t, err, "createManager must succeed") + ctx := context.Background() + + const ( + accountID = "agent-net-budget-acct" + adminUserID = "agent-net-budget-admin" + ) + account := newAccountWithId(ctx, accountID, adminUserID, "agent-net.test", "", "", false) + require.NoError(t, am.Store.SaveAccount(ctx, account), "SaveAccount must succeed") + + mgr := agentnetwork.NewManager(am.Store, permissions.NewManager(am.Store), am, nil) + + created, err := mgr.CreateBudgetRule(ctx, adminUserID, &agenttypes.AccountBudgetRule{ + AccountID: accountID, + Name: "eng-monthly", + Enabled: true, + TargetGroups: []string{"grp-eng"}, + TargetUsers: []string{"user-alice"}, + Limits: agenttypes.PolicyLimits{ + TokenLimit: agenttypes.PolicyTokenLimit{Enabled: true, GroupCap: 100_000, UserCap: 10_000, WindowSeconds: 2_592_000}, + BudgetLimit: agenttypes.PolicyBudgetLimit{Enabled: true, GroupCapUsd: 500, WindowSeconds: 2_592_000}, + }, + }) + require.NoError(t, err, "CreateBudgetRule must succeed") + require.NotEmpty(t, created.ID, "create must mint an ID") + + got, err := mgr.GetBudgetRule(ctx, accountID, adminUserID, created.ID) + require.NoError(t, err, "GetBudgetRule must succeed") + assert.Equal(t, "eng-monthly", got.Name, "name round-trips through the manager") + assert.Equal(t, []string{"grp-eng"}, got.TargetGroups, "target groups round-trip") + assert.Equal(t, int64(100_000), got.Limits.TokenLimit.GroupCap, "token group cap round-trips") + + list, err := mgr.GetAllBudgetRules(ctx, accountID, adminUserID) + require.NoError(t, err, "GetAllBudgetRules must succeed") + require.Len(t, list, 1, "exactly the one created rule must be listed") + + created.Limits.TokenLimit.GroupCap = 200_000 + updated, err := mgr.UpdateBudgetRule(ctx, adminUserID, created) + require.NoError(t, err, "UpdateBudgetRule must succeed") + assert.Equal(t, int64(200_000), updated.Limits.TokenLimit.GroupCap, "updated cap must persist") + + require.NoError(t, mgr.DeleteBudgetRule(ctx, accountID, adminUserID, created.ID), "DeleteBudgetRule must succeed") + _, err = mgr.GetBudgetRule(ctx, accountID, adminUserID, created.ID) + assert.Error(t, err, "get after delete must fail") +} + +// TestAgentNetwork_UpdateSettings_PreservesImmutableAndTogglesCollection is the +// GC-1 guard for UpdateSettings: it must apply the collection toggles while +// preserving the immutable Cluster/Subdomain pinned at bootstrap. +func TestAgentNetwork_UpdateSettings_PreservesImmutableAndTogglesCollection(t *testing.T) { + am, _, err := createManager(t) + require.NoError(t, err, "createManager must succeed") + ctx := context.Background() + + const ( + accountID = "agent-net-settings-acct" + adminUserID = "agent-net-settings-admin" + clusterAddr = "eu.proxy.netbird.io" + ) + account := newAccountWithId(ctx, accountID, adminUserID, "agent-net.test", "", "", false) + require.NoError(t, am.Store.SaveAccount(ctx, account), "SaveAccount must succeed") + + mgr := agentnetwork.NewManager(am.Store, permissions.NewManager(am.Store), am, nil) + + // Creating a provider bootstraps the settings row (cluster + subdomain). + _, err = mgr.CreateProvider(ctx, adminUserID, &agenttypes.Provider{ + AccountID: accountID, + ProviderID: "openai_api", + Name: "openai", + UpstreamURL: "https://api.openai.com", + APIKey: "sk-test", + Enabled: true, + Models: []agenttypes.ProviderModel{{ID: "gpt-5.4"}}, + }, clusterAddr) + require.NoError(t, err, "CreateProvider must bootstrap settings") + + before, err := mgr.GetSettings(ctx, accountID, adminUserID) + require.NoError(t, err, "GetSettings must succeed after bootstrap") + require.Equal(t, clusterAddr, before.Cluster, "cluster pinned at bootstrap") + require.NotEmpty(t, before.Subdomain, "subdomain pinned at bootstrap") + assert.False(t, before.EnablePromptCollection, "prompt collection defaults off") + + // Attempt to flip toggles AND smuggle a different cluster/subdomain — the + // immutable fields must be ignored. + updated, err := mgr.UpdateSettings(ctx, adminUserID, &agenttypes.Settings{ + AccountID: accountID, + Cluster: "attacker.cluster", + Subdomain: "evil", + EnableLogCollection: true, + EnablePromptCollection: true, + RedactPii: true, + }) + require.NoError(t, err, "UpdateSettings must succeed") + assert.Equal(t, before.Cluster, updated.Cluster, "cluster is immutable and must be preserved") + assert.Equal(t, before.Subdomain, updated.Subdomain, "subdomain is immutable and must be preserved") + assert.True(t, updated.EnableLogCollection, "log collection toggle must apply") + assert.True(t, updated.EnablePromptCollection, "prompt collection toggle must apply") + assert.True(t, updated.RedactPii, "redact toggle must apply") + + reloaded, err := am.Store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID) + require.NoError(t, err) + assert.Equal(t, before.Cluster, reloaded.Cluster, "persisted cluster unchanged") + assert.True(t, reloaded.EnablePromptCollection, "persisted prompt collection toggled on") +} diff --git a/management/server/agentnetwork_proxypeer_restart_test.go b/management/server/agentnetwork_proxypeer_restart_test.go new file mode 100644 index 000000000..1e4b8d016 --- /dev/null +++ b/management/server/agentnetwork_proxypeer_restart_test.go @@ -0,0 +1,199 @@ +package server + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork" + agenttypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +// TestAgentNetwork_ProxyRestart_PropagatesNewPeerAndDropsStale is the no-mock +// regression guard for the bug the user reported: restarting the proxy creates +// a fresh embedded peer with a NEW WireGuard public key (the proxy generates +// the keypair on every startup at proxy/internal/roundtrip/netbird.go:312). +// The PRIOR embedded peer record is never deleted on management, so the +// account accumulates a stale peer holding a stale CGNAT IP. Other peers +// in the account either keep routing to the dead IP, or — if synth DNS +// picks the wrong record — never see the new IP at all. +// +// What this test exercises (no mocks): +// - real SQLite test store +// - real DefaultAccountManager, network-map controller, peer-update channels +// - real peers.Manager.CreateProxyPeer path (the very method the proxy +// invokes over gRPC on every startup) +// - real agentnetwork.Manager + synth chain so the client receives a +// concrete DNS record that must point at the LATEST proxy peer. +// +// Pre-fix expected behavior (red): two embedded peers exist after the +// "restart"; the synth DNS record points at the stale one; the client +// receives an update reflecting the new peer but the old one lingers. +// Post-fix expected behavior (green): exactly one embedded peer exists +// after restart (with the new key) AND the client's network map carries +// the synth DNS pointing at that new peer's CGNAT IP. +func TestAgentNetwork_ProxyRestart_PropagatesNewPeerAndDropsStale(t *testing.T) { + am, updateManager, err := createManager(t) + require.NoError(t, err, "createManager must succeed") + ctx := context.Background() + + const ( + accountID = "an-restart-acct" + adminUserID = "an-restart-admin" + groupAID = "an-restart-grp-A" + clusterAddr = "eu.proxy.netbird.io" + clientKey = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=" + // Two different proxy pubkeys — the "before" and "after" of a + // proxy-process restart with fresh-keypair generation. + proxyKey1 = "Aaaaa1aaaaYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=" + proxyKey2 = "Bbbbb2bbbbYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=" + ) + + // --- Account scaffold --- + account := newAccountWithId(ctx, accountID, adminUserID, "an-restart.test", "", "", false) + require.NoError(t, am.Store.SaveAccount(ctx, account)) + + clientPeer := &nbpeer.Peer{ + Key: clientKey, + Name: "an-restart-client", + DNSLabel: "an-restart-client", + Meta: nbpeer.PeerSystemMeta{Hostname: "an-restart-client", GoOS: "linux", WtVersion: "development"}, + } + addedClient, _, _, _, err := am.AddPeer(ctx, "", "", adminUserID, clientPeer, false) + require.NoError(t, err, "AddPeer for client must succeed") + require.NoError(t, am.MarkPeerConnected(ctx, clientKey, accountID, time.Now().UnixNano(), &types.NetworkMap{}), + "MarkPeerConnected for the client peer must succeed (affected-peer fan-out skips disconnected peers)") + + // Place the client in group A so the synth policy reaches it. + account, err = am.Store.GetAccount(ctx, accountID) + require.NoError(t, err) + account.Groups[groupAID] = &types.Group{ID: groupAID, Name: "groupA", Peers: []string{addedClient.ID}} + require.NoError(t, am.Store.SaveAccount(ctx, account), "SaveAccount must persist group A") + + // --- Real peers + agent-network managers --- + permMgr := permissions.NewManager(am.Store) + peersMgr := peers.NewManager(am.Store, permMgr) + peersMgr.SetAccountManager(am) + peersMgr.SetNetworkMapController(am.networkMapController) + agentMgr := agentnetwork.NewManager(am.Store, permMgr, am, nil) + + // Subscribe BEFORE any state-mutating call so we don't lose the update + // that contains the synth DNS record. + clientCh := updateManager.CreateChannel(ctx, addedClient.ID) + t.Cleanup(func() { updateManager.CloseChannel(ctx, addedClient.ID) }) + drain(clientCh) + + // --- First proxy startup: register peer key K1, then mark it + // connected. In production the proxy follows CreateProxyPeer with the + // regular sync stream which lands on MarkPeerConnected; the synth DNS + // path filters out peers that aren't Connected (types/account.go:323), + // so without this step no DNS record would be emitted. + require.NoError(t, peersMgr.CreateProxyPeer(ctx, accountID, proxyKey1, clusterAddr), + "first CreateProxyPeer (proxy startup) must succeed") + + peer1ID, err := am.Store.GetPeerIDByKey(ctx, store.LockingStrengthNone, proxyKey1) + require.NoError(t, err, "proxy peer for K1 must be persisted after CreateProxyPeer") + require.NotEmpty(t, peer1ID) + + require.NoError(t, am.MarkPeerConnected(ctx, proxyKey1, accountID, time.Now().UnixNano(), &types.NetworkMap{}), + "MarkPeerConnected for K1 must succeed") + + account, err = am.Store.GetAccount(ctx, accountID) + require.NoError(t, err) + proxyIP1 := account.Peers[peer1ID].IP.String() + require.NotEmpty(t, proxyIP1, "K1 must have an assigned overlay IP") + + // --- Provider + policy. CreateProvider / CreatePolicy trigger the + // agentnetwork reconcile which runs UpdateAccountPeers; the resulting + // NetworkMap delivered to the client carries the synth DNS record + // pointing at K1's IP. --- + provider, err := agentMgr.CreateProvider(ctx, adminUserID, &agenttypes.Provider{ + AccountID: accountID, + ProviderID: "openai_api", + Name: "openai-test", + UpstreamURL: "https://api.openai.com", + APIKey: "sk-test-key", + Enabled: true, + Models: []agenttypes.ProviderModel{{ID: "gpt-5.4"}}, + }, clusterAddr) + require.NoError(t, err, "CreateProvider must succeed") + + _, err = agentMgr.CreatePolicy(ctx, adminUserID, &agenttypes.Policy{ + AccountID: accountID, + Name: "p1", + Enabled: true, + SourceGroups: []string{groupAID}, + DestinationProviderIDs: []string{provider.ID}, + }) + require.NoError(t, err, "CreatePolicy must succeed") + + settings, err := am.Store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID) + require.NoError(t, err) + fqdn := settings.Endpoint() + + rdata1 := awaitZoneRData(clientCh, clusterAddr, fqdn, true) + require.Equal(t, proxyIP1, rdata1, + "client must receive a synth DNS record pointing at K1's overlay IP after the synth path runs") + drain(clientCh) + + // --- Proxy restart: NEW keypair K2, same account, same cluster --- + require.NoError(t, peersMgr.CreateProxyPeer(ctx, accountID, proxyKey2, clusterAddr), + "second CreateProxyPeer (proxy restart with fresh keypair) must succeed") + + peer2ID, err := am.Store.GetPeerIDByKey(ctx, store.LockingStrengthNone, proxyKey2) + require.NoError(t, err, "proxy peer for K2 must be persisted after restart") + require.NotEmpty(t, peer2ID) + + require.NoError(t, am.MarkPeerConnected(ctx, proxyKey2, accountID, time.Now().UnixNano(), &types.NetworkMap{}), + "MarkPeerConnected for K2 must succeed") + + // In production the agent's sync stream pulls a fresh NetworkMap as + // part of its normal reconcile cadence; in this isolated test + // MarkPeerConnected's affected-peer fan-out can race the channel-side + // buffer in a way that swallows the synth-DNS-bearing update before + // our await reads it. Trigger an explicit account-wide fan-out so the + // assertion below tests what production actually delivers, not the + // in-test buffer race. + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate}) + + account, err = am.Store.GetAccount(ctx, accountID) + require.NoError(t, err) + proxyIP2 := account.Peers[peer2ID].IP.String() + require.NotEmpty(t, proxyIP2, "K2 must have an assigned overlay IP") + require.NotEqual(t, proxyIP1, proxyIP2, "K2 must get a different overlay IP than K1 (sanity)") + + // CRITICAL ASSERTION 1: K1 must no longer be in the store. The SqlStore + // returns ("", nil) for a missing key rather than NotFound, so assert + // on the returned ID being empty. + staleID, err := am.Store.GetPeerIDByKey(ctx, store.LockingStrengthNone, proxyKey1) + require.NoError(t, err, "GetPeerIDByKey for a missing peer must not error") + assert.Empty(t, staleID, + "stale embedded proxy peer K1 must be removed when a new embedded peer registers for the same (account, cluster); pre-fix this assertion fails because management never cleans up the prior peer record") + + // CRITICAL ASSERTION 2: exactly one embedded proxy peer remains, and it + // is K2. + account, err = am.Store.GetAccount(ctx, accountID) + require.NoError(t, err) + embeddedKeys := []string{} + for _, p := range account.Peers { + if p.ProxyMeta.Embedded { + embeddedKeys = append(embeddedKeys, p.Key) + } + } + assert.Equal(t, []string{proxyKey2}, embeddedKeys, + "after a proxy restart exactly one embedded proxy peer should remain — the one with the new key K2") + + // CRITICAL ASSERTION 3: the synth DNS record the client receives now + // points at K2's IP, not K1's. + rdata2 := awaitZoneRData(clientCh, clusterAddr, fqdn, true) + assert.Equal(t, proxyIP2, rdata2, + "after proxy restart, the client's synth DNS record must point at the NEW embedded peer's IP, not the stale K1 IP") +} diff --git a/management/server/agentnetwork_realstack_test.go b/management/server/agentnetwork_realstack_test.go new file mode 100644 index 000000000..e7855c575 --- /dev/null +++ b/management/server/agentnetwork_realstack_test.go @@ -0,0 +1,212 @@ +package server + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + networkmap "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork" + agenttypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + nbproto "github.com/netbirdio/netbird/shared/management/proto" +) + +// TestAgentNetwork_ProviderCRUD_FansOutToProxyAndClientPeers is the no-mock +// integration test for the live propagation path: a provider/policy mutation +// through the real agentnetwork.Manager triggers the real +// DefaultAccountManager.UpdateAccountPeers, which runs the real network-map +// controller (including AN-2b's injectAllProxyPolicies), and a network map is +// computed and fanned out to BOTH the embedded proxy peer and the client peer. +// +// Unlike the synthesizer/reconcile unit tests, nothing here is mocked: real +// SQLite store, real account manager + network-map controller, real +// agentnetwork manager, real peer update channels. The client peer's delivered +// map is asserted to actually carry the synth DNS surface, and provider +// create/delete are exercised end to end. +func TestAgentNetwork_ProviderCRUD_FansOutToProxyAndClientPeers(t *testing.T) { + am, updateManager, err := createManager(t) + require.NoError(t, err, "createManager must succeed") + ctx := context.Background() + + const ( + accountID = "agent-net-acct-1" + adminUserID = "agent-net-admin-1" + groupAID = "agent-net-grp-A" + clusterAddr = "eu.proxy.netbird.io" + clientKey = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=" + proxyPeerID = "agent-net-proxy-peer-1" + proxyPeerKey = "/yF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NyI=" + proxyIP = "100.64.0.99" + ) + + account := newAccountWithId(ctx, accountID, adminUserID, "agent-net.test", "", "", false) + require.NoError(t, am.Store.SaveAccount(ctx, account), "SaveAccount must succeed") + + // Real client peer through the production AddPeer path. + clientPeer := &nbpeer.Peer{ + Key: clientKey, + Name: "agent-net-client", + DNSLabel: "agent-net-client", + Meta: nbpeer.PeerSystemMeta{Hostname: "agent-net-client", GoOS: "linux", WtVersion: "development"}, + } + addedClient, _, _, _, err := am.AddPeer(ctx, "", "", adminUserID, clientPeer, false) + require.NoError(t, err, "AddPeer must add the client peer") + + // Inject a connected embedded proxy peer + put the client in the source group. + account, err = am.Store.GetAccount(ctx, accountID) + require.NoError(t, err) + account.Peers[proxyPeerID] = &nbpeer.Peer{ + ID: proxyPeerID, + AccountID: accountID, + Key: proxyPeerKey, + IP: netip.MustParseAddr(proxyIP), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: clusterAddr}, + DNSLabel: "agent-net-proxy", + } + account.Groups[groupAID] = &types.Group{ID: groupAID, Name: "groupA", Peers: []string{addedClient.ID}} + require.NoError(t, am.Store.SaveAccount(ctx, account), "SaveAccount must persist proxy peer + group") + + // Subscribe to BOTH peers' update channels — this is how we observe the + // real fan-out. + clientCh := updateManager.CreateChannel(ctx, addedClient.ID) + proxyCh := updateManager.CreateChannel(ctx, proxyPeerID) + t.Cleanup(func() { + updateManager.CloseChannel(ctx, addedClient.ID) + updateManager.CloseChannel(ctx, proxyPeerID) + }) + drain(clientCh) + drain(proxyCh) + + // Real agentnetwork manager wired to the real account manager. proxyController + // is nil (no gRPC cluster fan-out here) — the reconcile still fires + // UpdateAccountPeers, which is the path under test. + agentMgr := agentnetwork.NewManager(am.Store, permissions.NewManager(am.Store), am, nil) + + provider, err := agentMgr.CreateProvider(ctx, adminUserID, &agenttypes.Provider{ + AccountID: accountID, + ProviderID: "openai_api", + Name: "openai-test", + UpstreamURL: "https://api.openai.com", + APIKey: "sk-test-key", + Enabled: true, + Models: []agenttypes.ProviderModel{{ID: "gpt-5.4"}}, + }, clusterAddr) + require.NoError(t, err, "CreateProvider must succeed") + + policy, err := agentMgr.CreatePolicy(ctx, adminUserID, &agenttypes.Policy{ + AccountID: accountID, + Name: "p1", + Enabled: true, + SourceGroups: []string{groupAID}, + DestinationProviderIDs: []string{provider.ID}, + }) + require.NoError(t, err, "CreatePolicy must succeed") + + settings, err := am.Store.GetAgentNetworkSettings(ctx, store.LockingStrengthNone, accountID) + require.NoError(t, err) + fqdn := settings.Endpoint() + + // Both peers must receive a fan-out. The provider-create reconcile fires + // before the policy exists (synth service then has no AccessGroups, so no + // zone), and the async update buffer can collapse/reorder updates — so we + // poll until the client's delivered map actually carries the synth record. + rdata := awaitZoneRData(clientCh, clusterAddr, fqdn, true) + assert.Equal(t, proxyIP, rdata, + "client peer's delivered network map must contain the synth DNS record pointing at the embedded proxy peer") + require.True(t, awaitUpdate(proxyCh), "embedded proxy peer must also receive a netmap update after create") + + // UPDATE the provider — a new model on the existing service must still + // reconcile and keep the private surface routable (the live MODIFIED path). + provider.Models = append(provider.Models, agenttypes.ProviderModel{ID: "gpt-5.4-mini"}) + _, err = agentMgr.UpdateProvider(ctx, adminUserID, provider) + require.NoError(t, err, "UpdateProvider must succeed") + assert.Equal(t, proxyIP, awaitZoneRData(clientCh, clusterAddr, fqdn, true), + "client peer must still resolve the synth record after the provider is updated") + require.True(t, awaitUpdate(proxyCh), "embedded proxy peer must also receive a netmap update after update") + + // DELETE: detach the policy first (provider is in use), then drop the + // provider. Both peers update again and the synth surface disappears. + require.NoError(t, agentMgr.DeletePolicy(ctx, accountID, adminUserID, policy.ID), "DeletePolicy must succeed") + require.NoError(t, agentMgr.DeleteProvider(ctx, accountID, adminUserID, provider.ID), "DeleteProvider must succeed") + + require.True(t, awaitUpdate(proxyCh), "embedded proxy peer must also receive a netmap update after delete") + assert.Empty(t, awaitZoneRData(clientCh, clusterAddr, fqdn, false), + "synth DNS record must be gone from the client's map after the provider is deleted") +} + +// awaitZoneRData drains the channel for up to 8s. When wantPresent is true it +// returns as soon as the synth record appears (its RData). When false it drains +// to quiescence and returns the RData of the last delivered map (expected empty +// once the provider is gone), tolerating stale buffered updates that still +// carry the zone. +func awaitZoneRData(ch <-chan *networkmap.UpdateMessage, clusterAddr, fqdn string, wantPresent bool) string { + deadline := time.After(8 * time.Second) + last := "" + for { + select { + case m := <-ch: + if m == nil { + continue + } + last = synthZoneRData(m.Update, clusterAddr, fqdn) + if wantPresent && last != "" { + return last + } + case <-time.After(750 * time.Millisecond): + return last + case <-deadline: + return last + } + } +} + +// awaitUpdate reports whether at least one update arrives within the window. +func awaitUpdate(ch <-chan *networkmap.UpdateMessage) bool { + select { + case m := <-ch: + return m != nil + case <-time.After(5 * time.Second): + return false + } +} + +// drain empties any buffered updates (e.g. from AddPeer/SaveAccount) so the +// next observation reflects the operation under test. +func drain(ch <-chan *networkmap.UpdateMessage) { + for { + select { + case <-ch: + case <-time.After(200 * time.Millisecond): + return + } + } +} + +// synthZoneRData returns the RData of the synth A record (record name == fqdn) +// inside the cluster's custom zone, or "" when absent. +func synthZoneRData(sync *nbproto.SyncResponse, clusterAddr, fqdn string) string { + if sync == nil { + return "" + } + for _, zone := range sync.GetNetworkMap().GetDNSConfig().GetCustomZones() { + if zone.GetDomain() != dns.Fqdn(clusterAddr) { + continue + } + for _, rec := range zone.GetRecords() { + if rec.GetName() == dns.Fqdn(fqdn) { + return rec.GetRData() + } + } + } + return "" +} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 0abdb854d..a57f44b3c 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -23,6 +23,8 @@ import ( idpmanager "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork" + agentnetworkhandlers "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/handlers" "github.com/netbirdio/netbird/management/internals/modules/zones" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" "github.com/netbirdio/netbird/management/internals/modules/zones/records" @@ -59,7 +61,7 @@ import ( ) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc, agentNetworkManager agentnetwork.Manager) (http.Handler, error) { // Register bypass paths for unauthenticated endpoints if err := bypass.AddBypassPath("/api/instance"); err != nil { @@ -124,6 +126,9 @@ func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager accou zonesManager.RegisterEndpoints(router, zManager) recordsManager.RegisterEndpoints(router, rManager) idp.AddEndpoints(accountManager, router) + if agentNetworkManager != nil { + agentnetworkhandlers.RegisterEndpoints(agentNetworkManager, router) + } instance.AddEndpoints(instanceManager, accountManager, router) instance.AddVersionEndpoint(instanceManager, router) if serviceManager != nil && reverseProxyDomainManager != nil { diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 61584a615..8b05b2ddf 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -137,7 +137,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter() - apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -267,7 +267,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter() - apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index efe50c88f..43fbec15d 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -55,6 +55,7 @@ type DataSource interface { GetStoreEngine() types.Engine GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) GetProxyMetrics(ctx context.Context) (store.ProxyMetrics, error) + GetAgentNetworkMetrics(ctx context.Context) (store.AgentNetworkMetrics, error) } // ConnManager peer connection manager that holds state for current active connections @@ -413,6 +414,13 @@ func (w *Worker) generateProperties(ctx context.Context) properties { log.WithContext(ctx).Debugf("collect proxy metrics: %v", err) } + // Agent-network adoption + usage, aggregated across all accounts in a few + // cheap queries; nil on FileStore. + agentNetworkMetrics, err := w.dataSource.GetAgentNetworkMetrics(ctx) + if err != nil { + log.WithContext(ctx).Debugf("collect agent network metrics: %v", err) + } + minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions) metricsProperties["uptime"] = uptime metricsProperties["accounts"] = accounts @@ -471,6 +479,14 @@ func (w *Worker) generateProperties(ctx context.Context) properties { metricsProperties["proxies_connected"] = proxyMetrics.ProxiesConnected metricsProperties["custom_domains"] = customDomains metricsProperties["custom_domains_validated"] = customDomainsValidated + metricsProperties["agent_network_accounts"] = agentNetworkMetrics.Accounts + metricsProperties["agent_network_providers"] = agentNetworkMetrics.Providers + metricsProperties["agent_network_policies"] = agentNetworkMetrics.Policies + metricsProperties["agent_network_budget_rules"] = agentNetworkMetrics.BudgetRules + metricsProperties["agent_network_log_collection_enabled"] = agentNetworkMetrics.LogCollectionEnabled + metricsProperties["agent_network_input_tokens"] = agentNetworkMetrics.InputTokens + metricsProperties["agent_network_output_tokens"] = agentNetworkMetrics.OutputTokens + metricsProperties["agent_network_cost_usd"] = agentNetworkMetrics.CostUSD for targetType, count := range servicesTargetType { metricsProperties["services_target_type_"+string(targetType)] = count diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index ca9e10262..1fef89a2c 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -277,6 +277,21 @@ func (mockDatasource) GetProxyMetrics(_ context.Context) (store.ProxyMetrics, er }, nil } +// GetAgentNetworkMetrics returns canned agent-network counts so the +// generateProperties test can assert the adoption/usage signals end-to-end. +func (mockDatasource) GetAgentNetworkMetrics(_ context.Context) (store.AgentNetworkMetrics, error) { + return store.AgentNetworkMetrics{ + Accounts: 2, + Providers: 5, + Policies: 3, + BudgetRules: 1, + LogCollectionEnabled: 2, + InputTokens: 1000, + OutputTokens: 500, + CostUSD: 1.25, + }, nil +} + // TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties func TestGenerateProperties(t *testing.T) { ds := mockDatasource{} diff --git a/management/server/permissions/modules/module.go b/management/server/permissions/modules/module.go index 93007d4c1..a3a9c554d 100644 --- a/management/server/permissions/modules/module.go +++ b/management/server/permissions/modules/module.go @@ -19,6 +19,7 @@ const ( Pats Module = "pats" IdentityProviders Module = "identity_providers" Services Module = "services" + AgentNetwork Module = "agent_network" ) var All = map[Module]struct{}{ @@ -38,4 +39,5 @@ var All = map[Module]struct{}{ Pats: {}, IdentityProviders: {}, Services: {}, + AgentNetwork: {}, } diff --git a/management/server/store/file_store.go b/management/server/store/file_store.go index bcf563cd0..a776a8b42 100644 --- a/management/server/store/file_store.go +++ b/management/server/store/file_store.go @@ -280,3 +280,9 @@ func (s *FileStore) GetCustomDomainsCounts(_ context.Context) (int64, int64, err func (s *FileStore) GetProxyMetrics(_ context.Context) (ProxyMetrics, error) { return ProxyMetrics{}, nil } + +// GetAgentNetworkMetrics is a no-op for FileStore — agent-network state isn't +// persisted in the JSON file format. +func (s *FileStore) GetAgentNetworkMetrics(_ context.Context) (AgentNetworkMetrics, error) { + return AgentNetworkMetrics{}, nil +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 8bc4bcd7d..18be1b6ed 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -33,6 +33,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" @@ -137,6 +138,10 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{}, &accesslogs.AccessLogEntry{}, &proxy.Proxy{}, + &agentNetworkTypes.Provider{}, &agentNetworkTypes.Policy{}, &agentNetworkTypes.Guardrail{}, &agentNetworkTypes.Settings{}, + &agentNetworkTypes.Consumption{}, &agentNetworkTypes.AccountBudgetRule{}, + &agentNetworkTypes.AgentNetworkAccessLog{}, &agentNetworkTypes.AgentNetworkAccessLogGroup{}, + &agentNetworkTypes.AgentNetworkUsage{}, &agentNetworkTypes.AgentNetworkUsageGroup{}, ) if err != nil { return nil, fmt.Errorf("auto migratePreAuto: %w", err) @@ -5573,6 +5578,340 @@ func (s *SqlStore) CreateAccessLog(ctx context.Context, logEntry *accesslogs.Acc return nil } +// CreateAgentNetworkAccessLog persists a flattened agent-network access-log +// entry together with its authorising-group child rows in a single +// transaction. +func (s *SqlStore) CreateAgentNetworkAccessLog(ctx context.Context, entry *agentNetworkTypes.AgentNetworkAccessLog, groups []agentNetworkTypes.AgentNetworkAccessLogGroup) error { + err := s.db.Transaction(func(tx *gorm.DB) error { + // Idempotent on the log id / (log_id, group_id) so a proxy resend of the + // same entry can't fail the request. + if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(entry).Error; err != nil { + return err + } + if len(groups) > 0 { + if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&groups).Error; err != nil { + return err + } + } + return nil + }) + if err != nil { + log.WithContext(ctx).WithFields(log.Fields{ + "account_id": entry.AccountID, + "service_id": entry.ServiceID, + "model": entry.Model, + }).Errorf("failed to create agent-network access log entry in store: %v", err) + return status.Errorf(status.Internal, "failed to create agent-network access log entry in store") + } + return nil +} + +// CreateAgentNetworkUsage persists a stripped agent-network usage record +// together with its authorising-group child rows in a single transaction. +func (s *SqlStore) CreateAgentNetworkUsage(ctx context.Context, usage *agentNetworkTypes.AgentNetworkUsage, groups []agentNetworkTypes.AgentNetworkUsageGroup) error { + err := s.db.Transaction(func(tx *gorm.DB) error { + // Idempotent on the usage id / (usage_id, group_id) so a proxy resend of + // the same entry can't fail the request. + if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(usage).Error; err != nil { + return err + } + if len(groups) > 0 { + if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&groups).Error; err != nil { + return err + } + } + return nil + }) + if err != nil { + log.WithContext(ctx).WithFields(log.Fields{ + "account_id": usage.AccountID, + "model": usage.Model, + }).Errorf("failed to create agent-network usage record in store: %v", err) + return status.Errorf(status.Internal, "failed to create agent-network usage record in store") + } + return nil +} + +// DeleteOldAgentNetworkAccessLogs deletes an account's access-log rows (and +// their authorising-group child rows) older than the cutoff. Usage records are +// untouched — they are the long-term aggregate. Returns the number of log rows +// deleted. +func (s *SqlStore) DeleteOldAgentNetworkAccessLogs(ctx context.Context, accountID string, olderThan time.Time) (int64, error) { + var deleted int64 + err := s.db.Transaction(func(tx *gorm.DB) error { + // Remove group child rows for the soon-to-be-deleted logs first. + if err := tx.Exec( + "DELETE FROM agent_network_access_log_group WHERE account_id = ? AND log_id IN (SELECT id FROM agent_network_access_log WHERE account_id = ? AND timestamp < ?)", + accountID, accountID, olderThan, + ).Error; err != nil { + return err + } + res := tx.Where("account_id = ? AND timestamp < ?", accountID, olderThan). + Delete(&agentNetworkTypes.AgentNetworkAccessLog{}) + if res.Error != nil { + return res.Error + } + deleted = res.RowsAffected + return nil + }) + if err != nil { + log.WithContext(ctx).Errorf("failed to delete old agent-network access logs for account %s: %v", accountID, err) + return 0, status.Errorf(status.Internal, "failed to delete old agent-network access logs") + } + return deleted, nil +} + +// GetAgentNetworkUsageRows returns the stripped usage rows for an account that +// match the filter (date / user / group / provider / model). Aggregation into +// time buckets happens in the manager so granularities stay engine-portable. +func (s *SqlStore) GetAgentNetworkUsageRows(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkUsage, error) { + var rows []*agentNetworkTypes.AgentNetworkUsage + + query := s.applyAgentNetworkUsageFilters( + s.db.Where(accountIDCondition, accountID), + filter, + ).Order("timestamp ASC") + + if lockStrength != LockingStrengthNone { + query = query.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + if err := query.Find(&rows).Error; err != nil { + log.WithContext(ctx).Errorf("failed to get agent-network usage rows from store: %v", err) + return nil, status.Errorf(status.Internal, "failed to get agent-network usage rows from store") + } + return rows, nil +} + +// applyAgentNetworkUsageFilters applies the shared access-log filter's +// date/user/group/provider/model conditions to a usage-table query. Pagination, +// sort and free-text search are ignored — the overview is an aggregate. +func (s *SqlStore) applyAgentNetworkUsageFilters(query *gorm.DB, filter agentNetworkTypes.AgentNetworkAccessLogFilter) *gorm.DB { + if filter.UserID != nil { + query = query.Where("user_id = ?", *filter.UserID) + } + if filter.SessionID != nil { + query = query.Where("session_id = ?", *filter.SessionID) + } + if len(filter.ProviderIDs) > 0 { + query = query.Where("resolved_provider_id IN ?", filter.ProviderIDs) + } + if len(filter.Models) > 0 { + query = query.Where("model IN ?", filter.Models) + } + if len(filter.GroupIDs) > 0 { + query = query.Where( + "id IN (SELECT usage_id FROM agent_network_request_usage_group WHERE group_id IN ?)", + filter.GroupIDs, + ) + } + if filter.StartDate != nil { + query = query.Where("timestamp >= ?", *filter.StartDate) + } + if filter.EndDate != nil { + query = query.Where("timestamp <= ?", *filter.EndDate) + } + return query +} + +// GetAgentNetworkAccessLogs retrieves flattened agent-network access logs for +// an account with server-side pagination, filtering and sorting. Authorising +// group ids are hydrated from the group child table for the returned page. +func (s *SqlStore) GetAgentNetworkAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkAccessLog, int64, error) { + var logs []*agentNetworkTypes.AgentNetworkAccessLog + var totalCount int64 + + countQuery := s.applyAgentNetworkAccessLogFilters( + s.db.Model(&agentNetworkTypes.AgentNetworkAccessLog{}).Where(accountIDCondition, accountID), + filter, + ) + if err := countQuery.Count(&totalCount).Error; err != nil { + log.WithContext(ctx).Errorf("failed to count agent-network access logs: %v", err) + return nil, 0, status.Errorf(status.Internal, "failed to count agent-network access logs") + } + + query := s.applyAgentNetworkAccessLogFilters( + s.db.Where(accountIDCondition, accountID), + filter, + ). + Order(filter.GetSortColumn() + " " + filter.GetSortOrder()). + Limit(filter.GetLimit()). + Offset(filter.GetOffset()) + + if lockStrength != LockingStrengthNone { + query = query.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + if err := query.Find(&logs).Error; err != nil { + log.WithContext(ctx).Errorf("failed to get agent-network access logs from store: %v", err) + return nil, 0, status.Errorf(status.Internal, "failed to get agent-network access logs from store") + } + + if err := s.hydrateAgentNetworkAccessLogGroups(ctx, accountID, logs); err != nil { + return nil, 0, err + } + + return logs, totalCount, nil +} + +// applyAgentNetworkAccessLogFilters applies the filter conditions to a query. +func (s *SqlStore) applyAgentNetworkAccessLogFilters(query *gorm.DB, filter agentNetworkTypes.AgentNetworkAccessLogFilter) *gorm.DB { + if filter.Search != nil { + p := "%" + *filter.Search + "%" + query = query.Where( + "id LIKE ? OR host LIKE ? OR path LIKE ? OR model LIKE ? OR user_id IN (SELECT id FROM users WHERE email LIKE ? OR name LIKE ?)", + p, p, p, p, p, p, + ) + } + if filter.UserID != nil { + query = query.Where("user_id = ?", *filter.UserID) + } + if filter.SessionID != nil { + query = query.Where("session_id = ?", *filter.SessionID) + } + if filter.Decision != nil { + query = query.Where("decision = ?", *filter.Decision) + } + if filter.PathPrefix != nil { + query = query.Where("path LIKE ?", *filter.PathPrefix+"%") + } + if len(filter.ProviderIDs) > 0 { + query = query.Where("resolved_provider_id IN ?", filter.ProviderIDs) + } + if len(filter.Models) > 0 { + query = query.Where("model IN ?", filter.Models) + } + if len(filter.GroupIDs) > 0 { + query = query.Where( + "id IN (SELECT log_id FROM agent_network_access_log_group WHERE group_id IN ?)", + filter.GroupIDs, + ) + } + if filter.StartDate != nil { + query = query.Where("timestamp >= ?", *filter.StartDate) + } + if filter.EndDate != nil { + query = query.Where("timestamp <= ?", *filter.EndDate) + } + return query +} + +// hydrateAgentNetworkAccessLogGroups loads the authorising group ids for the +// given page of entries and assigns them onto each entry's GroupIDs field. +func (s *SqlStore) hydrateAgentNetworkAccessLogGroups(ctx context.Context, accountID string, logs []*agentNetworkTypes.AgentNetworkAccessLog) error { + if len(logs) == 0 { + return nil + } + + ids := make([]string, 0, len(logs)) + for _, l := range logs { + ids = append(ids, l.ID) + } + + var rows []agentNetworkTypes.AgentNetworkAccessLogGroup + if err := s.db. + Where(accountIDCondition, accountID). + Where("log_id IN ?", ids). + Find(&rows).Error; err != nil { + log.WithContext(ctx).Errorf("failed to hydrate agent-network access log groups: %v", err) + return status.Errorf(status.Internal, "failed to hydrate agent-network access log groups") + } + + byLog := make(map[string][]string, len(logs)) + for _, r := range rows { + byLog[r.LogID] = append(byLog[r.LogID], r.GroupID) + } + for _, l := range logs { + l.GroupIDs = byLog[l.ID] + } + return nil +} + +// agentNetworkSessionKeyExpr is the SQL group key for session-grouped access +// logs: the row's session id, or — when the client sent none — the row id, so +// session-less requests each form their own singleton group. COALESCE/NULLIF +// are standard SQL, so this stays portable across SQLite and Postgres. +const agentNetworkSessionKeyExpr = "COALESCE(NULLIF(session_id, ''), id)" + +// GetAgentNetworkAccessLogSessions retrieves agent-network access logs grouped +// by session, with server-side pagination, filtering and sorting at the session +// level. It paginates over the distinct session keys (ordered by the requested +// session-level aggregate), fetches every entry for the page's sessions, and +// folds them into per-session summaries. The returned count is the number of +// matching sessions. Filters apply to the entries, so a session's summary +// reflects only its filter-matching requests. +func (s *SqlStore) GetAgentNetworkAccessLogSessions(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkAccessLogSession, int64, error) { + // Count distinct sessions via a grouped subquery — portable and avoids + // relying on COUNT(DISTINCT ) quoting quirks. + sessionsSubquery := s.applyAgentNetworkAccessLogFilters( + s.db.Model(&agentNetworkTypes.AgentNetworkAccessLog{}).Where(accountIDCondition, accountID), + filter, + ). + Select(agentNetworkSessionKeyExpr + " AS session_key"). + Group(agentNetworkSessionKeyExpr) + + var totalCount int64 + if err := s.db.Table("(?) AS sessions", sessionsSubquery).Count(&totalCount).Error; err != nil { + log.WithContext(ctx).Errorf("failed to count agent-network access-log sessions: %v", err) + return nil, 0, status.Errorf(status.Internal, "failed to count agent-network access-log sessions") + } + + // The page of session keys, ordered by the session-level aggregate. The + // session-key tiebreaker keeps pagination deterministic when the primary + // aggregate ties. + type sessionKeyRow struct { + SessionKey string + } + var keyRows []sessionKeyRow + keyQuery := s.applyAgentNetworkAccessLogFilters( + s.db.Model(&agentNetworkTypes.AgentNetworkAccessLog{}).Where(accountIDCondition, accountID), + filter, + ). + Select(agentNetworkSessionKeyExpr + " AS session_key"). + Group(agentNetworkSessionKeyExpr). + Order(filter.GetSessionSortExpr() + " " + filter.GetSortOrder()). + Order("session_key ASC"). + Limit(filter.GetLimit()). + Offset(filter.GetOffset()) + if err := keyQuery.Scan(&keyRows).Error; err != nil { + log.WithContext(ctx).Errorf("failed to list agent-network access-log session keys: %v", err) + return nil, 0, status.Errorf(status.Internal, "failed to list agent-network access-log session keys") + } + if len(keyRows) == 0 { + return nil, totalCount, nil + } + + keys := make([]string, 0, len(keyRows)) + for _, r := range keyRows { + keys = append(keys, r.SessionKey) + } + + // All entries for the page's sessions, contiguous per session and oldest + // first within each — the fold relies on that ordering. + var entries []*agentNetworkTypes.AgentNetworkAccessLog + entriesQuery := s.applyAgentNetworkAccessLogFilters( + s.db.Where(accountIDCondition, accountID), + filter, + ). + Where(agentNetworkSessionKeyExpr+" IN ?", keys). + Order(agentNetworkSessionKeyExpr + ", timestamp ASC") + + if lockStrength != LockingStrengthNone { + entriesQuery = entriesQuery.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + if err := entriesQuery.Find(&entries).Error; err != nil { + log.WithContext(ctx).Errorf("failed to get agent-network access-log session entries: %v", err) + return nil, 0, status.Errorf(status.Internal, "failed to get agent-network access-log session entries") + } + + if err := s.hydrateAgentNetworkAccessLogGroups(ctx, accountID, entries); err != nil { + return nil, 0, err + } + + return agentNetworkTypes.FoldAccessLogSessions(keys, entries), totalCount, nil +} + // GetAccountAccessLogs retrieves access logs for a given account with pagination and filtering func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) { var logs []*accesslogs.AccessLogEntry diff --git a/management/server/store/sql_store_agentnetwork.go b/management/server/store/sql_store_agentnetwork.go new file mode 100644 index 000000000..b0df0cd2a --- /dev/null +++ b/management/server/store/sql_store_agentnetwork.go @@ -0,0 +1,664 @@ +package store + +import ( + "context" + "errors" + "fmt" + "math" + "time" + + log "github.com/sirupsen/logrus" + "gorm.io/gorm" + "gorm.io/gorm/clause" + + agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/shared/management/status" +) + +// GetAllAgentNetworkProviders returns Agent Network providers across +// every account. Used by the synthesizer to build the global service map. +func (s *SqlStore) GetAllAgentNetworkProviders(ctx context.Context, lockStrength LockingStrength) ([]*agentNetworkTypes.Provider, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var providers []*agentNetworkTypes.Provider + if result := tx.Find(&providers); result.Error != nil { + log.WithContext(ctx).Errorf("failed to get all agent network providers from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get all agent network providers from store") + } + + for _, provider := range providers { + if err := provider.DecryptSensitiveData(s.fieldEncrypt); err != nil { + log.WithContext(ctx).Errorf("failed to decrypt agent network provider %s: %v", provider.ID, err) + return nil, status.Errorf(status.Internal, "failed to decrypt agent network provider") + } + } + + return providers, nil +} + +// GetAgentNetworkMetrics returns aggregated agent-network adoption + usage +// counts for the self-hosted metrics worker. Each value is a single cheap +// aggregate; token/cost are summed over the always-collected per-request usage +// ledger (independent of the log-collection toggle) so they reflect real usage. +func (s *SqlStore) GetAgentNetworkMetrics(ctx context.Context) (AgentNetworkMetrics, error) { + var m AgentNetworkMetrics + db := s.db.WithContext(ctx) + + // Providers + distinct adopting accounts in one round-trip. + provRow := db.Model(&agentNetworkTypes.Provider{}). + Select("COUNT(*) AS providers, COUNT(DISTINCT account_id) AS accounts").Row() + if err := provRow.Scan(&m.Providers, &m.Accounts); err != nil { + return AgentNetworkMetrics{}, fmt.Errorf("scan agent network provider metrics: %w", err) + } + + if err := db.Model(&agentNetworkTypes.Policy{}).Count(&m.Policies).Error; err != nil { + return AgentNetworkMetrics{}, fmt.Errorf("count agent network policies: %w", err) + } + + if err := db.Model(&agentNetworkTypes.AccountBudgetRule{}).Count(&m.BudgetRules).Error; err != nil { + return AgentNetworkMetrics{}, fmt.Errorf("count agent network budget rules: %w", err) + } + + if err := db.Model(&agentNetworkTypes.Settings{}). + Where("enable_log_collection = ?", true).Count(&m.LogCollectionEnabled).Error; err != nil { + return AgentNetworkMetrics{}, fmt.Errorf("count agent network log-collection accounts: %w", err) + } + + // COALESCE so an empty ledger scans as 0 instead of NULL. + usageRow := db.Model(&agentNetworkTypes.AgentNetworkUsage{}). + Select("COALESCE(SUM(input_tokens), 0) AS input_tokens, " + + "COALESCE(SUM(output_tokens), 0) AS output_tokens, " + + "COALESCE(SUM(cost_usd), 0) AS cost_usd").Row() + if err := usageRow.Scan(&m.InputTokens, &m.OutputTokens, &m.CostUSD); err != nil { + return AgentNetworkMetrics{}, fmt.Errorf("scan agent network usage metrics: %w", err) + } + + return m, nil +} + +func (s *SqlStore) GetAccountAgentNetworkProviders(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Provider, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var providers []*agentNetworkTypes.Provider + result := tx.Find(&providers, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get agent network providers from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get agent network providers from store") + } + + for _, provider := range providers { + if err := provider.DecryptSensitiveData(s.fieldEncrypt); err != nil { + log.WithContext(ctx).Errorf("failed to decrypt agent network provider %s: %v", provider.ID, err) + return nil, status.Errorf(status.Internal, "failed to decrypt agent network provider") + } + } + + return providers, nil +} + +func (s *SqlStore) GetAgentNetworkProviderByID(ctx context.Context, lockStrength LockingStrength, accountID, providerID string) (*agentNetworkTypes.Provider, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var provider *agentNetworkTypes.Provider + result := tx.Take(&provider, accountAndIDQueryCondition, accountID, providerID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAgentNetworkProviderNotFoundError(providerID) + } + + log.WithContext(ctx).Errorf("failed to get agent network provider from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get agent network provider from store") + } + + if err := provider.DecryptSensitiveData(s.fieldEncrypt); err != nil { + log.WithContext(ctx).Errorf("failed to decrypt agent network provider %s: %v", provider.ID, err) + return nil, status.Errorf(status.Internal, "failed to decrypt agent network provider") + } + + return provider, nil +} + +func (s *SqlStore) SaveAgentNetworkProvider(ctx context.Context, provider *agentNetworkTypes.Provider) error { + providerCopy := provider.Copy() + if err := providerCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { + log.WithContext(ctx).Errorf("failed to encrypt agent network provider %s: %v", provider.ID, err) + return status.Errorf(status.Internal, "failed to encrypt agent network provider") + } + + result := s.db.Save(providerCopy) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save agent network provider to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save agent network provider to store") + } + + return nil +} + +func (s *SqlStore) DeleteAgentNetworkProvider(ctx context.Context, accountID, providerID string) error { + result := s.db.Delete(&agentNetworkTypes.Provider{}, accountAndIDQueryCondition, accountID, providerID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete agent network provider from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete agent network provider from store") + } + + if result.RowsAffected == 0 { + return status.NewAgentNetworkProviderNotFoundError(providerID) + } + + return nil +} + +func (s *SqlStore) GetAccountAgentNetworkPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Policy, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var policies []*agentNetworkTypes.Policy + result := tx.Find(&policies, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get agent network policies from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get agent network policies from store") + } + + return policies, nil +} + +func (s *SqlStore) GetAgentNetworkPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*agentNetworkTypes.Policy, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var policy *agentNetworkTypes.Policy + result := tx.Take(&policy, accountAndIDQueryCondition, accountID, policyID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAgentNetworkPolicyNotFoundError(policyID) + } + + log.WithContext(ctx).Errorf("failed to get agent network policy from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get agent network policy from store") + } + + return policy, nil +} + +func (s *SqlStore) SaveAgentNetworkPolicy(ctx context.Context, policy *agentNetworkTypes.Policy) error { + result := s.db.Save(policy) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save agent network policy to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save agent network policy to store") + } + + return nil +} + +func (s *SqlStore) DeleteAgentNetworkPolicy(ctx context.Context, accountID, policyID string) error { + result := s.db.Delete(&agentNetworkTypes.Policy{}, accountAndIDQueryCondition, accountID, policyID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete agent network policy from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete agent network policy from store") + } + + if result.RowsAffected == 0 { + return status.NewAgentNetworkPolicyNotFoundError(policyID) + } + + return nil +} + +func (s *SqlStore) GetAccountAgentNetworkGuardrails(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Guardrail, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var guardrails []*agentNetworkTypes.Guardrail + result := tx.Find(&guardrails, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get agent network guardrails from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get agent network guardrails from store") + } + + return guardrails, nil +} + +func (s *SqlStore) GetAgentNetworkGuardrailByID(ctx context.Context, lockStrength LockingStrength, accountID, guardrailID string) (*agentNetworkTypes.Guardrail, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var guardrail *agentNetworkTypes.Guardrail + result := tx.Take(&guardrail, accountAndIDQueryCondition, accountID, guardrailID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAgentNetworkGuardrailNotFoundError(guardrailID) + } + + log.WithContext(ctx).Errorf("failed to get agent network guardrail from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get agent network guardrail from store") + } + + return guardrail, nil +} + +func (s *SqlStore) SaveAgentNetworkGuardrail(ctx context.Context, guardrail *agentNetworkTypes.Guardrail) error { + result := s.db.Save(guardrail) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save agent network guardrail to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save agent network guardrail to store") + } + + return nil +} + +func (s *SqlStore) DeleteAgentNetworkGuardrail(ctx context.Context, accountID, guardrailID string) error { + result := s.db.Delete(&agentNetworkTypes.Guardrail{}, accountAndIDQueryCondition, accountID, guardrailID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete agent network guardrail from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete agent network guardrail from store") + } + + if result.RowsAffected == 0 { + return status.NewAgentNetworkGuardrailNotFoundError(guardrailID) + } + + return nil +} + +// GetAgentNetworkSettings returns the per-account Agent Network +// settings row. Returns status.NotFound when no row exists. +func (s *SqlStore) GetAgentNetworkSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*agentNetworkTypes.Settings, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var settings agentNetworkTypes.Settings + result := tx.Take(&settings, "account_id = ?", accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "agent network settings for account %s not found", accountID) + } + + log.WithContext(ctx).Errorf("failed to get agent network settings from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get agent network settings from store") + } + + return &settings, nil +} + +// GetAllAgentNetworkSettings returns every account's settings row. Used by the +// access-log retention sweep to learn each account's retention window. +func (s *SqlStore) GetAllAgentNetworkSettings(ctx context.Context, lockStrength LockingStrength) ([]*agentNetworkTypes.Settings, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var settings []*agentNetworkTypes.Settings + if err := tx.Find(&settings).Error; err != nil { + log.WithContext(ctx).Errorf("failed to list agent network settings: %v", err) + return nil, status.Errorf(status.Internal, "failed to list agent network settings") + } + return settings, nil +} + +// GetAgentNetworkSettingsByCluster returns every Settings row pinned to +// the given proxy cluster. Used by the bootstrap label generator to +// build the set of subdomains already taken on a cluster. +func (s *SqlStore) GetAgentNetworkSettingsByCluster(ctx context.Context, lockStrength LockingStrength, cluster string) ([]*agentNetworkTypes.Settings, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var settings []*agentNetworkTypes.Settings + result := tx.Find(&settings, "cluster = ?", cluster) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get agent network settings by cluster from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get agent network settings by cluster from store") + } + + return settings, nil +} + +// SaveAgentNetworkSettings upserts the per-account Agent Network +// settings row. +func (s *SqlStore) SaveAgentNetworkSettings(ctx context.Context, settings *agentNetworkTypes.Settings) error { + result := s.db.Save(settings) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save agent network settings to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save agent network settings to store") + } + + return nil +} + +// IncrementAgentNetworkConsumption atomically upserts the consumption +// row keyed on (account, dim_kind, dim_id, window_seconds, window_start) +// and adds the supplied deltas. Concurrent calls from multiple proxy +// nodes converge — the database performs the increment server-side via +// ON CONFLICT DO UPDATE so no read-modify-write race exists. +func (s *SqlStore) IncrementAgentNetworkConsumption( + ctx context.Context, + accountID string, + kind agentNetworkTypes.ConsumptionDimension, + dimID string, + windowSeconds int64, + windowStart time.Time, + tokensIn, tokensOut int64, + costUSD float64, +) error { + if accountID == "" || dimID == "" || windowSeconds <= 0 { + return status.Errorf(status.InvalidArgument, "account_id, dim_id and window_seconds must be set") + } + // Deltas are added server-side via ON CONFLICT; a negative or non-finite + // value would silently decrement / poison the persisted totals. + if tokensIn < 0 || tokensOut < 0 || costUSD < 0 || math.IsNaN(costUSD) || math.IsInf(costUSD, 0) { + return status.Errorf(status.InvalidArgument, "consumption deltas must be non-negative and finite") + } + row := agentNetworkTypes.Consumption{ + AccountID: accountID, + DimensionKind: kind, + DimensionID: dimID, + WindowSeconds: windowSeconds, + WindowStartUTC: windowStart.UTC(), + TokensInput: tokensIn, + TokensOutput: tokensOut, + CostUSD: costUSD, + UpdatedAt: time.Now().UTC(), + } + const tbl = "agent_network_consumption" + err := s.db.Clauses(clause.OnConflict{ + Columns: []clause.Column{ + {Name: "account_id"}, + {Name: "dim_kind"}, + {Name: "dim_id"}, + {Name: "window_seconds"}, + {Name: "window_start_utc"}, + }, + DoUpdates: clause.Assignments(map[string]any{ + "tokens_input": gorm.Expr(tbl+".tokens_input + ?", tokensIn), + "tokens_output": gorm.Expr(tbl+".tokens_output + ?", tokensOut), + "cost_usd": gorm.Expr(tbl+".cost_usd + ?", costUSD), + "updated_at": time.Now().UTC(), + }), + }).Create(&row).Error + if err != nil { + log.WithContext(ctx).Errorf("failed to increment agent network consumption: %v", err) + return status.Errorf(status.Internal, "failed to increment agent network consumption") + } + return nil +} + +// GetAgentNetworkConsumption returns the consumption row for the exact +// window key. Returns a zero-valued row (not found mapped to zero) so +// callers can use the result as the headroom basis without nil checks. +func (s *SqlStore) GetAgentNetworkConsumption( + ctx context.Context, + lockStrength LockingStrength, + accountID string, + kind agentNetworkTypes.ConsumptionDimension, + dimID string, + windowSeconds int64, + windowStart time.Time, +) (*agentNetworkTypes.Consumption, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var row agentNetworkTypes.Consumption + result := tx.Take(&row, + "account_id = ? AND dim_kind = ? AND dim_id = ? AND window_seconds = ? AND window_start_utc = ?", + accountID, kind, dimID, windowSeconds, windowStart.UTC()) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return &agentNetworkTypes.Consumption{ + AccountID: accountID, + DimensionKind: kind, + DimensionID: dimID, + WindowSeconds: windowSeconds, + WindowStartUTC: windowStart.UTC(), + }, nil + } + log.WithContext(ctx).Errorf("failed to get agent network consumption: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get agent network consumption") + } + return &row, nil +} + +// GetAgentNetworkConsumptionBatch reads many consumption counters for one +// account in a single query, returning a map keyed by the exact +// ConsumptionKey. Missing counters are simply absent from the map (callers +// treat absence as a zero counter). Replaces the per-cap point reads the +// policy selector previously issued one at a time. +func (s *SqlStore) GetAgentNetworkConsumptionBatch( + ctx context.Context, + lockStrength LockingStrength, + accountID string, + keys []agentNetworkTypes.ConsumptionKey, +) (map[agentNetworkTypes.ConsumptionKey]*agentNetworkTypes.Consumption, error) { + out := make(map[agentNetworkTypes.ConsumptionKey]*agentNetworkTypes.Consumption, len(keys)) + if len(keys) == 0 { + return out, nil + } + + // Collect the distinct dim ids, windows and window starts so a single + // query scopes to exactly the current windows in play, then filter the + // returned rows down to the exact requested keys. + wanted := make(map[agentNetworkTypes.ConsumptionKey]struct{}, len(keys)) + dimSet := make(map[string]struct{}) + winSet := make(map[int64]struct{}) + startSet := make(map[time.Time]struct{}) + for _, k := range keys { + k.WindowStartUTC = k.WindowStartUTC.UTC() + wanted[k] = struct{}{} + dimSet[k.DimID] = struct{}{} + winSet[k.WindowSeconds] = struct{}{} + startSet[k.WindowStartUTC] = struct{}{} + } + dimIDs := make([]string, 0, len(dimSet)) + for d := range dimSet { + dimIDs = append(dimIDs, d) + } + windows := make([]int64, 0, len(winSet)) + for w := range winSet { + windows = append(windows, w) + } + starts := make([]time.Time, 0, len(startSet)) + for t := range startSet { + starts = append(starts, t) + } + + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var rows []*agentNetworkTypes.Consumption + result := tx.Find(&rows, + "account_id = ? AND dim_id IN ? AND window_seconds IN ? AND window_start_utc IN ?", + accountID, dimIDs, windows, starts) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to batch-get agent network consumption: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get agent network consumption") + } + for _, row := range rows { + k := agentNetworkTypes.ConsumptionKey{ + Kind: row.DimensionKind, + DimID: row.DimensionID, + WindowSeconds: row.WindowSeconds, + WindowStartUTC: row.WindowStartUTC.UTC(), + } + if _, ok := wanted[k]; ok { + out[k] = row + } + } + return out, nil +} + +// IncrementAgentNetworkConsumptionBatch applies the same usage delta to every +// supplied counter inside a single transaction, so all per-(dimension, window) +// counters a served request books are written atomically in one round-trip +// instead of one upsert per counter. Keys are deduplicated by the caller. +func (s *SqlStore) IncrementAgentNetworkConsumptionBatch( + ctx context.Context, + accountID string, + keys []agentNetworkTypes.ConsumptionKey, + tokensIn, tokensOut int64, + costUSD float64, +) error { + if accountID == "" { + return status.Errorf(status.InvalidArgument, "account_id must be set") + } + if tokensIn < 0 || tokensOut < 0 || costUSD < 0 || math.IsNaN(costUSD) || math.IsInf(costUSD, 0) { + return status.Errorf(status.InvalidArgument, "consumption deltas must be non-negative and finite") + } + if len(keys) == 0 { + return nil + } + + const tbl = "agent_network_consumption" + err := s.db.Transaction(func(tx *gorm.DB) error { + for _, k := range keys { + if k.DimID == "" || k.WindowSeconds <= 0 { + return status.Errorf(status.InvalidArgument, "dim_id and window_seconds must be set") + } + now := time.Now().UTC() + row := agentNetworkTypes.Consumption{ + AccountID: accountID, + DimensionKind: k.Kind, + DimensionID: k.DimID, + WindowSeconds: k.WindowSeconds, + WindowStartUTC: k.WindowStartUTC.UTC(), + TokensInput: tokensIn, + TokensOutput: tokensOut, + CostUSD: costUSD, + UpdatedAt: now, + } + if err := tx.Clauses(clause.OnConflict{ + Columns: []clause.Column{ + {Name: "account_id"}, + {Name: "dim_kind"}, + {Name: "dim_id"}, + {Name: "window_seconds"}, + {Name: "window_start_utc"}, + }, + DoUpdates: clause.Assignments(map[string]any{ + "tokens_input": gorm.Expr(tbl+".tokens_input + ?", tokensIn), + "tokens_output": gorm.Expr(tbl+".tokens_output + ?", tokensOut), + "cost_usd": gorm.Expr(tbl+".cost_usd + ?", costUSD), + "updated_at": now, + }), + }).Create(&row).Error; err != nil { + return err + } + } + return nil + }) + if err != nil { + log.WithContext(ctx).Errorf("failed to batch-increment agent network consumption: %v", err) + return status.Errorf(status.Internal, "failed to increment agent network consumption") + } + return nil +} + +// ListAgentNetworkConsumption returns every consumption row recorded +// for the account, ordered by window_start descending. Backs the +// dashboard's basic counter view. +func (s *SqlStore) ListAgentNetworkConsumption( + ctx context.Context, + lockStrength LockingStrength, + accountID string, +) ([]*agentNetworkTypes.Consumption, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var rows []*agentNetworkTypes.Consumption + result := tx. + Order("window_start_utc DESC"). + Find(&rows, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to list agent network consumption: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to list agent network consumption") + } + return rows, nil +} + +// GetAccountAgentNetworkBudgetRules returns every account-level budget rule for +// the account. +func (s *SqlStore) GetAccountAgentNetworkBudgetRules(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.AccountBudgetRule, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var rules []*agentNetworkTypes.AccountBudgetRule + result := tx.Find(&rules, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get agent network budget rules from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get agent network budget rules from store") + } + + return rules, nil +} + +// GetAgentNetworkBudgetRuleByID returns a single budget rule scoped to the +// account, or a NotFound error. +func (s *SqlStore) GetAgentNetworkBudgetRuleByID(ctx context.Context, lockStrength LockingStrength, accountID, ruleID string) (*agentNetworkTypes.AccountBudgetRule, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var rule *agentNetworkTypes.AccountBudgetRule + result := tx.Take(&rule, accountAndIDQueryCondition, accountID, ruleID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAgentNetworkBudgetRuleNotFoundError(ruleID) + } + + log.WithContext(ctx).Errorf("failed to get agent network budget rule from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get agent network budget rule from store") + } + + return rule, nil +} + +// SaveAgentNetworkBudgetRule upserts a budget rule. +func (s *SqlStore) SaveAgentNetworkBudgetRule(ctx context.Context, rule *agentNetworkTypes.AccountBudgetRule) error { + result := s.db.Save(rule) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save agent network budget rule to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save agent network budget rule to store") + } + + return nil +} + +// DeleteAgentNetworkBudgetRule removes a budget rule scoped to the account. +func (s *SqlStore) DeleteAgentNetworkBudgetRule(ctx context.Context, accountID, ruleID string) error { + result := s.db.Delete(&agentNetworkTypes.AccountBudgetRule{}, accountAndIDQueryCondition, accountID, ruleID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete agent network budget rule from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete agent network budget rule from store") + } + + if result.RowsAffected == 0 { + return status.NewAgentNetworkBudgetRuleNotFoundError(ruleID) + } + + return nil +} diff --git a/management/server/store/sql_store_agentnetwork_accesslog_test.go b/management/server/store/sql_store_agentnetwork_accesslog_test.go new file mode 100644 index 000000000..793c82d79 --- /dev/null +++ b/management/server/store/sql_store_agentnetwork_accesslog_test.go @@ -0,0 +1,302 @@ +package store + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" +) + +// TestAgentNetworkUsage_RealStore_RoundTrip drives CreateAgentNetworkUsage and +// CreateAgentNetworkAccessLog through a real sqlite store to prove the schema +// migrates and the inserts succeed for both a populated (allowed) entry and a +// stripped (denied) entry. +func TestAgentNetworkUsage_RealStore_RoundTrip(t *testing.T) { + ctx := context.Background() + s, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + const accountID = "acc-anet-usage-1" + now := time.Now().UTC() + + // Populated (allowed) usage row with two authorising groups. + usage := &agentNetworkTypes.AgentNetworkUsage{ + ID: "log-allowed-1", + AccountID: accountID, + Timestamp: now, + UserID: "user-alice", + ResolvedProviderID: "prov-openai-1", + Provider: "openai", + Model: "gpt-4o", + SessionID: "sess-round-trip-1", + InputTokens: 1200, + OutputTokens: 640, + TotalTokens: 1840, + CostUSD: 0.0231, + } + usageGroups := []agentNetworkTypes.AgentNetworkUsageGroup{ + {UsageID: usage.ID, GroupID: "grp-eng", AccountID: accountID}, + {UsageID: usage.ID, GroupID: "grp-oncall", AccountID: accountID}, + } + require.NoError(t, s.CreateAgentNetworkUsage(ctx, usage, usageGroups), "populated usage insert must succeed") + + // Stripped (denied / 403) usage row: no provider/model/tokens, no groups. + denied := &agentNetworkTypes.AgentNetworkUsage{ + ID: "log-denied-1", + AccountID: accountID, + Timestamp: now, + UserID: "user-bob", + } + require.NoError(t, s.CreateAgentNetworkUsage(ctx, denied, nil), "stripped usage insert must succeed") + + // Idempotency: re-inserting the same id must not error. + require.NoError(t, s.CreateAgentNetworkUsage(ctx, usage, usageGroups), "duplicate usage insert must be idempotent") + + // Access-log row + group children. + entry := &agentNetworkTypes.AgentNetworkAccessLog{ + ID: "log-allowed-1", + AccountID: accountID, + ServiceID: "agent-net-svc-1", + Timestamp: now, + UserID: "user-alice", + StatusCode: 200, + Provider: "openai", + Model: "gpt-4o", + SessionID: "sess-round-trip-1", + InputTokens: 1200, + OutputTokens: 640, + TotalTokens: 1840, + CostUSD: 0.0231, + } + entryGroups := []agentNetworkTypes.AgentNetworkAccessLogGroup{ + {LogID: entry.ID, GroupID: "grp-eng", AccountID: accountID}, + {LogID: entry.ID, GroupID: "grp-oncall", AccountID: accountID}, + } + require.NoError(t, s.CreateAgentNetworkAccessLog(ctx, entry, entryGroups), "access-log insert must succeed") + + // Read back through the filtered list + verify group hydration. + logs, total, err := s.GetAgentNetworkAccessLogs(ctx, LockingStrengthNone, accountID, agentNetworkTypes.AgentNetworkAccessLogFilter{Page: 1, PageSize: 50}) + require.NoError(t, err, "list must succeed") + assert.Equal(t, int64(1), total, "one access-log row expected") + require.Len(t, logs, 1) + assert.ElementsMatch(t, []string{"grp-eng", "grp-oncall"}, logs[0].GroupIDs, "group ids must hydrate") + assert.Equal(t, "sess-round-trip-1", logs[0].SessionID, "session id must persist and read back on the access-log row") + + // Session filter narrows the access-log listing to one conversation. + sessionID := "sess-round-trip-1" + sessLogs, sessTotal, err := s.GetAgentNetworkAccessLogs(ctx, LockingStrengthNone, accountID, + agentNetworkTypes.AgentNetworkAccessLogFilter{Page: 1, PageSize: 50, SessionID: &sessionID}) + require.NoError(t, err) + assert.Equal(t, int64(1), sessTotal, "session filter must match the one row with that session id") + require.Len(t, sessLogs, 1) + assert.Equal(t, entry.ID, sessLogs[0].ID, "session filter must return the matching log row") + + bogus := "no-such-session" + _, emptyTotal, err := s.GetAgentNetworkAccessLogs(ctx, LockingStrengthNone, accountID, + agentNetworkTypes.AgentNetworkAccessLogFilter{Page: 1, PageSize: 50, SessionID: &bogus}) + require.NoError(t, err) + assert.Equal(t, int64(0), emptyTotal, "unknown session id must match nothing") + + // Session filter also narrows the always-on usage rows. + sessUsage, err := s.GetAgentNetworkUsageRows(ctx, LockingStrengthNone, accountID, + agentNetworkTypes.AgentNetworkAccessLogFilter{SessionID: &sessionID}) + require.NoError(t, err) + require.Len(t, sessUsage, 1, "session filter must narrow usage rows to the matching session") + assert.Equal(t, "sess-round-trip-1", sessUsage[0].SessionID, "usage row must carry the session id") +} + +// TestAgentNetworkUsageOverview_DailyAggregation drives GetAgentNetworkUsageRows +// + AggregateUsageByGranularity end-to-end against a real sqlite store, with +// two rows on the same day and one on another, plus a model filter. +func TestAgentNetworkUsageOverview_DailyAggregation(t *testing.T) { + ctx := context.Background() + s, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + const accountID = "acc-anet-overview-1" + day1 := time.Date(2026, 5, 5, 10, 0, 0, 0, time.UTC) + day1b := time.Date(2026, 5, 5, 22, 0, 0, 0, time.UTC) + day2 := time.Date(2026, 5, 6, 9, 0, 0, 0, time.UTC) + + mk := func(id string, ts time.Time, model string, in, out int64, cost float64) *agentNetworkTypes.AgentNetworkUsage { + return &agentNetworkTypes.AgentNetworkUsage{ + ID: id, AccountID: accountID, Timestamp: ts, Model: model, + InputTokens: in, OutputTokens: out, TotalTokens: in + out, CostUSD: cost, + } + } + require.NoError(t, s.CreateAgentNetworkUsage(ctx, mk("u1", day1, "gpt-4o", 100, 50, 0.10), nil)) + require.NoError(t, s.CreateAgentNetworkUsage(ctx, mk("u2", day1b, "gpt-4o", 200, 80, 0.20), nil)) + require.NoError(t, s.CreateAgentNetworkUsage(ctx, mk("u3", day2, "claude-3", 10, 5, 0.01), nil)) + + rows, err := s.GetAgentNetworkUsageRows(ctx, LockingStrengthNone, accountID, agentNetworkTypes.AgentNetworkAccessLogFilter{}) + require.NoError(t, err) + require.Len(t, rows, 3, "all three usage rows expected") + + buckets := agentNetworkTypes.AggregateUsageByGranularity(rows, agentNetworkTypes.UsageGranularityDay) + require.Len(t, buckets, 2, "two distinct days expected") + assert.Equal(t, "2026-05-05", buckets[0].PeriodStart, "oldest-first ordering") + assert.Equal(t, int64(300), buckets[0].InputTokens, "same-day input tokens summed") + assert.Equal(t, int64(130), buckets[0].OutputTokens) + assert.InDelta(t, 0.30, buckets[0].CostUSD, 1e-9, "same-day cost summed") + assert.Equal(t, "2026-05-06", buckets[1].PeriodStart) + assert.Equal(t, int64(15), buckets[1].TotalTokens) + + // Model filter narrows to a single day. + model := "claude-3" + filtered, err := s.GetAgentNetworkUsageRows(ctx, LockingStrengthNone, accountID, agentNetworkTypes.AgentNetworkAccessLogFilter{Models: []string{model}}) + require.NoError(t, err) + require.Len(t, filtered, 1, "model filter must narrow rows") + assert.Equal(t, "u3", filtered[0].ID) +} + +// TestAgentNetworkAccessLogSessions_RealStore drives GetAgentNetworkAccessLogSessions +// against a real sqlite store: session grouping + aggregation, recency ordering, +// singleton groups for session-less requests, session pagination, the model +// filter narrowing sessions, and aggregate sorting. +func TestAgentNetworkAccessLogSessions_RealStore(t *testing.T) { + ctx := context.Background() + s, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + const accountID = "acc-anet-sessions-1" + base := time.Date(2026, 5, 5, 10, 0, 0, 0, time.UTC) + at := func(h int) time.Time { return base.Add(time.Duration(h) * time.Hour) } + + mk := func(id, session, user, provider, model, decision string, ts time.Time, cost float64) *agentNetworkTypes.AgentNetworkAccessLog { + return &agentNetworkTypes.AgentNetworkAccessLog{ + ID: id, AccountID: accountID, ServiceID: "svc", Timestamp: ts, + UserID: user, StatusCode: 200, Provider: provider, Model: model, + SessionID: session, Decision: decision, + InputTokens: 100, OutputTokens: 50, TotalTokens: 150, CostUSD: cost, + } + } + + // Two-request session s1 (alice), a one-request denied session s2 (bob), and + // two session-less requests (empty session id) that must each form their own + // singleton group. + require.NoError(t, s.CreateAgentNetworkAccessLog(ctx, mk("s1-a", "s1", "alice", "openai", "gpt-4o", "allow", at(1), 0.10), + []agentNetworkTypes.AgentNetworkAccessLogGroup{{LogID: "s1-a", GroupID: "grp-eng", AccountID: accountID}})) + require.NoError(t, s.CreateAgentNetworkAccessLog(ctx, mk("s1-b", "s1", "alice", "openai", "gpt-4o", "allow", at(2), 0.20), + []agentNetworkTypes.AgentNetworkAccessLogGroup{{LogID: "s1-b", GroupID: "grp-oncall", AccountID: accountID}})) + require.NoError(t, s.CreateAgentNetworkAccessLog(ctx, mk("s2-a", "s2", "bob", "anthropic", "claude-3", "deny", at(3), 0.05), nil)) + require.NoError(t, s.CreateAgentNetworkAccessLog(ctx, mk("se-old", "", "carol", "openai", "o1", "allow", at(0), 0.01), nil)) + require.NoError(t, s.CreateAgentNetworkAccessLog(ctx, mk("se-new", "", "dave", "mistral", "mistral-large", "allow", at(4), 0.02), nil)) + + // Default sort: last activity (MAX timestamp) descending. + sessions, total, err := s.GetAgentNetworkAccessLogSessions(ctx, LockingStrengthNone, accountID, + agentNetworkTypes.AgentNetworkAccessLogFilter{Page: 1, PageSize: 50}) + require.NoError(t, err) + assert.Equal(t, int64(4), total, "four sessions: s1, s2, and two singletons") + require.Len(t, sessions, 4) + + // se-new(t4) > s2(t3) > s1(t2) > se-old(t0) + assert.Equal(t, "", sessions[0].SessionID, "newest is a session-less singleton") + assert.Equal(t, "se-new", sessions[0].Entries[0].ID) + assert.Equal(t, "s2", sessions[1].SessionID) + assert.Equal(t, "s1", sessions[2].SessionID) + assert.Equal(t, "se-old", sessions[3].Entries[0].ID) + + // s1 aggregation. + s1 := sessions[2] + assert.Equal(t, 2, s1.RequestCount, "s1 has two requests") + assert.Equal(t, int64(300), s1.TotalTokens, "tokens summed across the session") + assert.InDelta(t, 0.30, s1.CostUSD, 1e-9, "cost summed across the session") + assert.Equal(t, "alice", s1.UserID) + assert.Equal(t, "allow", s1.Decision) + // SQLite hands times back in time.Local; normalise to UTC so the instant is + // compared, not the (differing) *Location pointer. + assert.Equal(t, at(1), s1.StartedAt.UTC(), "started = earliest entry") + assert.Equal(t, at(2), s1.EndedAt.UTC(), "ended = latest entry") + assert.ElementsMatch(t, []string{"openai"}, s1.Providers) + assert.ElementsMatch(t, []string{"gpt-4o"}, s1.Models) + assert.ElementsMatch(t, []string{"grp-eng", "grp-oncall"}, s1.GroupIDs, "union of the entries' authorising groups") + + // Denied session rolls up to deny. + assert.Equal(t, "deny", sessions[1].Decision, "any denied request makes the session deny") + + // Pagination over sessions: 2 per page. + page1, total, err := s.GetAgentNetworkAccessLogSessions(ctx, LockingStrengthNone, accountID, + agentNetworkTypes.AgentNetworkAccessLogFilter{Page: 1, PageSize: 2}) + require.NoError(t, err) + assert.Equal(t, int64(4), total, "total still counts all sessions") + require.Len(t, page1, 2) + assert.Equal(t, "se-new", page1[0].Entries[0].ID) + assert.Equal(t, "s2", page1[1].SessionID) + + page2, _, err := s.GetAgentNetworkAccessLogSessions(ctx, LockingStrengthNone, accountID, + agentNetworkTypes.AgentNetworkAccessLogFilter{Page: 2, PageSize: 2}) + require.NoError(t, err) + require.Len(t, page2, 2) + assert.Equal(t, "s1", page2[0].SessionID) + assert.Equal(t, "se-old", page2[1].Entries[0].ID) + + // Model filter narrows to the session(s) with matching entries. + model := "claude-3" + filtered, fTotal, err := s.GetAgentNetworkAccessLogSessions(ctx, LockingStrengthNone, accountID, + agentNetworkTypes.AgentNetworkAccessLogFilter{Page: 1, PageSize: 50, Models: []string{model}}) + require.NoError(t, err) + assert.Equal(t, int64(1), fTotal, "only s2 has a claude-3 request") + require.Len(t, filtered, 1) + assert.Equal(t, "s2", filtered[0].SessionID) + + // Sort by total session cost, descending: s1 (0.30) leads despite not being + // the most recent. + byCost, _, err := s.GetAgentNetworkAccessLogSessions(ctx, LockingStrengthNone, accountID, + agentNetworkTypes.AgentNetworkAccessLogFilter{Page: 1, PageSize: 50, SortBy: "cost_usd", SortOrder: "desc"}) + require.NoError(t, err) + require.Len(t, byCost, 4) + assert.Equal(t, "s1", byCost[0].SessionID, "highest-cost session sorts first") +} + +// TestDeleteOldAgentNetworkAccessLogs verifies the retention sweep removes only +// access-log rows (and their group children) older than the cutoff, leaving +// recent rows — and never touching usage records. +func TestDeleteOldAgentNetworkAccessLogs(t *testing.T) { + ctx := context.Background() + s, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + const accountID = "acc-anet-retention-1" + old := time.Now().UTC().AddDate(0, 0, -40) + recent := time.Now().UTC().AddDate(0, 0, -1) + + mkLog := func(id string, ts time.Time) (*agentNetworkTypes.AgentNetworkAccessLog, []agentNetworkTypes.AgentNetworkAccessLogGroup) { + return &agentNetworkTypes.AgentNetworkAccessLog{ + ID: id, AccountID: accountID, ServiceID: "svc", Timestamp: ts, StatusCode: 200, Model: "gpt-4o", + }, []agentNetworkTypes.AgentNetworkAccessLogGroup{ + {LogID: id, GroupID: "grp-eng", AccountID: accountID}, + } + } + oldEntry, oldGroups := mkLog("old-1", old) + recentEntry, recentGroups := mkLog("recent-1", recent) + require.NoError(t, s.CreateAgentNetworkAccessLog(ctx, oldEntry, oldGroups)) + require.NoError(t, s.CreateAgentNetworkAccessLog(ctx, recentEntry, recentGroups)) + // A usage row for the old request must survive the access-log sweep. + require.NoError(t, s.CreateAgentNetworkUsage(ctx, &agentNetworkTypes.AgentNetworkUsage{ + ID: "old-1", AccountID: accountID, Timestamp: old, Model: "gpt-4o", InputTokens: 10, TotalTokens: 10, + }, nil)) + + cutoff := time.Now().UTC().AddDate(0, 0, -30) + deleted, err := s.DeleteOldAgentNetworkAccessLogs(ctx, accountID, cutoff) + require.NoError(t, err) + assert.Equal(t, int64(1), deleted, "only the 40-day-old log is deleted") + + logs, total, err := s.GetAgentNetworkAccessLogs(ctx, LockingStrengthNone, accountID, agentNetworkTypes.AgentNetworkAccessLogFilter{Page: 1, PageSize: 50}) + require.NoError(t, err) + assert.Equal(t, int64(1), total, "the recent log remains") + require.Len(t, logs, 1) + assert.Equal(t, "recent-1", logs[0].ID) + + // Usage is untouched by the access-log retention sweep. + usage, err := s.GetAgentNetworkUsageRows(ctx, LockingStrengthNone, accountID, agentNetworkTypes.AgentNetworkAccessLogFilter{}) + require.NoError(t, err) + require.Len(t, usage, 1, "usage record for the deleted log must survive") +} diff --git a/management/server/store/sql_store_agentnetwork_budgetrule_test.go b/management/server/store/sql_store_agentnetwork_budgetrule_test.go new file mode 100644 index 000000000..3bf7b797d --- /dev/null +++ b/management/server/store/sql_store_agentnetwork_budgetrule_test.go @@ -0,0 +1,112 @@ +package store + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" +) + +// TestAgentNetworkBudgetRule_RealStore_RoundTrip is the GC-0 no-mock guard: it +// drives the budget-rule CRUD through a real sqlite store and asserts the full +// object — targets and the reused PolicyLimits cap shape — survives the +// save → gorm/JSON serialize → reload round-trip, then that delete removes it +// and a second delete reports NotFound. +func TestAgentNetworkBudgetRule_RealStore_RoundTrip(t *testing.T) { + ctx := context.Background() + s, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + defer cleanup() + + const accountID = "acc-budgetrule-1" + rule := agentNetworkTypes.NewAccountBudgetRule(accountID) + rule.Name = "eng-monthly" + rule.TargetGroups = []string{"grp-eng", "grp-oncall"} + rule.TargetUsers = []string{"user-alice"} + rule.Limits = agentNetworkTypes.PolicyLimits{ + TokenLimit: agentNetworkTypes.PolicyTokenLimit{ + Enabled: true, GroupCap: 100_000, UserCap: 10_000, WindowSeconds: 2_592_000, + }, + BudgetLimit: agentNetworkTypes.PolicyBudgetLimit{ + Enabled: true, GroupCapUsd: 500, UserCapUsd: 50, WindowSeconds: 2_592_000, + }, + } + require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, rule), "save must succeed") + + got, err := s.GetAgentNetworkBudgetRuleByID(ctx, LockingStrengthNone, accountID, rule.ID) + require.NoError(t, err, "get by id must succeed after save") + assert.Equal(t, rule.Name, got.Name, "name must round-trip") + assert.Equal(t, []string{"grp-eng", "grp-oncall"}, got.TargetGroups, "target groups must round-trip") + assert.Equal(t, []string{"user-alice"}, got.TargetUsers, "target users must round-trip") + assert.Equal(t, rule.Limits, got.Limits, "the reused PolicyLimits cap shape must round-trip intact") + assert.True(t, got.Enabled, "enabled must round-trip") + + list, err := s.GetAccountAgentNetworkBudgetRules(ctx, LockingStrengthNone, accountID) + require.NoError(t, err, "list must succeed") + require.Len(t, list, 1, "exactly the one saved rule must be listed") + assert.Equal(t, rule.ID, list[0].ID, "listed rule id must match") + + require.NoError(t, s.DeleteAgentNetworkBudgetRule(ctx, accountID, rule.ID), "delete must succeed") + + _, err = s.GetAgentNetworkBudgetRuleByID(ctx, LockingStrengthNone, accountID, rule.ID) + assert.Error(t, err, "get after delete must report not found") + + err = s.DeleteAgentNetworkBudgetRule(ctx, accountID, rule.ID) + assert.Error(t, err, "deleting an absent rule must report not found") +} + +// TestAgentNetworkBudgetRule_RealStore_ScopedByAccount pins that rules are +// account-scoped: a rule under one account is invisible to another. +func TestAgentNetworkBudgetRule_RealStore_ScopedByAccount(t *testing.T) { + ctx := context.Background() + s, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + defer cleanup() + + ruleA := agentNetworkTypes.NewAccountBudgetRule("acc-A") + require.NoError(t, s.SaveAgentNetworkBudgetRule(ctx, ruleA)) + + list, err := s.GetAccountAgentNetworkBudgetRules(ctx, LockingStrengthNone, "acc-B") + require.NoError(t, err) + assert.Empty(t, list, "account B must not see account A's budget rule") + + _, err = s.GetAgentNetworkBudgetRuleByID(ctx, LockingStrengthNone, "acc-B", ruleA.ID) + assert.Error(t, err, "cross-account get by id must not resolve") +} + +// TestAgentNetworkSettings_RealStore_CollectionTogglesRoundTrip pins the GC-0 +// additive settings columns: the three collection toggles default off on a +// fresh row and survive a save/reload at their set values. +func TestAgentNetworkSettings_RealStore_CollectionTogglesRoundTrip(t *testing.T) { + ctx := context.Background() + s, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + defer cleanup() + + const accountID = "acc-settings-toggles" + require.NoError(t, s.SaveAgentNetworkSettings(ctx, &agentNetworkTypes.Settings{ + AccountID: accountID, + Cluster: "eu.proxy.netbird.io", + Subdomain: "violet", + })) + + got, err := s.GetAgentNetworkSettings(ctx, LockingStrengthNone, accountID) + require.NoError(t, err) + assert.False(t, got.EnableLogCollection, "log collection must default off") + assert.False(t, got.EnablePromptCollection, "prompt collection must default off") + assert.False(t, got.RedactPii, "redact pii must default off") + + got.EnableLogCollection = true + got.EnablePromptCollection = true + got.RedactPii = true + require.NoError(t, s.SaveAgentNetworkSettings(ctx, got)) + + reloaded, err := s.GetAgentNetworkSettings(ctx, LockingStrengthNone, accountID) + require.NoError(t, err) + assert.True(t, reloaded.EnableLogCollection, "log collection must round-trip on") + assert.True(t, reloaded.EnablePromptCollection, "prompt collection must round-trip on") + assert.True(t, reloaded.RedactPii, "redact pii must round-trip on") +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 066ab285d..908c199f5 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -37,6 +37,7 @@ import ( "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util/crypt" + agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" "github.com/netbirdio/netbird/management/server/migration" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" @@ -300,6 +301,12 @@ type Store interface { CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) + CreateAgentNetworkAccessLog(ctx context.Context, entry *agentNetworkTypes.AgentNetworkAccessLog, groups []agentNetworkTypes.AgentNetworkAccessLogGroup) error + CreateAgentNetworkUsage(ctx context.Context, usage *agentNetworkTypes.AgentNetworkUsage, groups []agentNetworkTypes.AgentNetworkUsageGroup) error + GetAgentNetworkAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkAccessLog, int64, error) + GetAgentNetworkAccessLogSessions(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkAccessLogSession, int64, error) + GetAgentNetworkUsageRows(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkUsage, error) + DeleteOldAgentNetworkAccessLogs(ctx context.Context, accountID string, olderThan time.Time) (int64, error) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error) DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error @@ -328,7 +335,40 @@ type Store interface { // return a zero-valued struct. GetProxyMetrics(ctx context.Context) (ProxyMetrics, error) + // GetAgentNetworkMetrics returns aggregated agent-network adoption + usage + // counts for the self-hosted metrics worker. Self-hosted only — file-based + // stores return a zero-valued struct. + GetAgentNetworkMetrics(ctx context.Context) (AgentNetworkMetrics, error) + GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) + + // Agent Network persistence (providers, policies, guardrails, settings). + GetAllAgentNetworkProviders(ctx context.Context, lockStrength LockingStrength) ([]*agentNetworkTypes.Provider, error) + GetAccountAgentNetworkProviders(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Provider, error) + GetAgentNetworkProviderByID(ctx context.Context, lockStrength LockingStrength, accountID, providerID string) (*agentNetworkTypes.Provider, error) + SaveAgentNetworkProvider(ctx context.Context, provider *agentNetworkTypes.Provider) error + DeleteAgentNetworkProvider(ctx context.Context, accountID, providerID string) error + GetAccountAgentNetworkPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Policy, error) + GetAgentNetworkPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*agentNetworkTypes.Policy, error) + SaveAgentNetworkPolicy(ctx context.Context, policy *agentNetworkTypes.Policy) error + DeleteAgentNetworkPolicy(ctx context.Context, accountID, policyID string) error + GetAccountAgentNetworkGuardrails(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Guardrail, error) + GetAgentNetworkGuardrailByID(ctx context.Context, lockStrength LockingStrength, accountID, guardrailID string) (*agentNetworkTypes.Guardrail, error) + SaveAgentNetworkGuardrail(ctx context.Context, guardrail *agentNetworkTypes.Guardrail) error + DeleteAgentNetworkGuardrail(ctx context.Context, accountID, guardrailID string) error + GetAgentNetworkSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*agentNetworkTypes.Settings, error) + GetAllAgentNetworkSettings(ctx context.Context, lockStrength LockingStrength) ([]*agentNetworkTypes.Settings, error) + GetAgentNetworkSettingsByCluster(ctx context.Context, lockStrength LockingStrength, cluster string) ([]*agentNetworkTypes.Settings, error) + SaveAgentNetworkSettings(ctx context.Context, settings *agentNetworkTypes.Settings) error + IncrementAgentNetworkConsumption(ctx context.Context, accountID string, kind agentNetworkTypes.ConsumptionDimension, dimID string, windowSeconds int64, windowStart time.Time, tokensIn, tokensOut int64, costUSD float64) error + IncrementAgentNetworkConsumptionBatch(ctx context.Context, accountID string, keys []agentNetworkTypes.ConsumptionKey, tokensIn, tokensOut int64, costUSD float64) error + GetAgentNetworkConsumption(ctx context.Context, lockStrength LockingStrength, accountID string, kind agentNetworkTypes.ConsumptionDimension, dimID string, windowSeconds int64, windowStart time.Time) (*agentNetworkTypes.Consumption, error) + GetAgentNetworkConsumptionBatch(ctx context.Context, lockStrength LockingStrength, accountID string, keys []agentNetworkTypes.ConsumptionKey) (map[agentNetworkTypes.ConsumptionKey]*agentNetworkTypes.Consumption, error) + ListAgentNetworkConsumption(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Consumption, error) + GetAccountAgentNetworkBudgetRules(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.AccountBudgetRule, error) + GetAgentNetworkBudgetRuleByID(ctx context.Context, lockStrength LockingStrength, accountID, ruleID string) (*agentNetworkTypes.AccountBudgetRule, error) + SaveAgentNetworkBudgetRule(ctx context.Context, rule *agentNetworkTypes.AccountBudgetRule) error + DeleteAgentNetworkBudgetRule(ctx context.Context, accountID, ruleID string) error } // ProxyMetrics aggregates self-hosted proxy + cluster usage signals @@ -355,6 +395,32 @@ type ProxyMetrics struct { ProxiesConnected int64 } +// AgentNetworkMetrics aggregates self-hosted agent-network adoption + usage +// signals surfaced to the telemetry payload. Each field is best-effort: when a +// store cannot answer (e.g. FileStore) all fields are zero. +type AgentNetworkMetrics struct { + // Accounts is the number of distinct accounts with at least one provider + // configured (agent-network adoption). + Accounts int64 + // Providers is the total number of configured providers across all accounts. + Providers int64 + // Policies is the total number of agent-network policies across all accounts. + Policies int64 + // BudgetRules is the total number of account-level budget rules ("budget + // limits") across all accounts. + BudgetRules int64 + // LogCollectionEnabled is the number of accounts that have agent-network + // log collection turned on. + LogCollectionEnabled int64 + // InputTokens / OutputTokens / CostUSD are summed over the always-collected + // per-request usage ledger (agent_network_request_usage), independent of the + // log-collection toggle. They reflect total metered LLM usage served through + // agent networks. + InputTokens int64 + OutputTokens int64 + CostUSD float64 +} + const ( postgresDsnEnv = "NB_STORE_ENGINE_POSTGRES_DSN" postgresDsnEnvLegacy = "NETBIRD_STORE_ENGINE_POSTGRES_DSN" diff --git a/management/server/store/store_mock_agentnetwork.go b/management/server/store/store_mock_agentnetwork.go new file mode 100644 index 000000000..18adf20f0 --- /dev/null +++ b/management/server/store/store_mock_agentnetwork.go @@ -0,0 +1,495 @@ +package store + +import ( + context "context" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + + agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" +) + +// GetAllAgentNetworkProviders mocks base method. +func (m *MockStore) GetAllAgentNetworkProviders(ctx context.Context, lockStrength LockingStrength) ([]*agentNetworkTypes.Provider, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllAgentNetworkProviders", ctx, lockStrength) + ret0, _ := ret[0].([]*agentNetworkTypes.Provider) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllAgentNetworkProviders indicates an expected call of GetAllAgentNetworkProviders. +func (mr *MockStoreMockRecorder) GetAllAgentNetworkProviders(ctx, lockStrength interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllAgentNetworkProviders", reflect.TypeOf((*MockStore)(nil).GetAllAgentNetworkProviders), ctx, lockStrength) +} + +// GetAgentNetworkMetrics mocks base method. +func (m *MockStore) GetAgentNetworkMetrics(ctx context.Context) (AgentNetworkMetrics, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAgentNetworkMetrics", ctx) + ret0, _ := ret[0].(AgentNetworkMetrics) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAgentNetworkMetrics indicates an expected call of GetAgentNetworkMetrics. +func (mr *MockStoreMockRecorder) GetAgentNetworkMetrics(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkMetrics", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkMetrics), ctx) +} + +// GetAccountAgentNetworkProviders mocks base method. +func (m *MockStore) GetAccountAgentNetworkProviders(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Provider, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountAgentNetworkProviders", ctx, lockStrength, accountID) + ret0, _ := ret[0].([]*agentNetworkTypes.Provider) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountAgentNetworkProviders indicates an expected call of GetAccountAgentNetworkProviders. +func (mr *MockStoreMockRecorder) GetAccountAgentNetworkProviders(ctx, lockStrength, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountAgentNetworkProviders", reflect.TypeOf((*MockStore)(nil).GetAccountAgentNetworkProviders), ctx, lockStrength, accountID) +} + +// GetAgentNetworkProviderByID mocks base method. +func (m *MockStore) GetAgentNetworkProviderByID(ctx context.Context, lockStrength LockingStrength, accountID, providerID string) (*agentNetworkTypes.Provider, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAgentNetworkProviderByID", ctx, lockStrength, accountID, providerID) + ret0, _ := ret[0].(*agentNetworkTypes.Provider) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAgentNetworkProviderByID indicates an expected call of GetAgentNetworkProviderByID. +func (mr *MockStoreMockRecorder) GetAgentNetworkProviderByID(ctx, lockStrength, accountID, providerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkProviderByID", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkProviderByID), ctx, lockStrength, accountID, providerID) +} + +// SaveAgentNetworkProvider mocks base method. +func (m *MockStore) SaveAgentNetworkProvider(ctx context.Context, provider *agentNetworkTypes.Provider) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveAgentNetworkProvider", ctx, provider) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveAgentNetworkProvider indicates an expected call of SaveAgentNetworkProvider. +func (mr *MockStoreMockRecorder) SaveAgentNetworkProvider(ctx, provider interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAgentNetworkProvider", reflect.TypeOf((*MockStore)(nil).SaveAgentNetworkProvider), ctx, provider) +} + +// DeleteAgentNetworkProvider mocks base method. +func (m *MockStore) DeleteAgentNetworkProvider(ctx context.Context, accountID, providerID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAgentNetworkProvider", ctx, accountID, providerID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAgentNetworkProvider indicates an expected call of DeleteAgentNetworkProvider. +func (mr *MockStoreMockRecorder) DeleteAgentNetworkProvider(ctx, accountID, providerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAgentNetworkProvider", reflect.TypeOf((*MockStore)(nil).DeleteAgentNetworkProvider), ctx, accountID, providerID) +} + +// GetAccountAgentNetworkPolicies mocks base method. +func (m *MockStore) GetAccountAgentNetworkPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Policy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountAgentNetworkPolicies", ctx, lockStrength, accountID) + ret0, _ := ret[0].([]*agentNetworkTypes.Policy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountAgentNetworkPolicies indicates an expected call of GetAccountAgentNetworkPolicies. +func (mr *MockStoreMockRecorder) GetAccountAgentNetworkPolicies(ctx, lockStrength, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountAgentNetworkPolicies", reflect.TypeOf((*MockStore)(nil).GetAccountAgentNetworkPolicies), ctx, lockStrength, accountID) +} + +// GetAgentNetworkPolicyByID mocks base method. +func (m *MockStore) GetAgentNetworkPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*agentNetworkTypes.Policy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAgentNetworkPolicyByID", ctx, lockStrength, accountID, policyID) + ret0, _ := ret[0].(*agentNetworkTypes.Policy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAgentNetworkPolicyByID indicates an expected call of GetAgentNetworkPolicyByID. +func (mr *MockStoreMockRecorder) GetAgentNetworkPolicyByID(ctx, lockStrength, accountID, policyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkPolicyByID", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkPolicyByID), ctx, lockStrength, accountID, policyID) +} + +// SaveAgentNetworkPolicy mocks base method. +func (m *MockStore) SaveAgentNetworkPolicy(ctx context.Context, policy *agentNetworkTypes.Policy) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveAgentNetworkPolicy", ctx, policy) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveAgentNetworkPolicy indicates an expected call of SaveAgentNetworkPolicy. +func (mr *MockStoreMockRecorder) SaveAgentNetworkPolicy(ctx, policy interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAgentNetworkPolicy", reflect.TypeOf((*MockStore)(nil).SaveAgentNetworkPolicy), ctx, policy) +} + +// DeleteAgentNetworkPolicy mocks base method. +func (m *MockStore) DeleteAgentNetworkPolicy(ctx context.Context, accountID, policyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAgentNetworkPolicy", ctx, accountID, policyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAgentNetworkPolicy indicates an expected call of DeleteAgentNetworkPolicy. +func (mr *MockStoreMockRecorder) DeleteAgentNetworkPolicy(ctx, accountID, policyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAgentNetworkPolicy", reflect.TypeOf((*MockStore)(nil).DeleteAgentNetworkPolicy), ctx, accountID, policyID) +} + +// GetAccountAgentNetworkGuardrails mocks base method. +func (m *MockStore) GetAccountAgentNetworkGuardrails(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Guardrail, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountAgentNetworkGuardrails", ctx, lockStrength, accountID) + ret0, _ := ret[0].([]*agentNetworkTypes.Guardrail) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountAgentNetworkGuardrails indicates an expected call of GetAccountAgentNetworkGuardrails. +func (mr *MockStoreMockRecorder) GetAccountAgentNetworkGuardrails(ctx, lockStrength, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountAgentNetworkGuardrails", reflect.TypeOf((*MockStore)(nil).GetAccountAgentNetworkGuardrails), ctx, lockStrength, accountID) +} + +// GetAgentNetworkGuardrailByID mocks base method. +func (m *MockStore) GetAgentNetworkGuardrailByID(ctx context.Context, lockStrength LockingStrength, accountID, guardrailID string) (*agentNetworkTypes.Guardrail, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAgentNetworkGuardrailByID", ctx, lockStrength, accountID, guardrailID) + ret0, _ := ret[0].(*agentNetworkTypes.Guardrail) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAgentNetworkGuardrailByID indicates an expected call of GetAgentNetworkGuardrailByID. +func (mr *MockStoreMockRecorder) GetAgentNetworkGuardrailByID(ctx, lockStrength, accountID, guardrailID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkGuardrailByID", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkGuardrailByID), ctx, lockStrength, accountID, guardrailID) +} + +// SaveAgentNetworkGuardrail mocks base method. +func (m *MockStore) SaveAgentNetworkGuardrail(ctx context.Context, guardrail *agentNetworkTypes.Guardrail) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveAgentNetworkGuardrail", ctx, guardrail) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveAgentNetworkGuardrail indicates an expected call of SaveAgentNetworkGuardrail. +func (mr *MockStoreMockRecorder) SaveAgentNetworkGuardrail(ctx, guardrail interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAgentNetworkGuardrail", reflect.TypeOf((*MockStore)(nil).SaveAgentNetworkGuardrail), ctx, guardrail) +} + +// DeleteAgentNetworkGuardrail mocks base method. +func (m *MockStore) DeleteAgentNetworkGuardrail(ctx context.Context, accountID, guardrailID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAgentNetworkGuardrail", ctx, accountID, guardrailID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAgentNetworkGuardrail indicates an expected call of DeleteAgentNetworkGuardrail. +func (mr *MockStoreMockRecorder) DeleteAgentNetworkGuardrail(ctx, accountID, guardrailID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAgentNetworkGuardrail", reflect.TypeOf((*MockStore)(nil).DeleteAgentNetworkGuardrail), ctx, accountID, guardrailID) +} + +// GetAgentNetworkSettings mocks base method. +func (m *MockStore) GetAgentNetworkSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*agentNetworkTypes.Settings, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAgentNetworkSettings", ctx, lockStrength, accountID) + ret0, _ := ret[0].(*agentNetworkTypes.Settings) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAgentNetworkSettings indicates an expected call of GetAgentNetworkSettings. +func (mr *MockStoreMockRecorder) GetAgentNetworkSettings(ctx, lockStrength, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkSettings", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkSettings), ctx, lockStrength, accountID) +} + +// GetAgentNetworkSettingsByCluster mocks base method. +func (m *MockStore) GetAgentNetworkSettingsByCluster(ctx context.Context, lockStrength LockingStrength, cluster string) ([]*agentNetworkTypes.Settings, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAgentNetworkSettingsByCluster", ctx, lockStrength, cluster) + ret0, _ := ret[0].([]*agentNetworkTypes.Settings) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAgentNetworkSettingsByCluster indicates an expected call of GetAgentNetworkSettingsByCluster. +func (mr *MockStoreMockRecorder) GetAgentNetworkSettingsByCluster(ctx, lockStrength, cluster interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkSettingsByCluster", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkSettingsByCluster), ctx, lockStrength, cluster) +} + +// SaveAgentNetworkSettings mocks base method. +func (m *MockStore) SaveAgentNetworkSettings(ctx context.Context, settings *agentNetworkTypes.Settings) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveAgentNetworkSettings", ctx, settings) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveAgentNetworkSettings indicates an expected call of SaveAgentNetworkSettings. +func (mr *MockStoreMockRecorder) SaveAgentNetworkSettings(ctx, settings interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAgentNetworkSettings", reflect.TypeOf((*MockStore)(nil).SaveAgentNetworkSettings), ctx, settings) +} + +// IncrementAgentNetworkConsumption mocks base method. +func (m *MockStore) IncrementAgentNetworkConsumption(ctx context.Context, accountID string, kind agentNetworkTypes.ConsumptionDimension, dimID string, windowSeconds int64, windowStart time.Time, tokensIn, tokensOut int64, costUSD float64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IncrementAgentNetworkConsumption", ctx, accountID, kind, dimID, windowSeconds, windowStart, tokensIn, tokensOut, costUSD) + ret0, _ := ret[0].(error) + return ret0 +} + +// IncrementAgentNetworkConsumption indicates an expected call of IncrementAgentNetworkConsumption. +func (mr *MockStoreMockRecorder) IncrementAgentNetworkConsumption(ctx, accountID, kind, dimID, windowSeconds, windowStart, tokensIn, tokensOut, costUSD interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementAgentNetworkConsumption", reflect.TypeOf((*MockStore)(nil).IncrementAgentNetworkConsumption), ctx, accountID, kind, dimID, windowSeconds, windowStart, tokensIn, tokensOut, costUSD) +} + +// GetAgentNetworkConsumption mocks base method. +func (m *MockStore) GetAgentNetworkConsumption(ctx context.Context, lockStrength LockingStrength, accountID string, kind agentNetworkTypes.ConsumptionDimension, dimID string, windowSeconds int64, windowStart time.Time) (*agentNetworkTypes.Consumption, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAgentNetworkConsumption", ctx, lockStrength, accountID, kind, dimID, windowSeconds, windowStart) + ret0, _ := ret[0].(*agentNetworkTypes.Consumption) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAgentNetworkConsumption indicates an expected call of GetAgentNetworkConsumption. +func (mr *MockStoreMockRecorder) GetAgentNetworkConsumption(ctx, lockStrength, accountID, kind, dimID, windowSeconds, windowStart interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkConsumption", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkConsumption), ctx, lockStrength, accountID, kind, dimID, windowSeconds, windowStart) +} + +// GetAgentNetworkConsumptionBatch mocks base method. +func (m *MockStore) GetAgentNetworkConsumptionBatch(ctx context.Context, lockStrength LockingStrength, accountID string, keys []agentNetworkTypes.ConsumptionKey) (map[agentNetworkTypes.ConsumptionKey]*agentNetworkTypes.Consumption, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAgentNetworkConsumptionBatch", ctx, lockStrength, accountID, keys) + ret0, _ := ret[0].(map[agentNetworkTypes.ConsumptionKey]*agentNetworkTypes.Consumption) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAgentNetworkConsumptionBatch indicates an expected call of GetAgentNetworkConsumptionBatch. +func (mr *MockStoreMockRecorder) GetAgentNetworkConsumptionBatch(ctx, lockStrength, accountID, keys interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkConsumptionBatch", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkConsumptionBatch), ctx, lockStrength, accountID, keys) +} + +// IncrementAgentNetworkConsumptionBatch mocks base method. +func (m *MockStore) IncrementAgentNetworkConsumptionBatch(ctx context.Context, accountID string, keys []agentNetworkTypes.ConsumptionKey, tokensIn, tokensOut int64, costUSD float64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IncrementAgentNetworkConsumptionBatch", ctx, accountID, keys, tokensIn, tokensOut, costUSD) + ret0, _ := ret[0].(error) + return ret0 +} + +// IncrementAgentNetworkConsumptionBatch indicates an expected call of IncrementAgentNetworkConsumptionBatch. +func (mr *MockStoreMockRecorder) IncrementAgentNetworkConsumptionBatch(ctx, accountID, keys, tokensIn, tokensOut, costUSD interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementAgentNetworkConsumptionBatch", reflect.TypeOf((*MockStore)(nil).IncrementAgentNetworkConsumptionBatch), ctx, accountID, keys, tokensIn, tokensOut, costUSD) +} + +// ListAgentNetworkConsumption mocks base method. +func (m *MockStore) ListAgentNetworkConsumption(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.Consumption, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListAgentNetworkConsumption", ctx, lockStrength, accountID) + ret0, _ := ret[0].([]*agentNetworkTypes.Consumption) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListAgentNetworkConsumption indicates an expected call of ListAgentNetworkConsumption. +func (mr *MockStoreMockRecorder) ListAgentNetworkConsumption(ctx, lockStrength, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAgentNetworkConsumption", reflect.TypeOf((*MockStore)(nil).ListAgentNetworkConsumption), ctx, lockStrength, accountID) +} + +// GetAccountAgentNetworkBudgetRules mocks base method. +func (m *MockStore) GetAccountAgentNetworkBudgetRules(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*agentNetworkTypes.AccountBudgetRule, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountAgentNetworkBudgetRules", ctx, lockStrength, accountID) + ret0, _ := ret[0].([]*agentNetworkTypes.AccountBudgetRule) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountAgentNetworkBudgetRules indicates an expected call of GetAccountAgentNetworkBudgetRules. +func (mr *MockStoreMockRecorder) GetAccountAgentNetworkBudgetRules(ctx, lockStrength, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountAgentNetworkBudgetRules", reflect.TypeOf((*MockStore)(nil).GetAccountAgentNetworkBudgetRules), ctx, lockStrength, accountID) +} + +// GetAgentNetworkBudgetRuleByID mocks base method. +func (m *MockStore) GetAgentNetworkBudgetRuleByID(ctx context.Context, lockStrength LockingStrength, accountID, ruleID string) (*agentNetworkTypes.AccountBudgetRule, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAgentNetworkBudgetRuleByID", ctx, lockStrength, accountID, ruleID) + ret0, _ := ret[0].(*agentNetworkTypes.AccountBudgetRule) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAgentNetworkBudgetRuleByID indicates an expected call of GetAgentNetworkBudgetRuleByID. +func (mr *MockStoreMockRecorder) GetAgentNetworkBudgetRuleByID(ctx, lockStrength, accountID, ruleID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkBudgetRuleByID", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkBudgetRuleByID), ctx, lockStrength, accountID, ruleID) +} + +// SaveAgentNetworkBudgetRule mocks base method. +func (m *MockStore) SaveAgentNetworkBudgetRule(ctx context.Context, rule *agentNetworkTypes.AccountBudgetRule) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveAgentNetworkBudgetRule", ctx, rule) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveAgentNetworkBudgetRule indicates an expected call of SaveAgentNetworkBudgetRule. +func (mr *MockStoreMockRecorder) SaveAgentNetworkBudgetRule(ctx, rule interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAgentNetworkBudgetRule", reflect.TypeOf((*MockStore)(nil).SaveAgentNetworkBudgetRule), ctx, rule) +} + +// DeleteAgentNetworkBudgetRule mocks base method. +func (m *MockStore) DeleteAgentNetworkBudgetRule(ctx context.Context, accountID, ruleID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAgentNetworkBudgetRule", ctx, accountID, ruleID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAgentNetworkBudgetRule indicates an expected call of DeleteAgentNetworkBudgetRule. +func (mr *MockStoreMockRecorder) DeleteAgentNetworkBudgetRule(ctx, accountID, ruleID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAgentNetworkBudgetRule", reflect.TypeOf((*MockStore)(nil).DeleteAgentNetworkBudgetRule), ctx, accountID, ruleID) +} + +// CreateAgentNetworkAccessLog mocks base method. +func (m *MockStore) CreateAgentNetworkAccessLog(ctx context.Context, entry *agentNetworkTypes.AgentNetworkAccessLog, groups []agentNetworkTypes.AgentNetworkAccessLogGroup) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateAgentNetworkAccessLog", ctx, entry, groups) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateAgentNetworkAccessLog indicates an expected call of CreateAgentNetworkAccessLog. +func (mr *MockStoreMockRecorder) CreateAgentNetworkAccessLog(ctx, entry, groups interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAgentNetworkAccessLog", reflect.TypeOf((*MockStore)(nil).CreateAgentNetworkAccessLog), ctx, entry, groups) +} + +// CreateAgentNetworkUsage mocks base method. +func (m *MockStore) CreateAgentNetworkUsage(ctx context.Context, usage *agentNetworkTypes.AgentNetworkUsage, groups []agentNetworkTypes.AgentNetworkUsageGroup) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateAgentNetworkUsage", ctx, usage, groups) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateAgentNetworkUsage indicates an expected call of CreateAgentNetworkUsage. +func (mr *MockStoreMockRecorder) CreateAgentNetworkUsage(ctx, usage, groups interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAgentNetworkUsage", reflect.TypeOf((*MockStore)(nil).CreateAgentNetworkUsage), ctx, usage, groups) +} + +// GetAgentNetworkAccessLogs mocks base method. +func (m *MockStore) GetAgentNetworkAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkAccessLog, int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAgentNetworkAccessLogs", ctx, lockStrength, accountID, filter) + ret0, _ := ret[0].([]*agentNetworkTypes.AgentNetworkAccessLog) + ret1, _ := ret[1].(int64) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetAgentNetworkAccessLogs indicates an expected call of GetAgentNetworkAccessLogs. +func (mr *MockStoreMockRecorder) GetAgentNetworkAccessLogs(ctx, lockStrength, accountID, filter interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkAccessLogs", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkAccessLogs), ctx, lockStrength, accountID, filter) +} + +// GetAgentNetworkAccessLogSessions mocks base method. +func (m *MockStore) GetAgentNetworkAccessLogSessions(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkAccessLogSession, int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAgentNetworkAccessLogSessions", ctx, lockStrength, accountID, filter) + ret0, _ := ret[0].([]*agentNetworkTypes.AgentNetworkAccessLogSession) + ret1, _ := ret[1].(int64) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetAgentNetworkAccessLogSessions indicates an expected call of GetAgentNetworkAccessLogSessions. +func (mr *MockStoreMockRecorder) GetAgentNetworkAccessLogSessions(ctx, lockStrength, accountID, filter interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkAccessLogSessions", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkAccessLogSessions), ctx, lockStrength, accountID, filter) +} + +// GetAgentNetworkUsageRows mocks base method. +func (m *MockStore) GetAgentNetworkUsageRows(ctx context.Context, lockStrength LockingStrength, accountID string, filter agentNetworkTypes.AgentNetworkAccessLogFilter) ([]*agentNetworkTypes.AgentNetworkUsage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAgentNetworkUsageRows", ctx, lockStrength, accountID, filter) + ret0, _ := ret[0].([]*agentNetworkTypes.AgentNetworkUsage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAgentNetworkUsageRows indicates an expected call of GetAgentNetworkUsageRows. +func (mr *MockStoreMockRecorder) GetAgentNetworkUsageRows(ctx, lockStrength, accountID, filter interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentNetworkUsageRows", reflect.TypeOf((*MockStore)(nil).GetAgentNetworkUsageRows), ctx, lockStrength, accountID, filter) +} + +// DeleteOldAgentNetworkAccessLogs mocks base method. +func (m *MockStore) DeleteOldAgentNetworkAccessLogs(ctx context.Context, accountID string, olderThan time.Time) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteOldAgentNetworkAccessLogs", ctx, accountID, olderThan) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteOldAgentNetworkAccessLogs indicates an expected call of DeleteOldAgentNetworkAccessLogs. +func (mr *MockStoreMockRecorder) DeleteOldAgentNetworkAccessLogs(ctx, accountID, olderThan interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldAgentNetworkAccessLogs", reflect.TypeOf((*MockStore)(nil).DeleteOldAgentNetworkAccessLogs), ctx, accountID, olderThan) +} + +// GetAllAgentNetworkSettings mocks base method. +func (m *MockStore) GetAllAgentNetworkSettings(ctx context.Context, lockStrength LockingStrength) ([]*agentNetworkTypes.Settings, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllAgentNetworkSettings", ctx, lockStrength) + ret0, _ := ret[0].([]*agentNetworkTypes.Settings) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllAgentNetworkSettings indicates an expected call of GetAllAgentNetworkSettings. +func (mr *MockStoreMockRecorder) GetAllAgentNetworkSettings(ctx, lockStrength interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllAgentNetworkSettings", reflect.TypeOf((*MockStore)(nil).GetAllAgentNetworkSettings), ctx, lockStrength) +} diff --git a/proxy/inbound.go b/proxy/inbound.go index d729ba9ae..e8f93fbe2 100644 --- a/proxy/inbound.go +++ b/proxy/inbound.go @@ -466,15 +466,20 @@ func feedRouterFromListener(ctx context.Context, ln net.Listener, router *nbtcp. _ = ln.Close() }() + var backoff nbtcp.AcceptBackoff for { conn, err := ln.Accept() if err != nil { - if ctx.Err() != nil || errors.Is(err, net.ErrClosed) { + if ctx.Err() != nil || nbtcp.IsClosedListenerErr(err) { + return + } + logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v; backing off", err) + if !backoff.Backoff(ctx) { return } - logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v", err) continue } + backoff.Reset() router.HandleConn(ctx, conn) } } diff --git a/proxy/inbound_test.go b/proxy/inbound_test.go index 584a04238..0e6081802 100644 --- a/proxy/inbound_test.go +++ b/proxy/inbound_test.go @@ -533,3 +533,125 @@ MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49 AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== -----END EC PRIVATE KEY-----`) + +// scriptedAcceptListener returns pre-scripted errors from Accept(). Used +// to drive the feedRouterFromListener tests without binding a real +// socket — the production code path is a netstack-backed listener that +// returns gVisor's "endpoint is in invalid state" forever after its +// endpoint is destroyed. +type scriptedAcceptListener struct { + errs chan error + closed chan struct{} +} + +func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener { + s := &scriptedAcceptListener{ + errs: make(chan error, len(errs)+1), + closed: make(chan struct{}), + } + for _, e := range errs { + s.errs <- e + } + return s +} + +func (s *scriptedAcceptListener) Accept() (net.Conn, error) { + select { + case <-s.closed: + return nil, net.ErrClosed + case err := <-s.errs: + return nil, err + } +} + +func (s *scriptedAcceptListener) Close() error { + select { + case <-s.closed: + default: + close(s.closed) + } + return nil +} + +func (s *scriptedAcceptListener) Addr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +// errSentinel carries a literal error message so tests can synthesise +// the exact gVisor text without importing the netstack package. +type errSentinel string + +func (e errSentinel) Error() string { return string(e) } + +// TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint is the +// regression guard for the inbound side of the tight-loop bug. The +// per-account plain-HTTP feeder must recognise gVisor's "endpoint is in +// invalid state" and exit, otherwise it pegs a CPU core and floods the +// account-scoped log with the same accept error every iteration. +func TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80} + router := nbtcp.NewRouter(logger, nil, addr) + + gvisorErr := &net.OpError{ + Op: "accept", + Net: "tcp", + Addr: addr, + Err: errSentinel("endpoint is in invalid state"), + } + ln := newScriptedAcceptListener(gvisorErr) + defer ln.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + feedRouterFromListener(context.Background(), ln, router, logger, "acct-1") + }() + + select { + case <-done: + // Expected: loop recognised the gVisor error and returned. + case <-time.After(2 * time.Second): + t.Fatal("feedRouterFromListener did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning") + } +} + +// TestFeedRouterFromListener_BacksOffOnTransientError asserts the +// defence-in-depth path: an unknown sticky Accept error must NOT cause +// CPU spin. The loop backs off and exits cleanly when ctx is cancelled. +func TestFeedRouterFromListener_BacksOffOnTransientError(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80} + router := nbtcp.NewRouter(logger, nil, addr) + + const transientCount = 5 + errs := make([]error, transientCount) + for i := range errs { + errs[i] = errSentinel("transient: temporary network error") + } + ln := newScriptedAcceptListener(errs...) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + start := time.Now() + done := make(chan struct{}) + go func() { + defer close(done) + feedRouterFromListener(ctx, ln, router, logger, "acct-1") + }() + time.AfterFunc(150*time.Millisecond, cancel) + + select { + case <-done: + // Expected. + case <-time.After(2 * time.Second): + t.Fatal("feedRouterFromListener did not exit on ctx cancellation — backoff or exit path broken") + } + + // Without backoff the 5 scripted errors would burn in microseconds. + // With backoff the first delay alone is 5ms, so the loop must take + // at least that long even though ctx fires at 150ms. + elapsed := time.Since(start) + assert.GreaterOrEqual(t, elapsed, 5*time.Millisecond, + "loop ran without backing off — would burn CPU in production") +} diff --git a/proxy/internal/accesslog/logger.go b/proxy/internal/accesslog/logger.go index 3283f61db..db868b4e0 100644 --- a/proxy/internal/accesslog/logger.go +++ b/proxy/internal/accesslog/logger.go @@ -128,6 +128,7 @@ type logEntry struct { BytesDownload int64 Protocol Protocol Metadata map[string]string + AgentNetwork bool } // Protocol identifies the transport protocol of an access log entry. @@ -214,6 +215,54 @@ func (l *Logger) allowDenyLog(serviceID types.ServiceID, reason string) bool { return false } +// usageMetadataKeys is the allowlist of metadata retained on a stripped, +// usage-only agent-network entry. Mirrors the llm.* / cost.* keys in +// proxy/internal/middleware/keys.go — only the dimensions management needs to +// record a usage row (provider / model / tokens / cost / groups). +var usageMetadataKeys = map[string]struct{}{ + "llm.provider": {}, + "llm.model": {}, + "llm.resolved_provider_id": {}, + "llm.input_tokens": {}, + "llm.output_tokens": {}, + "llm.total_tokens": {}, + "cost.usd_total": {}, + "llm.authorising_groups": {}, +} + +// stripAgentNetworkEntryForUsage returns the entry reduced to what's needed to +// record usage/cost: it drops request detail (host / path / source IP) and any +// prompt capture, keeping the LLM usage metadata plus the caller identity +// (user / auth mechanism) needed for attribution. Shipped when an +// agent-network account has log collection disabled but usage must still be +// collected. logEntry is passed by value, so mutating it here is safe; Metadata +// is replaced with a fresh map rather than mutated in place. +func stripAgentNetworkEntryForUsage(entry logEntry) logEntry { + entry.Host = "" + entry.Path = "" + entry.SourceIP = netip.Addr{} + // Drop the rest of the per-request telemetry too — a usage-only entry + // must carry the LLM usage metadata and caller identity, nothing that + // describes the individual request. + entry.Method = "" + entry.ResponseCode = 0 + entry.DurationMs = 0 + entry.BytesUpload = 0 + entry.BytesDownload = 0 + entry.Protocol = "" + + if len(entry.Metadata) > 0 { + stripped := make(map[string]string, len(usageMetadataKeys)) + for k := range usageMetadataKeys { + if v, ok := entry.Metadata[k]; ok { + stripped[k] = v + } + } + entry.Metadata = stripped + } + return entry +} + func (l *Logger) log(entry logEntry) { // Fire off the log request in a separate routine. // This increases the possibility of losing a log message @@ -264,6 +313,7 @@ func (l *Logger) log(entry logEntry) { BytesDownload: entry.BytesDownload, Protocol: string(entry.Protocol), Metadata: entry.Metadata, + AgentNetwork: entry.AgentNetwork, }, }); err != nil { l.logger.WithFields(log.Fields{ diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index 5a0684c19..9c644418e 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -83,11 +83,23 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { BytesDownload: bytesDownload, Protocol: ProtocolHTTP, Metadata: capturedData.GetMetadata(), + AgentNetwork: capturedData.GetAgentNetwork(), } l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s", requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceID(), capturedData.GetAccountID()) - l.log(entry) + // Emit the access log unless the matched target opted out + // (agent-network synth targets do this when the account's + // EnableLogCollection toggle is off). For agent-network entries we + // still ship a stripped, usage-only record even when suppressed, so + // usage/cost is collected regardless of the log-collection toggle; + // request detail and prompt capture are dropped before sending. + switch { + case !capturedData.GetSuppressAccessLog(): + l.log(entry) + case entry.AgentNetwork: + l.log(stripAgentNetworkEntryForUsage(entry)) + } // Track usage for cost monitoring (upload + download) by domain l.trackUsage(host, bytesUpload+bytesDownload) diff --git a/proxy/internal/accesslog/middleware_test.go b/proxy/internal/accesslog/middleware_test.go new file mode 100644 index 000000000..cf91957a8 --- /dev/null +++ b/proxy/internal/accesslog/middleware_test.go @@ -0,0 +1,185 @@ +package accesslog + +import ( + "context" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/proxy/internal/proxy" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// recorderClient is a minimal stub for the access-log gRPCClient interface. It +// counts SendAccessLog invocations and signals on every call so tests can +// deterministically wait for the goroutine inside Logger.log without sleeping. +type recorderClient struct { + mu sync.Mutex + calls int64 + lastEntry *proto.AccessLog + called chan struct{} +} + +func newRecorderClient() *recorderClient { + return &recorderClient{called: make(chan struct{}, 16)} +} + +func (r *recorderClient) SendAccessLog(_ context.Context, in *proto.SendAccessLogRequest, _ ...grpc.CallOption) (*proto.SendAccessLogResponse, error) { + r.mu.Lock() + r.calls++ + r.lastEntry = in.GetLog() + r.mu.Unlock() + select { + case r.called <- struct{}{}: + default: + } + return &proto.SendAccessLogResponse{}, nil +} + +func (r *recorderClient) callCount() int64 { + r.mu.Lock() + defer r.mu.Unlock() + return r.calls +} + +// newTestLogger builds a Logger backed by the supplied recorderClient. It is +// the same constructor production uses, just with a stub gRPC client — no +// mocks, no interface re-implementations. +func newTestLogger(t *testing.T, client *recorderClient) *Logger { + t.Helper() + logger := NewLogger(client, nil, nil) + t.Cleanup(logger.Close) + return logger +} + +// TestMiddleware_SuppressAccessLog_SkipsLogSink asserts the suppression gate. +// When the inner handler stamps SuppressAccessLog=true on CapturedData (mirrors +// what reverseproxy does when the matched target's DisableAccessLog flag is +// set), the middleware must NOT invoke the access-log sink. Bandwidth telemetry +// (trackUsage) keeps running — it's the call to SendAccessLog that we gate. +func TestMiddleware_SuppressAccessLog_SkipsLogSink(t *testing.T) { + client := newRecorderClient() + l := newTestLogger(t, client) + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cd := proxy.CapturedDataFromContext(r.Context()) + require.NotNil(t, cd, "middleware must inject CapturedData into the request context") + cd.SetSuppressAccessLog(true) + w.WriteHeader(http.StatusOK) + }) + + srv := httptest.NewServer(l.Middleware(inner)) + t.Cleanup(srv.Close) + + resp, err := http.Get(srv.URL + "/agent-network/v1/chat/completions") + require.NoError(t, err, "GET against suppressed target must succeed") + require.NoError(t, resp.Body.Close()) + require.Equal(t, http.StatusOK, resp.StatusCode, "inner handler must run normally") + + // Give the goroutine fence a beat (Logger.log dispatches in a goroutine). + // The negative assertion needs a small window: if a send is going to + // happen, it happens promptly. + select { + case <-client.called: + t.Fatalf("access-log sink must not be invoked when SuppressAccessLog=true (got %d call(s))", client.callCount()) + case <-time.After(150 * time.Millisecond): + } + + assert.Equal(t, int64(0), client.callCount(), + "SendAccessLog must not be called for suppressed requests") +} + +// TestMiddleware_SuppressAccessLog_DefaultEmitsLog is the regression sanity: +// when nothing sets SuppressAccessLog (the universal default for every +// non-agent-network target), the middleware MUST still emit the access-log +// entry. This is the guarantee that wires-through to the EnableLogCollection +// gate without breaking anyone who isn't opted in. +func TestMiddleware_SuppressAccessLog_DefaultEmitsLog(t *testing.T) { + client := newRecorderClient() + l := newTestLogger(t, client) + + var innerRan atomic.Bool + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + innerRan.Store(true) + // Intentionally DO NOT touch SuppressAccessLog — mirrors every + // non-agent-network target. + w.WriteHeader(http.StatusOK) + }) + + srv := httptest.NewServer(l.Middleware(inner)) + t.Cleanup(srv.Close) + + resp, err := http.Get(srv.URL + "/service/healthz") + require.NoError(t, err, "GET against default target must succeed") + require.NoError(t, resp.Body.Close()) + require.Equal(t, http.StatusOK, resp.StatusCode, "inner handler must run normally") + require.True(t, innerRan.Load(), "inner handler must have run") + + select { + case <-client.called: + case <-time.After(2 * time.Second): + t.Fatalf("SendAccessLog must be invoked for non-suppressed requests, none observed (calls=%d)", client.callCount()) + } + + assert.Equal(t, int64(1), client.callCount(), + "non-suppressed request must produce exactly one access-log send") +} + +// TestMiddleware_SuppressAccessLog_PreservesUsageTracking proves the gate is +// surgical: with SuppressAccessLog=true the access-log send is skipped, but +// the per-domain usage tracker still records the bytes transferred. This is +// the cost-monitoring guarantee called out in the gate's comment. +func TestMiddleware_SuppressAccessLog_PreservesUsageTracking(t *testing.T) { + client := newRecorderClient() + l := newTestLogger(t, client) + + payload := []byte("ok") + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cd := proxy.CapturedDataFromContext(r.Context()) + require.NotNil(t, cd, "middleware must inject CapturedData") + cd.SetSuppressAccessLog(true) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(payload) + }) + + srv := httptest.NewServer(l.Middleware(inner)) + t.Cleanup(srv.Close) + + resp, err := http.Get(srv.URL + "/agent-network/v1/chat/completions") + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + // Allow trackUsage to land — it runs synchronously after l.log(entry) is + // (would have been) called. + time.Sleep(50 * time.Millisecond) + + l.usageMux.Lock() + usage, present := l.domainUsage[hostNoPort(srv.URL)] + l.usageMux.Unlock() + require.True(t, present, "domain usage must be tracked even when the access-log is suppressed") + assert.Greater(t, usage.bytesTransferred, int64(0), "bytesTransferred must include the response payload") + assert.Equal(t, int64(0), client.callCount(), + "SendAccessLog must remain suppressed across the response write") +} + +// hostNoPort extracts the host name from an httptest server URL. The +// middleware strips the port before keying domain usage, so the test mirrors +// that to look the entry up. +func hostNoPort(url string) string { + // httptest URLs are always "http://127.0.0.1:PORT". + const prefix = "http://" + host := url[len(prefix):] + for i := 0; i < len(host); i++ { + if host[i] == ':' || host[i] == '/' { + return host[:i] + } + } + return host +} diff --git a/proxy/internal/auth/middleware_test.go b/proxy/internal/auth/middleware_test.go index c0ec5c94c..6608c2b22 100644 --- a/proxy/internal/auth/middleware_test.go +++ b/proxy/internal/auth/middleware_test.go @@ -297,6 +297,109 @@ func TestProtect_SessionCookieGroupsPropagate(t *testing.T) { assert.Equal(t, groups, capturedData.GetUserGroups(), "CapturedData groups must be retained after handler completes") } +// stubTunnelValidator implements SessionValidator for the tunnel-peer +// path. ValidateTunnelPeer returns a fixed response so tests can assert +// how the proxy maps it onto CapturedData, and records whether the +// fast-path actually reached management. +type stubTunnelValidator struct { + called bool + resp *proto.ValidateTunnelPeerResponse +} + +func (s *stubTunnelValidator) ValidateSession(context.Context, *proto.ValidateSessionRequest, ...grpc.CallOption) (*proto.ValidateSessionResponse, error) { + return nil, errors.New("not used in this test") +} + +func (s *stubTunnelValidator) ValidateTunnelPeer(context.Context, *proto.ValidateTunnelPeerRequest, ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error) { + s.called = true + return s.resp, nil +} + +// TestProtect_PrivateService_TunnelPeerGroupsPropagate locks the agent-network +// auth path end-to-end at the proxy edge: a Private service must route through +// ValidateTunnelPeer and lift the returned peer_group_ids onto CapturedData so +// the llm_router group-authorisation pass can see them. Regression guard for +// the failure that surfaces downstream as llm_policy.no_authorised_provider — +// i.e. a synthesised service that reaches the proxy without private=true (so +// this path is skipped) leaves UserGroups empty and every request is denied. +func TestProtect_PrivateService_TunnelPeerGroupsPropagate(t *testing.T) { + groups := []string{"grp-admins", "grp-users"} + names := []string{"Admins", "Users"} + validator := &stubTunnelValidator{resp: &proto.ValidateTunnelPeerResponse{ + Valid: true, + UserId: "user-1", + UserEmail: "user@example.com", + SessionToken: "tunnel-session-token", + PeerGroupIds: groups, + PeerGroupNames: names, + }} + mw := NewMiddleware(log.StandardLogger(), validator, nil) + kp := generateTestKeyPair(t) + + // Private service: no operator schemes — auth gates solely on the tunnel peer. + require.NoError(t, mw.AddDomain("agent.example.com", nil, kp.PublicKey, time.Hour, "acct-1", "svc-1", nil, true)) + + cd := proxy.NewCapturedData("") + cd.SetClientIP(netip.MustParseAddr("100.90.1.14")) // CGNAT tunnel source + + var seenGroups []string + var seenUser string + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c := proxy.CapturedDataFromContext(r.Context()) + require.NotNil(t, c, "captured data must be present in request context") + seenGroups = c.GetUserGroups() + seenUser = c.GetUserID() + w.WriteHeader(http.StatusOK) + })) + + lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) { + return PeerIdentity{}, true + }) + req := httptest.NewRequest(http.MethodPost, "http://agent.example.com/v1/chat/completions", nil) + req.RemoteAddr = "100.90.1.14:5000" + req = req.WithContext(WithTunnelLookup(proxy.WithCapturedData(req.Context(), cd), lookup)) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code, "private service must authorise a tunnel peer the validator accepts") + assert.Equal(t, groups, seenGroups, "ValidateTunnelPeer peer_group_ids must reach CapturedData.UserGroups for llm_router authorisation") + assert.Equal(t, "user-1", seenUser, "tunnel-peer principal must reach CapturedData") + assert.Equal(t, groups, cd.GetUserGroups(), "groups must persist on CapturedData after the handler returns") +} + +// TestProtect_PrivateService_TunnelPeerDenied verifies the deny path: when +// ValidateTunnelPeer rejects the peer, a Private service 403s and never reaches +// the upstream handler (no fall-through to unauthenticated pass-through). +func TestProtect_PrivateService_TunnelPeerDenied(t *testing.T) { + validator := &stubTunnelValidator{resp: &proto.ValidateTunnelPeerResponse{ + Valid: false, + DeniedReason: "not_in_group", + }} + mw := NewMiddleware(log.StandardLogger(), validator, nil) + kp := generateTestKeyPair(t) + require.NoError(t, mw.AddDomain("agent.example.com", nil, kp.PublicKey, time.Hour, "acct-1", "svc-1", nil, true)) + + cd := proxy.NewCapturedData("") + cd.SetClientIP(netip.MustParseAddr("100.90.1.14")) + + reached := false + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reached = true + w.WriteHeader(http.StatusOK) + })) + lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) { + return PeerIdentity{}, true + }) + req := httptest.NewRequest(http.MethodPost, "http://agent.example.com/v1/chat/completions", nil) + req.RemoteAddr = "100.90.1.14:5000" + req = req.WithContext(WithTunnelLookup(proxy.WithCapturedData(req.Context(), cd), lookup)) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusForbidden, rec.Code, "private service must 403 when the tunnel peer is rejected") + assert.False(t, reached, "denied private request must not reach the upstream handler") +} + func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) @@ -1228,22 +1331,6 @@ func TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked(t *testing.T) { assert.Equal(t, http.StatusUnauthorized, rec.Code, "PIN-only domain should serve the login page on plain HTTP") } -// stubTunnelValidator records ValidateTunnelPeer calls so a test can -// assert whether the fast-path reached management. -type stubTunnelValidator struct { - called bool - resp *proto.ValidateTunnelPeerResponse -} - -func (s *stubTunnelValidator) ValidateSession(context.Context, *proto.ValidateSessionRequest, ...grpc.CallOption) (*proto.ValidateSessionResponse, error) { - return nil, errors.New("not used in this test") -} - -func (s *stubTunnelValidator) ValidateTunnelPeer(context.Context, *proto.ValidateTunnelPeerRequest, ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error) { - s.called = true - return s.resp, nil -} - // TestProtect_TunnelPeerFastPath_RequiresInboundMarker guards the // anti-spoof gate: a request with an RFC1918 source IP arriving on the // public listener (no TunnelLookupFromContext attached) must not be diff --git a/proxy/internal/llm/anthropic.go b/proxy/internal/llm/anthropic.go new file mode 100644 index 000000000..523731fbd --- /dev/null +++ b/proxy/internal/llm/anthropic.go @@ -0,0 +1,196 @@ +package llm + +import ( + "encoding/json" + "fmt" + "strings" +) + +// AnthropicParser implements the Parser interface for the Anthropic Messages +// and Completions APIs. Detection is substring-based to tolerate upstream +// path rewrites. +type AnthropicParser struct{} + +var anthropicPathHints = []string{ + "/v1/messages", + "/v1/complete", +} + +// Provider returns ProviderAnthropic. +func (AnthropicParser) Provider() Provider { return ProviderAnthropic } + +// ProviderName returns the stable label used for metrics and metadata. +func (AnthropicParser) ProviderName() string { return "anthropic" } + +// DetectFromURL reports whether the given request path looks like an +// Anthropic API endpoint. The match is case-insensitive and substring-based. +func (AnthropicParser) DetectFromURL(path string) bool { + lower := strings.ToLower(path) + for _, hint := range anthropicPathHints { + if strings.Contains(lower, hint) { + return true + } + } + return false +} + +type anthropicRequest struct { + Model string `json:"model"` + Stream *bool `json:"stream"` + System json.RawMessage `json:"system"` + Messages []anthropicMessage `json:"messages"` + // Legacy /v1/complete endpoint. + Prompt string `json:"prompt"` +} + +type anthropicMessage struct { + Role string `json:"role"` + Content json.RawMessage `json:"content"` +} + +// ParseRequest extracts the model name and streaming flag from an Anthropic +// request body. Unknown or missing fields leave the corresponding struct +// members zero-valued. +func (AnthropicParser) ParseRequest(body []byte) (RequestFacts, error) { + var req anthropicRequest + if err := json.Unmarshal(body, &req); err != nil { + return RequestFacts{}, fmt.Errorf("decode anthropic request: %w: %v", ErrMalformedRequest, err) + } + return RequestFacts{ + Model: req.Model, + Stream: ptrDeref(req.Stream), + }, nil +} + +type anthropicResponse struct { + Usage struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + // CacheReadInputTokens and CacheCreationInputTokens are + // ADDITIVE to InputTokens (not subset), each billed at its + // own rate by the cost meter. cache_read is the cheaper + // read-from-cache rate, cache_creation is the more + // expensive write-to-cache rate. + CacheReadInputTokens int64 `json:"cache_read_input_tokens"` + CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"` + } `json:"usage"` +} + +// ParseResponse decodes the non-streaming Anthropic response envelope. Status +// codes other than 200 are treated as non-LLM responses so the caller can +// skip cost accounting without aborting the request. +func (AnthropicParser) ParseResponse(status int, contentType string, body []byte) (Usage, error) { + if status != 200 { + return Usage{}, fmt.Errorf("anthropic status %d: %w", status, ErrNotLLMResponse) + } + if isEventStream(contentType) { + return Usage{}, ErrStreamingUnsupported + } + if !isJSON(contentType) { + return Usage{}, fmt.Errorf("anthropic content-type %q: %w", contentType, ErrNotLLMResponse) + } + + var resp anthropicResponse + if err := json.Unmarshal(body, &resp); err != nil { + return Usage{}, fmt.Errorf("decode anthropic response: %w: %v", ErrMalformedResponse, err) + } + return Usage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens + resp.Usage.CacheReadInputTokens + resp.Usage.CacheCreationInputTokens, + CachedInputTokens: resp.Usage.CacheReadInputTokens, + CacheCreationTokens: resp.Usage.CacheCreationInputTokens, + }, nil +} + +// ExtractPrompt returns the user-visible prompt text from an Anthropic +// request body. Handles the Messages API (system + messages[]) and the +// legacy /v1/complete prompt string. Returns "" on any decode failure. +func (AnthropicParser) ExtractPrompt(body []byte) string { + var req anthropicRequest + if err := json.Unmarshal(body, &req); err != nil { + return "" + } + var b strings.Builder + if len(req.System) > 0 { + if s := decodeStringOrJoin(req.System); s != "" { + b.WriteString("system: ") + b.WriteString(s) + } + } + for _, m := range req.Messages { + if b.Len() > 0 { + b.WriteByte('\n') + } + if m.Role != "" { + b.WriteString(m.Role) + b.WriteString(": ") + } + b.WriteString(decodeStringOrJoin(m.Content)) + } + if b.Len() == 0 && req.Prompt != "" { + b.WriteString(req.Prompt) + } + return b.String() +} + +// ExtractSessionID is the body-side fallback for Anthropic. Claude Code's +// authoritative session marker is the X-Claude-Code-Session-Id request +// header (handled by the request-parser middleware); this only mines the +// optional metadata.user_id for an embedded "...session_" marker. +// metadata.user_id on its own is a USER identifier, not a session, so the +// whole value is deliberately NOT used — returning it would mislabel every +// request from a user as one session. Returns "" when no session marker is +// present. +func (AnthropicParser) ExtractSessionID(body []byte) string { + var req struct { + Metadata struct { + UserID string `json:"user_id"` + } `json:"metadata"` + } + if err := json.Unmarshal(body, &req); err != nil { + return "" + } + if idx := strings.LastIndex(req.Metadata.UserID, "session_"); idx >= 0 { + if session := req.Metadata.UserID[idx+len("session_"):]; session != "" { + return session + } + } + return "" +} + +type anthropicMessageResponse struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + // Legacy /v1/complete response. + Completion string `json:"completion"` +} + +// ExtractCompletion returns the assistant text from a non-streaming Anthropic +// Messages or Completions response. Returns "" when status/content-type +// indicate the body is not parseable or no text part is present. +func (AnthropicParser) ExtractCompletion(status int, contentType string, body []byte) string { + if status != 200 || isEventStream(contentType) || !isJSON(contentType) { + return "" + } + var resp anthropicMessageResponse + if err := json.Unmarshal(body, &resp); err != nil { + return "" + } + var b strings.Builder + for _, part := range resp.Content { + if part.Text == "" { + continue + } + if b.Len() > 0 { + b.WriteByte('\n') + } + b.WriteString(part.Text) + } + if b.Len() == 0 { + return resp.Completion + } + return b.String() +} diff --git a/proxy/internal/llm/anthropic_test.go b/proxy/internal/llm/anthropic_test.go new file mode 100644 index 000000000..a0c1f7896 --- /dev/null +++ b/proxy/internal/llm/anthropic_test.go @@ -0,0 +1,169 @@ +package llm + +import ( + "errors" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAnthropicDetectFromURL(t *testing.T) { + p := AnthropicParser{} + + cases := map[string]bool{ + "/v1/messages": true, + "/v1/complete": true, + "/V1/Messages": true, + "/proxy/v1/messages?x": true, + "/v1/chat/completions": false, + "": false, + } + for path, want := range cases { + assert.Equal(t, want, p.DetectFromURL(path), "DetectFromURL(%q)", path) + } +} + +func TestAnthropicParseRequest(t *testing.T) { + p := AnthropicParser{} + + t.Run("stream true", func(t *testing.T) { + facts, err := p.ParseRequest([]byte(`{"model":"claude-sonnet-4-5","stream":true}`)) + require.NoError(t, err) + assert.Equal(t, "claude-sonnet-4-5", facts.Model, "model extracted") + assert.True(t, facts.Stream, "stream flag honoured") + }) + + t.Run("stream default", func(t *testing.T) { + facts, err := p.ParseRequest([]byte(`{"model":"claude-sonnet-4-5"}`)) + require.NoError(t, err) + assert.False(t, facts.Stream, "missing stream flag defaults to false") + }) + + t.Run("malformed", func(t *testing.T) { + _, err := p.ParseRequest([]byte(`{"model":`)) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrMalformedRequest), "sentinel wrapped") + }) +} + +func TestAnthropicParseResponse(t *testing.T) { + p := AnthropicParser{} + + t.Run("happy fixture", func(t *testing.T) { + body, err := os.ReadFile(filepath.Join("fixtures", "anthropic_messages.json")) + require.NoError(t, err, "fixture must be readable") + + usage, err := p.ParseResponse(200, "application/json", body) + require.NoError(t, err) + assert.Equal(t, int64(123), usage.InputTokens, "input tokens extracted") + assert.Equal(t, int64(45), usage.OutputTokens, "output tokens extracted") + assert.Equal(t, int64(168), usage.TotalTokens, "total computed as sum") + }) + + t.Run("streaming rejected", func(t *testing.T) { + _, err := p.ParseResponse(200, "text/event-stream", []byte("")) + require.ErrorIs(t, err, ErrStreamingUnsupported, "SSE responses must use the scanner") + }) + + t.Run("non-200", func(t *testing.T) { + _, err := p.ParseResponse(429, "application/json", []byte(`{}`)) + require.ErrorIs(t, err, ErrNotLLMResponse, "non-200 rejected as non-LLM") + }) + + t.Run("non-json content type", func(t *testing.T) { + _, err := p.ParseResponse(200, "text/html", []byte(`{}`)) + require.ErrorIs(t, err, ErrNotLLMResponse, "text/html treated as non-LLM") + }) + + t.Run("malformed body", func(t *testing.T) { + _, err := p.ParseResponse(200, "application/json", []byte(`{`)) + require.ErrorIs(t, err, ErrMalformedResponse, "bad JSON yields malformed error") + }) + + // Anthropic's two cache fields are ADDITIVE to input_tokens (not + // subset). The parser must surface them so the cost meter can + // bill each bucket at its own configured rate. Total includes + // every bucket so downstream attribution sees the full token + // volume the request consumed. + t.Run("cache_read_input_tokens surfaces as CachedInputTokens (additive)", func(t *testing.T) { + body := []byte(`{"usage":{"input_tokens":256,"output_tokens":200,"cache_read_input_tokens":768}}`) + usage, err := p.ParseResponse(200, "application/json", body) + require.NoError(t, err) + assert.Equal(t, int64(256), usage.InputTokens, "regular input remains separate from cache buckets") + assert.Equal(t, int64(768), usage.CachedInputTokens, "cache_read maps onto CachedInputTokens — same field carries OpenAI cached subset and Anthropic cache reads") + assert.Zero(t, usage.CacheCreationTokens) + assert.Equal(t, int64(256+200+768), usage.TotalTokens, "total includes every input bucket plus output — cache reads are billable tokens") + }) + + t.Run("cache_creation_input_tokens surfaces as CacheCreationTokens (additive)", func(t *testing.T) { + body := []byte(`{"usage":{"input_tokens":256,"output_tokens":200,"cache_creation_input_tokens":512}}`) + usage, err := p.ParseResponse(200, "application/json", body) + require.NoError(t, err) + assert.Equal(t, int64(256), usage.InputTokens) + assert.Zero(t, usage.CachedInputTokens) + assert.Equal(t, int64(512), usage.CacheCreationTokens, "cache_creation surfaces — meter applies the write-rate multiplier") + assert.Equal(t, int64(256+200+512), usage.TotalTokens) + }) + + t.Run("both cache buckets present", func(t *testing.T) { + body := []byte(`{"usage":{"input_tokens":256,"output_tokens":200,"cache_read_input_tokens":768,"cache_creation_input_tokens":512}}`) + usage, err := p.ParseResponse(200, "application/json", body) + require.NoError(t, err) + assert.Equal(t, int64(768), usage.CachedInputTokens) + assert.Equal(t, int64(512), usage.CacheCreationTokens) + assert.Equal(t, int64(256+200+768+512), usage.TotalTokens, "all four buckets sum into total") + }) + + t.Run("absent cache fields leave counts at zero", func(t *testing.T) { + body := []byte(`{"usage":{"input_tokens":100,"output_tokens":50}}`) + usage, err := p.ParseResponse(200, "application/json", body) + require.NoError(t, err) + assert.Zero(t, usage.CachedInputTokens, "no cache_read field = no cached count") + assert.Zero(t, usage.CacheCreationTokens, "no cache_creation field = no creation count") + assert.Equal(t, int64(150), usage.TotalTokens, "back to the simple in+out total when no cache buckets present") + }) +} + +func TestAnthropicExtractPrompt_Messages(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-7","system":"be brief","messages":[{"role":"user","content":"hi"},{"role":"assistant","content":"yes"}]}`) + got := AnthropicParser{}.ExtractPrompt(body) + require.Contains(t, got, "system: be brief", "system surfaces with role label") + require.Contains(t, got, "user: hi", "user message surfaces") + require.Contains(t, got, "assistant: yes", "assistant message surfaces") +} + +func TestAnthropicExtractPrompt_LegacyComplete(t *testing.T) { + body := []byte(`{"model":"claude-2","prompt":"\n\nHuman: hi\n\nAssistant:"}`) + got := AnthropicParser{}.ExtractPrompt(body) + require.Contains(t, got, "Human: hi", "legacy prompt string surfaces") +} + +func TestAnthropicExtractSessionID(t *testing.T) { + t.Run("claude code session suffix", func(t *testing.T) { + body := []byte(`{"model":"claude-opus-4-8","metadata":{"user_id":"user_abc123_account_def456_session_9f8e7d6c"},"messages":[]}`) + assert.Equal(t, "9f8e7d6c", AnthropicParser{}.ExtractSessionID(body), "session_ suffix must be extracted from metadata.user_id") + }) + t.Run("plain user_id is not treated as a session", func(t *testing.T) { + body := []byte(`{"model":"claude-opus-4-8","metadata":{"user_id":"acme-team"},"messages":[]}`) + assert.Equal(t, "", AnthropicParser{}.ExtractSessionID(body), "a user identifier without a session marker must NOT be used as a session id") + }) + t.Run("no metadata yields empty", func(t *testing.T) { + body := []byte(`{"model":"claude-opus-4-8","messages":[{"role":"user","content":"hi"}]}`) + assert.Equal(t, "", AnthropicParser{}.ExtractSessionID(body), "absent metadata.user_id yields no session id") + }) +} + +func TestAnthropicExtractCompletion_Messages(t *testing.T) { + body, err := os.ReadFile(filepath.Join("fixtures", "anthropic_messages.json")) + require.NoError(t, err) + got := AnthropicParser{}.ExtractCompletion(200, "application/json", body) + require.NotEmpty(t, got, "anthropic fixture has assistant text") +} + +func TestAnthropicExtractCompletion_Streaming(t *testing.T) { + got := AnthropicParser{}.ExtractCompletion(200, "text/event-stream", []byte("")) + require.Empty(t, got, "streaming responses are skipped") +} diff --git a/proxy/internal/llm/bedrock.go b/proxy/internal/llm/bedrock.go new file mode 100644 index 000000000..f7802beb2 --- /dev/null +++ b/proxy/internal/llm/bedrock.go @@ -0,0 +1,189 @@ +package llm + +import ( + "encoding/json" + "fmt" + "strings" +) + +// ProviderNameBedrock is the stable label for the AWS Bedrock parser, used as +// the llm.provider metadata value and the cost-meter formula selector. +const ProviderNameBedrock = "bedrock" + +// BedrockParser implements the Parser interface for the AWS Bedrock runtime. +// Bedrock carries the model in the URL path (/model/{id}/{action}); the request +// middleware extracts it there, so this parser focuses on the response shapes: +// the vendor-native InvokeModel body (e.g. Anthropic's snake_case usage) and the +// unified Converse body (camelCase usage). +type BedrockParser struct{} + +var bedrockPathHints = []string{"/invoke", "/converse"} + +// Provider returns ProviderBedrock. +func (BedrockParser) Provider() Provider { return ProviderBedrock } + +// ProviderName returns the stable label used for metrics and metadata. +func (BedrockParser) ProviderName() string { return ProviderNameBedrock } + +// DetectFromURL reports whether the path is a Bedrock runtime model endpoint. +func (BedrockParser) DetectFromURL(path string) bool { + lower := strings.ToLower(path) + if !strings.HasPrefix(lower, "/model/") { + return false + } + for _, hint := range bedrockPathHints { + if strings.Contains(lower, hint) { + return true + } + } + return false +} + +// ParseRequest is a no-op for Bedrock: the model lives in the URL path, not the +// body, and the streaming flag is derived from the path action. The request +// middleware handles both via parseBedrockPath, so this returns empty facts. +func (BedrockParser) ParseRequest([]byte) (RequestFacts, error) { + return RequestFacts{}, nil +} + +// bedrockResponse captures token usage from both Bedrock response shapes: +// InvokeModel (vendor-native; Anthropic uses snake_case + additive cache +// buckets) and Converse (camelCase, with a precomputed total). +type bedrockResponse struct { + Usage struct { + // InvokeModel (Anthropic-on-Bedrock) — snake_case. + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheReadInputTokens int64 `json:"cache_read_input_tokens"` + CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"` + // Converse — camelCase. + InputTokensCamel int64 `json:"inputTokens"` + OutputTokensCamel int64 `json:"outputTokens"` + TotalTokensCamel int64 `json:"totalTokens"` + } `json:"usage"` +} + +// ParseResponse decodes the non-streaming Bedrock response envelope, handling +// both the InvokeModel and Converse usage shapes. Non-200 / non-JSON bodies are +// treated as non-LLM responses so the caller skips cost accounting. +func (BedrockParser) ParseResponse(status int, contentType string, body []byte) (Usage, error) { + if status != 200 { + return Usage{}, fmt.Errorf("bedrock status %d: %w", status, ErrNotLLMResponse) + } + if isAWSEventStream(contentType) || isEventStream(contentType) { + return Usage{}, ErrStreamingUnsupported + } + if !isJSON(contentType) { + return Usage{}, fmt.Errorf("bedrock content-type %q: %w", contentType, ErrNotLLMResponse) + } + + var resp bedrockResponse + if err := json.Unmarshal(body, &resp); err != nil { + return Usage{}, fmt.Errorf("decode bedrock response: %w: %v", ErrMalformedResponse, err) + } + inTok := firstNonZero(resp.Usage.InputTokens, resp.Usage.InputTokensCamel) + outTok := firstNonZero(resp.Usage.OutputTokens, resp.Usage.OutputTokensCamel) + total := resp.Usage.TotalTokensCamel + if total == 0 { + total = inTok + outTok + resp.Usage.CacheReadInputTokens + resp.Usage.CacheCreationInputTokens + } + return Usage{ + InputTokens: inTok, + OutputTokens: outTok, + TotalTokens: total, + CachedInputTokens: resp.Usage.CacheReadInputTokens, + CacheCreationTokens: resp.Usage.CacheCreationInputTokens, + }, nil +} + +// ExtractPrompt returns the user-visible prompt from a Bedrock request body, +// handling both the InvokeModel (Anthropic Messages: system + messages[]) and +// Converse (messages[].content[].text) shapes. Returns "" on decode failure. +func (BedrockParser) ExtractPrompt(body []byte) string { + var req struct { + System json.RawMessage `json:"system"` + Messages []struct { + Role string `json:"role"` + Content json.RawMessage `json:"content"` + } `json:"messages"` + } + if err := json.Unmarshal(body, &req); err != nil { + return "" + } + var b strings.Builder + if s := decodeStringOrJoin(req.System); s != "" { + b.WriteString("system: ") + b.WriteString(s) + } + for _, m := range req.Messages { + if b.Len() > 0 { + b.WriteByte('\n') + } + if m.Role != "" { + b.WriteString(m.Role) + b.WriteString(": ") + } + b.WriteString(decodeStringOrJoin(m.Content)) + } + return b.String() +} + +// ExtractCompletion returns the assistant text from a non-streaming Bedrock +// response, handling InvokeModel (Anthropic content[].text) and Converse +// (output.message.content[].text). +func (BedrockParser) ExtractCompletion(status int, contentType string, body []byte) string { + if status != 200 || isAWSEventStream(contentType) || !isJSON(contentType) { + return "" + } + var resp struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + Output struct { + Message struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + } `json:"message"` + } `json:"output"` + } + if err := json.Unmarshal(body, &resp); err != nil { + return "" + } + var b strings.Builder + appendText := func(text string) { + if text == "" { + return + } + if b.Len() > 0 { + b.WriteByte('\n') + } + b.WriteString(text) + } + for _, p := range resp.Content { + appendText(p.Text) + } + for _, p := range resp.Output.Message.Content { + appendText(p.Text) + } + return b.String() +} + +// ExtractSessionID has no Bedrock-native marker; session grouping relies on the +// request headers handled by the middleware. Returns "". +func (BedrockParser) ExtractSessionID([]byte) string { return "" } + +// firstNonZero returns a when non-zero, else b. Folds the snake_case and +// camelCase usage variants into a single value. +func firstNonZero(a, b int64) int64 { + if a != 0 { + return a + } + return b +} + +// isAWSEventStream reports whether contentType is the AWS binary event-stream +// framing used by Bedrock's streaming endpoints. +func isAWSEventStream(contentType string) bool { + return strings.Contains(strings.ToLower(contentType), "application/vnd.amazon.eventstream") +} diff --git a/proxy/internal/llm/bedrock_test.go b/proxy/internal/llm/bedrock_test.go new file mode 100644 index 000000000..ca6f092f3 --- /dev/null +++ b/proxy/internal/llm/bedrock_test.go @@ -0,0 +1,65 @@ +package llm + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBedrockParser_ParseResponse_Invoke(t *testing.T) { + body := []byte(`{"usage":{"input_tokens":13,"output_tokens":5,"cache_read_input_tokens":2,"cache_creation_input_tokens":4}}`) + u, err := BedrockParser{}.ParseResponse(200, "application/json", body) + require.NoError(t, err) + require.Equal(t, int64(13), u.InputTokens, "invoke input tokens") + require.Equal(t, int64(5), u.OutputTokens, "invoke output tokens") + require.Equal(t, int64(2), u.CachedInputTokens, "invoke cache-read tokens") + require.Equal(t, int64(4), u.CacheCreationTokens, "invoke cache-creation tokens") + require.Equal(t, int64(13+5+2+4), u.TotalTokens, "invoke total is additive") +} + +func TestBedrockParser_ParseResponse_Converse(t *testing.T) { + body := []byte(`{"output":{"message":{"content":[{"text":"pong"}]}},"usage":{"inputTokens":11,"outputTokens":3,"totalTokens":14}}`) + u, err := BedrockParser{}.ParseResponse(200, "application/json", body) + require.NoError(t, err) + require.Equal(t, int64(11), u.InputTokens, "converse camelCase input tokens") + require.Equal(t, int64(3), u.OutputTokens, "converse camelCase output tokens") + require.Equal(t, int64(14), u.TotalTokens, "converse uses provider total") +} + +func TestBedrockParser_ParseResponse_StreamingUnsupported(t *testing.T) { + _, err := BedrockParser{}.ParseResponse(200, "application/vnd.amazon.eventstream", []byte("binary")) + require.ErrorIs(t, err, ErrStreamingUnsupported, "event-stream must route to the streaming accumulator") +} + +func TestBedrockParser_ParseResponse_NonSuccess(t *testing.T) { + _, err := BedrockParser{}.ParseResponse(404, "application/json", []byte(`{"message":"gated"}`)) + require.ErrorIs(t, err, ErrNotLLMResponse, "non-200 is not an LLM response") +} + +func TestBedrockParser_ExtractCompletion(t *testing.T) { + invoke := BedrockParser{}.ExtractCompletion(200, "application/json", []byte(`{"content":[{"text":"a"},{"text":"b"}]}`)) + require.Equal(t, "a\nb", invoke, "invoke completion joins content parts") + + converse := BedrockParser{}.ExtractCompletion(200, "application/json", []byte(`{"output":{"message":{"content":[{"text":"x"}]}}}`)) + require.Equal(t, "x", converse, "converse completion reads output.message.content") +} + +func TestBedrockParser_ExtractPrompt(t *testing.T) { + invoke := BedrockParser{}.ExtractPrompt([]byte(`{"messages":[{"role":"user","content":"hi"}]}`)) + require.Equal(t, "user: hi", invoke, "invoke prompt reads anthropic content string") + + converse := BedrockParser{}.ExtractPrompt([]byte(`{"messages":[{"role":"user","content":[{"text":"hello"}]}]}`)) + require.Equal(t, "user: hello", converse, "converse prompt reads content parts") +} + +func TestBedrockParser_DetectFromURL(t *testing.T) { + require.True(t, BedrockParser{}.DetectFromURL("/model/eu.anthropic.claude/invoke"), "invoke path") + require.True(t, BedrockParser{}.DetectFromURL("/model/x/converse-stream"), "converse-stream path") + require.False(t, BedrockParser{}.DetectFromURL("/v1/chat/completions"), "openai path is not bedrock") +} + +func TestBedrockParser_RegisteredByName(t *testing.T) { + p, ok := ParserByName(ProviderNameBedrock) + require.True(t, ok, "bedrock parser is registered") + require.Equal(t, ProviderNameBedrock, p.ProviderName()) +} diff --git a/proxy/internal/llm/errors.go b/proxy/internal/llm/errors.go new file mode 100644 index 000000000..09019fbbe --- /dev/null +++ b/proxy/internal/llm/errors.go @@ -0,0 +1,31 @@ +package llm + +import "errors" + +// Sentinel errors returned by parsers and the pricing loader. Callers use +// errors.Is to branch on a condition without coupling to parser internals. +var ( + // ErrUnknownProvider indicates no parser claimed the request path. + ErrUnknownProvider = errors.New("llmobs: unknown provider") + + // ErrUnsupportedModel indicates the response parsed successfully but the + // model is absent from the pricing table. Token counts are still valid. + ErrUnsupportedModel = errors.New("llmobs: unsupported model") + + // ErrNotLLMResponse indicates the response is not a JSON success body + // that a non-streaming parser can consume (non-200 or wrong content type). + ErrNotLLMResponse = errors.New("llmobs: not an LLM response") + + // ErrStreamingUnsupported indicates the caller passed an SSE response to + // a non-streaming parser. Streaming is handled separately via the SSE + // scanner. + ErrStreamingUnsupported = errors.New("llmobs: streaming response requires SSE scanner") + + // ErrMalformedResponse indicates the response body could not be decoded + // as the provider-specific JSON schema. + ErrMalformedResponse = errors.New("llmobs: malformed response body") + + // ErrMalformedRequest indicates the request body could not be decoded as + // the provider-specific JSON schema. + ErrMalformedRequest = errors.New("llmobs: malformed request body") +) diff --git a/proxy/internal/llm/fixtures/anthropic_messages.json b/proxy/internal/llm/fixtures/anthropic_messages.json new file mode 100644 index 000000000..2c9bb663a --- /dev/null +++ b/proxy/internal/llm/fixtures/anthropic_messages.json @@ -0,0 +1,17 @@ +{ + "id": "msg_abc", + "type": "message", + "role": "assistant", + "model": "claude-sonnet-4-5", + "content": [ + { + "type": "text", + "text": "Hello, world!" + } + ], + "stop_reason": "end_turn", + "usage": { + "input_tokens": 123, + "output_tokens": 45 + } +} diff --git a/proxy/internal/llm/fixtures/anthropic_stream.txt b/proxy/internal/llm/fixtures/anthropic_stream.txt new file mode 100644 index 000000000..2b8bb889c --- /dev/null +++ b/proxy/internal/llm/fixtures/anthropic_stream.txt @@ -0,0 +1,21 @@ +event: message_start +data: {"type":"message_start","message":{"id":"msg_abc","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[],"stop_reason":null,"usage":{"input_tokens":123,"output_tokens":1}}} + +event: content_block_start +data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":", world!"}} + +event: content_block_stop +data: {"type":"content_block_stop","index":0} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":45}} + +event: message_stop +data: {"type":"message_stop"} + diff --git a/proxy/internal/llm/fixtures/openai_chat_completion.json b/proxy/internal/llm/fixtures/openai_chat_completion.json new file mode 100644 index 000000000..d0e25337b --- /dev/null +++ b/proxy/internal/llm/fixtures/openai_chat_completion.json @@ -0,0 +1,21 @@ +{ + "id": "chatcmpl-abc", + "object": "chat.completion", + "created": 1700000000, + "model": "gpt-4o-mini", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello, world!" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 123, + "completion_tokens": 45, + "total_tokens": 168 + } +} diff --git a/proxy/internal/llm/fixtures/openai_responses.json b/proxy/internal/llm/fixtures/openai_responses.json new file mode 100644 index 000000000..f998fcd33 --- /dev/null +++ b/proxy/internal/llm/fixtures/openai_responses.json @@ -0,0 +1,24 @@ +{ + "id": "resp_abc", + "object": "response", + "created_at": 1700000000, + "model": "gpt-5.4", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "ok"}] + } + ], + "usage": { + "input_tokens": 15, + "input_tokens_details": { + "cached_tokens": 0 + }, + "output_tokens": 414, + "output_tokens_details": { + "reasoning_tokens": 0 + }, + "total_tokens": 429 + } +} diff --git a/proxy/internal/llm/fixtures/openai_responses_stream.txt b/proxy/internal/llm/fixtures/openai_responses_stream.txt new file mode 100644 index 000000000..2801fa99d --- /dev/null +++ b/proxy/internal/llm/fixtures/openai_responses_stream.txt @@ -0,0 +1,24 @@ +event: response.created +data: {"type":"response.created","response":{"id":"resp_abc","object":"response","model":"gpt-5.5","usage":null}} + +event: response.in_progress +data: {"type":"response.in_progress","response":{"id":"resp_abc","usage":null}} + +event: response.output_item.added +data: {"type":"response.output_item.added","output_index":0,"item":{"type":"message","role":"assistant","content":[]}} + +event: response.content_part.added +data: {"type":"response.content_part.added","item_id":"msg_1","output_index":0,"content_index":0,"part":{"type":"output_text","text":""}} + +event: response.output_text.delta +data: {"type":"response.output_text.delta","item_id":"msg_1","output_index":0,"content_index":0,"delta":"Hello"} + +event: response.output_text.delta +data: {"type":"response.output_text.delta","item_id":"msg_1","output_index":0,"content_index":0,"delta":", world!"} + +event: response.output_text.done +data: {"type":"response.output_text.done","item_id":"msg_1","output_index":0,"content_index":0,"text":"Hello, world!"} + +event: response.completed +data: {"type":"response.completed","response":{"id":"resp_abc","object":"response","model":"gpt-5.5","usage":{"input_tokens":123,"input_tokens_details":{"cached_tokens":40},"output_tokens":45,"output_tokens_details":{"reasoning_tokens":12},"total_tokens":168}}} + diff --git a/proxy/internal/llm/fixtures/openai_stream.txt b/proxy/internal/llm/fixtures/openai_stream.txt new file mode 100644 index 000000000..058b7ce22 --- /dev/null +++ b/proxy/internal/llm/fixtures/openai_stream.txt @@ -0,0 +1,8 @@ +data: {"id":"chatcmpl-abc","object":"chat.completion.chunk","created":1700000000,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} + +data: {"id":"chatcmpl-abc","object":"chat.completion.chunk","created":1700000000,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{"content":", world!"},"finish_reason":null}]} + +data: {"id":"chatcmpl-abc","object":"chat.completion.chunk","created":1700000000,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":123,"completion_tokens":45,"total_tokens":168}} + +data: [DONE] + diff --git a/proxy/internal/llm/fixtures/pricing.yaml b/proxy/internal/llm/fixtures/pricing.yaml new file mode 100644 index 000000000..3d26ff803 --- /dev/null +++ b/proxy/internal/llm/fixtures/pricing.yaml @@ -0,0 +1,59 @@ +# Realistic-pricing starter for llm_observability. Drop this into the +# directory you point the proxy at via --plugin-data-dir, then reference it +# from the target's plugin config: +# +# plugins: +# - id: llm_observability +# enabled: true +# params: +# pricing_path: pricing.yaml +# +# Values are USD per 1_000 tokens. Public list prices drift; treat this as a +# starting point and keep your production copy current. + +openai: + # GPT-5 family + gpt-5: + input_per_1k: 0.00125 + output_per_1k: 0.01 + gpt-5-mini: + input_per_1k: 0.00025 + output_per_1k: 0.002 + gpt-5-nano: + input_per_1k: 0.00005 + output_per_1k: 0.0004 + gpt-5.4: + input_per_1k: 0.00125 + output_per_1k: 0.01 + # GPT-4o family + gpt-4o: + input_per_1k: 0.0025 + output_per_1k: 0.01 + gpt-4o-mini: + input_per_1k: 0.00015 + output_per_1k: 0.0006 + # Embeddings + text-embedding-3-large: + input_per_1k: 0.00013 + output_per_1k: 0 + text-embedding-3-small: + input_per_1k: 0.00002 + output_per_1k: 0 + +anthropic: + # Claude 4.x family + claude-opus-4-7: + input_per_1k: 0.015 + output_per_1k: 0.075 + claude-sonnet-4-7: + input_per_1k: 0.003 + output_per_1k: 0.015 + claude-sonnet-4-6: + input_per_1k: 0.003 + output_per_1k: 0.015 + claude-sonnet-4-5: + input_per_1k: 0.003 + output_per_1k: 0.015 + claude-haiku-4-5: + input_per_1k: 0.0008 + output_per_1k: 0.004 diff --git a/proxy/internal/llm/openai.go b/proxy/internal/llm/openai.go new file mode 100644 index 000000000..86ee30797 --- /dev/null +++ b/proxy/internal/llm/openai.go @@ -0,0 +1,412 @@ +package llm + +import ( + "encoding/json" + "fmt" + "strings" +) + +// OpenAIParser implements the Parser interface for OpenAI-compatible APIs. +// It recognizes chat.completions, completions, embeddings, and the newer +// responses endpoint; any proxy path-prefix stripping is tolerated by the +// substring match in DetectFromURL. +type OpenAIParser struct{} + +// openAIPathHints are substring patterns that mark a request as +// OpenAI-shaped. The bare `/chat/completions` is listed alongside +// `/v1/chat/completions` because gateways like Cloudflare AI +// Gateway place their own version segment before the provider +// slug (gateway/v1/{account}/{gateway}/openai/chat/completions) — +// the canonical `/v1/` ends up nowhere near `/chat/completions`, +// so the `/v1/chat/completions` hint misses. `/chat/completions` +// is OpenAI's API contract: any service accepting OpenAI bodies +// serves at this path, so false-positive risk is negligible. +// `/completions` (legacy), `/embeddings`, and `/responses` are +// kept on the canonical-only path because their bare forms are +// too generic to be safe substrings. +var openAIPathHints = []string{ + "/v1/chat/completions", + "/v1/completions", + "/v1/embeddings", + "/v1/responses", + "/chat/completions", +} + +// Provider returns ProviderOpenAI. +func (OpenAIParser) Provider() Provider { return ProviderOpenAI } + +// ProviderName returns the stable label used for metrics and metadata. +func (OpenAIParser) ProviderName() string { return "openai" } + +// DetectFromURL reports whether the given request path looks like an OpenAI +// API endpoint. The match is case-insensitive and substring-based so that a +// reverse proxy prefix strip or rewrite does not defeat detection. +func (OpenAIParser) DetectFromURL(path string) bool { + lower := strings.ToLower(path) + for _, hint := range openAIPathHints { + if strings.Contains(lower, hint) { + return true + } + } + return false +} + +type openAIRequest struct { + Model string `json:"model"` + Stream *bool `json:"stream"` + StreamOptions *struct { + IncludeUsage *bool `json:"include_usage"` + } `json:"stream_options"` + // Chat Completions / Completions: messages[].content (string or array of + // content parts). Responses API: input is either a string or an array of + // items with content parts. We use json.RawMessage to defer parsing each + // shape independently. + Messages []openAIMessage `json:"messages"` + Prompt json.RawMessage `json:"prompt"` + Input json.RawMessage `json:"input"` +} + +type openAIMessage struct { + Role string `json:"role"` + Content json.RawMessage `json:"content"` +} + +// ParseRequest extracts the model name and streaming flag from an OpenAI +// request body. Unknown or missing fields leave the corresponding struct +// members zero-valued. +func (OpenAIParser) ParseRequest(body []byte) (RequestFacts, error) { + var req openAIRequest + if err := json.Unmarshal(body, &req); err != nil { + return RequestFacts{}, fmt.Errorf("decode openai request: %w: %v", ErrMalformedRequest, err) + } + return RequestFacts{ + Model: req.Model, + Stream: ptrDeref(req.Stream), + }, nil +} + +// openAIResponse accepts both naming conventions in a single struct because +// OpenAI's older Chat Completions API uses prompt_tokens/completion_tokens +// while the newer Responses API (/v1/responses) uses input_tokens/output_tokens +// (aligned with Anthropic). Pointer fields let us tell "absent" from "zero". +// +// PromptTokensDetails.CachedTokens (Chat Completions) and +// InputTokensDetails.CachedTokens (Responses API) carry the SUBSET of +// prompt/input tokens that hit the prompt cache. Cost-meter applies the +// discount rate to that subset and the regular rate to the remainder so +// we never double-bill the cached portion. +type openAIResponse struct { + Usage struct { + PromptTokens *int64 `json:"prompt_tokens"` + CompletionTokens *int64 `json:"completion_tokens"` + InputTokens *int64 `json:"input_tokens"` + OutputTokens *int64 `json:"output_tokens"` + TotalTokens *int64 `json:"total_tokens"` + PromptTokensDetails *struct { + CachedTokens *int64 `json:"cached_tokens"` + } `json:"prompt_tokens_details"` + InputTokensDetails *struct { + CachedTokens *int64 `json:"cached_tokens"` + } `json:"input_tokens_details"` + } `json:"usage"` +} + +// ParseResponse decodes the non-streaming OpenAI response envelope. Status +// codes other than 200 are treated as non-LLM responses so the caller can +// skip cost accounting without aborting the request. +func (OpenAIParser) ParseResponse(status int, contentType string, body []byte) (Usage, error) { + if status != 200 { + return Usage{}, fmt.Errorf("openai status %d: %w", status, ErrNotLLMResponse) + } + if isEventStream(contentType) { + return Usage{}, ErrStreamingUnsupported + } + if !isJSON(contentType) { + return Usage{}, fmt.Errorf("openai content-type %q: %w", contentType, ErrNotLLMResponse) + } + + var resp openAIResponse + if err := json.Unmarshal(body, &resp); err != nil { + return Usage{}, fmt.Errorf("decode openai response: %w: %v", ErrMalformedResponse, err) + } + + // Responses-API names take precedence when present; fall back to the older + // Chat Completions names. This handles both endpoints transparently + // without forcing a per-route configuration. + u := Usage{ + InputTokens: pickInt64(resp.Usage.InputTokens, resp.Usage.PromptTokens), + OutputTokens: pickInt64(resp.Usage.OutputTokens, resp.Usage.CompletionTokens), + TotalTokens: derefInt64(resp.Usage.TotalTokens), + CachedInputTokens: openAICachedTokens(resp), + } + if u.TotalTokens == 0 && (u.InputTokens > 0 || u.OutputTokens > 0) { + u.TotalTokens = u.InputTokens + u.OutputTokens + } + return u, nil +} + +// openAICachedTokens returns the cached-prompt subset reported by +// either the Responses-API (input_tokens_details.cached_tokens) or +// the Chat-Completions API (prompt_tokens_details.cached_tokens). +// Responses-API takes precedence when both are populated. +func openAICachedTokens(resp openAIResponse) int64 { + // Responses-API details are authoritative when present: an explicit + // cached_tokens of 0 must be honored, not treated as missing and + // overridden by the Chat-Completions field (which would overstate cache). + if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens != nil { + return derefInt64(resp.Usage.InputTokensDetails.CachedTokens) + } + if resp.Usage.PromptTokensDetails != nil { + return derefInt64(resp.Usage.PromptTokensDetails.CachedTokens) + } + return 0 +} + +// ExtractPrompt returns the user-visible prompt text from an OpenAI request. +// Handles chat.completions (messages[].content), legacy completions (prompt +// string), and the Responses API (input as string or content-part array). +// Returns "" when nothing extractable is found. +func (OpenAIParser) ExtractPrompt(body []byte) string { + var req openAIRequest + if err := json.Unmarshal(body, &req); err != nil { + return "" + } + if len(req.Messages) > 0 { + return joinMessages(req.Messages) + } + if len(req.Input) > 0 { + return extractResponsesInput(req.Input) + } + if len(req.Prompt) > 0 { + return decodeStringOrJoin(req.Prompt) + } + return "" +} + +// extractResponsesInput handles the Responses API `input` field. It is one +// of three shapes: a plain string, an array of message items +// ({role, content: string | [parts]}) as sent by Codex and the Responses +// SDK, or a flat array of content parts ({type, text/input_text}). Message +// items are flattened to "role: text" lines; items without extractable text +// (reasoning blocks, tool calls) are skipped. +func extractResponsesInput(raw json.RawMessage) string { + if s, ok := tryDecodeString(raw); ok { + return s + } + var items []struct { + Role string `json:"role"` + Content json.RawMessage `json:"content"` + Text string `json:"text"` + InputText string `json:"input_text"` + } + if err := json.Unmarshal(raw, &items); err != nil { + return extractContentParts(raw) + } + var b strings.Builder + for _, it := range items { + var text string + switch { + case len(it.Content) > 0: + text = decodeStringOrJoin(it.Content) + case it.Text != "": + text = it.Text + case it.InputText != "": + text = it.InputText + } + if text == "" { + continue + } + if b.Len() > 0 { + b.WriteByte('\n') + } + if it.Role != "" { + b.WriteString(it.Role) + b.WriteString(": ") + } + b.WriteString(text) + } + return b.String() +} + +// ExtractSessionID reads the OpenAI session marker. Codex (the Responses +// API client) stamps client_metadata.session_id on every request body; +// plain chat.completions traffic carries no session id and yields "". +func (OpenAIParser) ExtractSessionID(body []byte) string { + var req struct { + ClientMetadata struct { + SessionID string `json:"session_id"` + } `json:"client_metadata"` + } + if err := json.Unmarshal(body, &req); err != nil { + return "" + } + return req.ClientMetadata.SessionID +} + +type openAIChatChoice struct { + Message struct { + Role string `json:"role"` + Content json.RawMessage `json:"content"` + } `json:"message"` + Text string `json:"text"` +} + +type openAIChatResponse struct { + Choices []openAIChatChoice `json:"choices"` + // Responses API: output[].content[].text + Output []struct { + Type string `json:"type"` + Content json.RawMessage `json:"content"` + Text string `json:"text"` + } `json:"output"` + OutputText string `json:"output_text"` +} + +// ExtractCompletion returns the assistant text from a non-streaming OpenAI +// response. Handles chat.completions (choices[].message.content), legacy +// completions (choices[].text), and Responses API (output[].content[].text +// or the convenience output_text field). +func (OpenAIParser) ExtractCompletion(status int, contentType string, body []byte) string { + if status != 200 || isEventStream(contentType) || !isJSON(contentType) { + return "" + } + var resp openAIChatResponse + if err := json.Unmarshal(body, &resp); err != nil { + return "" + } + if resp.OutputText != "" { + return resp.OutputText + } + for _, c := range resp.Choices { + if len(c.Message.Content) > 0 { + if s := decodeStringOrJoin(c.Message.Content); s != "" { + return s + } + } + if c.Text != "" { + return c.Text + } + } + for _, o := range resp.Output { + if o.Text != "" { + return o.Text + } + if len(o.Content) > 0 { + if s := extractContentParts(o.Content); s != "" { + return s + } + } + } + return "" +} + +// joinMessages flattens a chat.completions messages array into a single +// "role: content" string per message, separated by newlines. Roles surface +// system/user/assistant context which is useful for log review. +func joinMessages(msgs []openAIMessage) string { + var b strings.Builder + for i, m := range msgs { + if i > 0 { + b.WriteByte('\n') + } + if m.Role != "" { + b.WriteString(m.Role) + b.WriteString(": ") + } + b.WriteString(decodeStringOrJoin(m.Content)) + } + return b.String() +} + +// extractContentParts handles the Responses-API content shape, which is +// either a single string or an array of {type, text} parts. text and +// input_text both carry user-facing content. +func extractContentParts(raw json.RawMessage) string { + if s, ok := tryDecodeString(raw); ok { + return s + } + var parts []struct { + Type string `json:"type"` + Text string `json:"text"` + InputText string `json:"input_text"` + } + if err := json.Unmarshal(raw, &parts); err != nil { + // Last-ditch: array of strings. + var arr []string + if json.Unmarshal(raw, &arr) == nil { + return strings.Join(arr, "\n") + } + return "" + } + var b strings.Builder + for _, p := range parts { + var text string + switch { + case p.Text != "": + text = p.Text + case p.InputText != "": + text = p.InputText + } + if text == "" { + continue + } + if b.Len() > 0 { + b.WriteByte('\n') + } + b.WriteString(text) + } + return b.String() +} + +// decodeStringOrJoin accepts either a JSON string or a content-parts array +// (chat.completions multimodal) and returns a flat string. Multimodal parts +// are separated by newlines; non-text parts are skipped. +func decodeStringOrJoin(raw json.RawMessage) string { + if s, ok := tryDecodeString(raw); ok { + return s + } + return extractContentParts(raw) +} + +func tryDecodeString(raw json.RawMessage) (string, bool) { + if len(raw) == 0 { + return "", false + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s, true + } + return "", false +} + +// pickInt64 returns the first non-nil pointer's value. Used to prefer one +// naming convention while transparently falling back to another. +func pickInt64(preferred, fallback *int64) int64 { + if preferred != nil { + return *preferred + } + return derefInt64(fallback) +} + +func derefInt64(v *int64) int64 { + if v == nil { + return 0 + } + return *v +} + +func ptrDeref(b *bool) bool { + if b == nil { + return false + } + return *b +} + +func isEventStream(contentType string) bool { + return strings.Contains(strings.ToLower(contentType), "text/event-stream") +} + +func isJSON(contentType string) bool { + lower := strings.ToLower(contentType) + return strings.Contains(lower, "application/json") || strings.Contains(lower, "+json") +} diff --git a/proxy/internal/llm/openai_test.go b/proxy/internal/llm/openai_test.go new file mode 100644 index 000000000..7a5fca4fb --- /dev/null +++ b/proxy/internal/llm/openai_test.go @@ -0,0 +1,255 @@ +package llm + +import ( + "errors" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOpenAIDetectFromURL(t *testing.T) { + p := OpenAIParser{} + + cases := map[string]bool{ + "/v1/chat/completions": true, + "/v1/completions": true, + "/v1/embeddings": true, + "/v1/responses": true, + "/API/V1/Chat/Completions": true, + "/upstream/v1/chat/completions?trace=1": true, + // Cloudflare AI Gateway puts its own /v1/{account}/{gateway} + // segment between the canonical /v1/ and the provider slug, + // so the /v1/chat/completions substring no longer appears + // adjacent in the path. The bare /chat/completions hint + // catches Cloudflare's OpenAI direct path + // (/v1/{account}/{gateway}/openai/chat/completions) and + // compat path (/v1/{account}/{gateway}/compat/chat/completions). + "/v1/{account}/{gateway}/openai/chat/completions": true, + "/v1/{account}/{gateway}/compat/chat/completions": true, + "/chat/completions": true, + "/v1/messages": false, + "/healthz": false, + "": false, + } + for path, want := range cases { + assert.Equal(t, want, p.DetectFromURL(path), "DetectFromURL(%q)", path) + } +} + +func TestOpenAIParseRequest(t *testing.T) { + p := OpenAIParser{} + + t.Run("stream true", func(t *testing.T) { + facts, err := p.ParseRequest([]byte(`{"model":"gpt-4o","stream":true,"stream_options":{"include_usage":true}}`)) + require.NoError(t, err) + assert.Equal(t, "gpt-4o", facts.Model, "request model extracted") + assert.True(t, facts.Stream, "request marked as streaming") + }) + + t.Run("stream default", func(t *testing.T) { + facts, err := p.ParseRequest([]byte(`{"model":"gpt-4o-mini"}`)) + require.NoError(t, err) + assert.Equal(t, "gpt-4o-mini", facts.Model, "request model extracted") + assert.False(t, facts.Stream, "missing stream flag defaults to false") + }) + + t.Run("malformed", func(t *testing.T) { + _, err := p.ParseRequest([]byte(`{not json}`)) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrMalformedRequest), "sentinel error wrapped") + }) +} + +func TestOpenAIParseResponse(t *testing.T) { + p := OpenAIParser{} + + t.Run("happy fixture", func(t *testing.T) { + body, err := os.ReadFile(filepath.Join("fixtures", "openai_chat_completion.json")) + require.NoError(t, err, "fixture must be readable") + + usage, err := p.ParseResponse(200, "application/json", body) + require.NoError(t, err) + assert.Equal(t, int64(123), usage.InputTokens, "prompt tokens become input") + assert.Equal(t, int64(45), usage.OutputTokens, "completion tokens become output") + assert.Equal(t, int64(168), usage.TotalTokens, "total_tokens carried through") + }) + + t.Run("total computed when missing", func(t *testing.T) { + body := []byte(`{"usage":{"prompt_tokens":10,"completion_tokens":5}}`) + usage, err := p.ParseResponse(200, "application/json", body) + require.NoError(t, err) + assert.Equal(t, int64(15), usage.TotalTokens, "total computed from in+out") + }) + + t.Run("streaming rejected", func(t *testing.T) { + _, err := p.ParseResponse(200, "text/event-stream", []byte("")) + require.ErrorIs(t, err, ErrStreamingUnsupported, "SSE responses must use the scanner") + }) + + t.Run("non-200", func(t *testing.T) { + _, err := p.ParseResponse(500, "application/json", []byte(`{"error":"x"}`)) + require.ErrorIs(t, err, ErrNotLLMResponse, "non-200 rejected as non-LLM") + }) + + t.Run("non-json content type", func(t *testing.T) { + _, err := p.ParseResponse(200, "text/plain", []byte(`{}`)) + require.ErrorIs(t, err, ErrNotLLMResponse, "text/plain treated as non-LLM") + }) + + t.Run("malformed body", func(t *testing.T) { + _, err := p.ParseResponse(200, "application/json", []byte(`{not json`)) + require.ErrorIs(t, err, ErrMalformedResponse, "bad JSON yields malformed error") + }) + + // Responses-API fixture: /v1/responses returns input_tokens/output_tokens + // (Anthropic-style) instead of prompt_tokens/completion_tokens. The parser + // must accept both. + t.Run("responses api fixture", func(t *testing.T) { + body, err := os.ReadFile(filepath.Join("fixtures", "openai_responses.json")) + require.NoError(t, err, "fixture must be readable") + + usage, err := p.ParseResponse(200, "application/json", body) + require.NoError(t, err) + assert.Equal(t, int64(15), usage.InputTokens, "input_tokens should map directly") + assert.Equal(t, int64(414), usage.OutputTokens, "output_tokens should map directly") + assert.Equal(t, int64(429), usage.TotalTokens, "total_tokens carried through") + }) + + t.Run("responses api naming preferred over chat-completions when both present", func(t *testing.T) { + body := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"input_tokens":15,"output_tokens":414,"total_tokens":429}}`) + usage, err := p.ParseResponse(200, "application/json", body) + require.NoError(t, err) + assert.Equal(t, int64(15), usage.InputTokens, "responses-api names take precedence") + assert.Equal(t, int64(414), usage.OutputTokens, "responses-api names take precedence") + }) + + t.Run("chat-completions naming still works alone", func(t *testing.T) { + body := []byte(`{"usage":{"prompt_tokens":15,"completion_tokens":414,"total_tokens":429}}`) + usage, err := p.ParseResponse(200, "application/json", body) + require.NoError(t, err) + assert.Equal(t, int64(15), usage.InputTokens, "prompt_tokens fallback") + assert.Equal(t, int64(414), usage.OutputTokens, "completion_tokens fallback") + }) + + // Cached-prompt accounting. cached_tokens is a SUBSET of + // prompt_tokens — input_tokens carries the full prompt count and + // the cached subset is reported separately so the cost meter can + // apply the discount rate to that portion. + t.Run("chat-completions cached_tokens subset surfaces", func(t *testing.T) { + body := []byte(`{"usage":{"prompt_tokens":1024,"completion_tokens":200,"total_tokens":1224,"prompt_tokens_details":{"cached_tokens":768}}}`) + usage, err := p.ParseResponse(200, "application/json", body) + require.NoError(t, err) + assert.Equal(t, int64(1024), usage.InputTokens, "input remains the full prompt count — cached is a subset, not a separate bucket") + assert.Equal(t, int64(768), usage.CachedInputTokens, "cached_tokens must surface so cost meter can discount the cached subset") + assert.Zero(t, usage.CacheCreationTokens, "OpenAI has no cache_creation analogue") + }) + + t.Run("responses-api input_tokens_details.cached_tokens surfaces", func(t *testing.T) { + body := []byte(`{"usage":{"input_tokens":2048,"output_tokens":100,"total_tokens":2148,"input_tokens_details":{"cached_tokens":1500}}}`) + usage, err := p.ParseResponse(200, "application/json", body) + require.NoError(t, err) + assert.Equal(t, int64(2048), usage.InputTokens) + assert.Equal(t, int64(1500), usage.CachedInputTokens, "Responses-API input_tokens_details.cached_tokens path must surface too") + }) + + t.Run("responses-api cached takes precedence over chat-completions when both present", func(t *testing.T) { + body := []byte(`{"usage":{"prompt_tokens":1,"input_tokens":2,"output_tokens":3,"prompt_tokens_details":{"cached_tokens":50},"input_tokens_details":{"cached_tokens":99}}}`) + usage, err := p.ParseResponse(200, "application/json", body) + require.NoError(t, err) + assert.Equal(t, int64(99), usage.CachedInputTokens, "Responses-API field wins when both naming conventions are present") + }) + + t.Run("absent cached_tokens leaves cached counts at zero", func(t *testing.T) { + body := []byte(`{"usage":{"prompt_tokens":15,"completion_tokens":414,"total_tokens":429}}`) + usage, err := p.ParseResponse(200, "application/json", body) + require.NoError(t, err) + assert.Zero(t, usage.CachedInputTokens, "no prompt_tokens_details = no cached subset") + }) +} + +func TestOpenAIExtractPrompt_ChatCompletions(t *testing.T) { + body := []byte(`{"model":"gpt-4o-mini","messages":[{"role":"system","content":"be brief"},{"role":"user","content":"ping"}]}`) + got := OpenAIParser{}.ExtractPrompt(body) + require.NotEmpty(t, got, "messages array must extract") + require.Contains(t, got, "system: be brief", "system role and content surface") + require.Contains(t, got, "user: ping", "user role and content surface") +} + +func TestOpenAIExtractPrompt_ResponsesAPIStringInput(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","input":"Hello there"}`) + got := OpenAIParser{}.ExtractPrompt(body) + require.Equal(t, "Hello there", got, "string input field should pass through") +} + +func TestOpenAIExtractPrompt_ResponsesAPIInputParts(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","input":[{"type":"input_text","input_text":"first"},{"type":"input_text","input_text":"second"}]}`) + got := OpenAIParser{}.ExtractPrompt(body) + require.Contains(t, got, "first", "first content part surfaces") + require.Contains(t, got, "second", "second content part surfaces") +} + +// TestOpenAIExtractPrompt_ResponsesAPIMessageItems guards the live Codex +// shape: input is an array of message items whose text is nested under +// content[].text, not flat content parts. The old code fed the outer array +// to the content-part decoder and extracted nothing, so the stored prompt +// was empty. +func TestOpenAIExtractPrompt_ResponsesAPIMessageItems(t *testing.T) { + body := []byte(`{"model":"gpt-5.5","input":[` + + `{"type":"message","role":"developer","content":[{"type":"input_text","text":"system rules"}]},` + + `{"type":"message","role":"user","content":[{"type":"input_text","text":"hello there"}]},` + + `{"type":"reasoning","encrypted_content":"opaque","summary":[]},` + + `{"type":"message","role":"assistant","content":[{"type":"output_text","text":"prior reply"}]}` + + `]}`) + got := OpenAIParser{}.ExtractPrompt(body) + require.Contains(t, got, "system rules", "developer message content must surface") + require.Contains(t, got, "hello there", "user message content must surface") + require.Contains(t, got, "developer:", "role labels must prefix each message") + require.NotContains(t, got, "opaque", "reasoning items without text must be skipped") +} + +func TestOpenAIExtractPrompt_LegacyCompletion(t *testing.T) { + body := []byte(`{"model":"text-davinci-003","prompt":"once upon a time"}`) + got := OpenAIParser{}.ExtractPrompt(body) + require.Equal(t, "once upon a time", got, "string prompt field should pass through") +} + +func TestOpenAIExtractSessionID(t *testing.T) { + t.Run("codex client_metadata.session_id", func(t *testing.T) { + body := []byte(`{"model":"gpt-5.5","client_metadata":{"session_id":"019eeb72-ab7c-7cd2","thread_id":"t1"},"input":[]}`) + assert.Equal(t, "019eeb72-ab7c-7cd2", OpenAIParser{}.ExtractSessionID(body), "Codex session id must come from client_metadata.session_id") + }) + t.Run("plain chat has no session", func(t *testing.T) { + body := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`) + assert.Equal(t, "", OpenAIParser{}.ExtractSessionID(body), "plain chat.completions carries no session id") + }) + t.Run("non-JSON yields empty", func(t *testing.T) { + assert.Equal(t, "", OpenAIParser{}.ExtractSessionID([]byte("not json")), "malformed body must not error") + }) +} + +func TestOpenAIExtractCompletion_ChatCompletions(t *testing.T) { + body, err := os.ReadFile(filepath.Join("fixtures", "openai_chat_completion.json")) + require.NoError(t, err) + got := OpenAIParser{}.ExtractCompletion(200, "application/json", body) + require.NotEmpty(t, got, "fixture has assistant content") +} + +func TestOpenAIExtractCompletion_ResponsesAPI(t *testing.T) { + body, err := os.ReadFile(filepath.Join("fixtures", "openai_responses.json")) + require.NoError(t, err) + got := OpenAIParser{}.ExtractCompletion(200, "application/json", body) + require.NotEmpty(t, got, "responses-api fixture has output content") +} + +func TestOpenAIExtractCompletion_Streaming(t *testing.T) { + got := OpenAIParser{}.ExtractCompletion(200, "text/event-stream", []byte("")) + require.Empty(t, got, "streaming responses are skipped") +} + +func TestOpenAIExtractCompletion_NonOK(t *testing.T) { + got := OpenAIParser{}.ExtractCompletion(500, "application/json", []byte(`{"choices":[{"message":{"content":"x"}}]}`)) + require.Empty(t, got, "non-200 returns empty") +} diff --git a/proxy/internal/llm/parser.go b/proxy/internal/llm/parser.go new file mode 100644 index 000000000..81fa11f97 --- /dev/null +++ b/proxy/internal/llm/parser.go @@ -0,0 +1,112 @@ +// Package llm provides the shared LLM request and response parsing +// library consumed by proxy middleware. It is runtime agnostic: the same +// package is used by the native built-in executor now and will be reused +// by the WASM adapter later. +package llm + +// Provider identifies an LLM API provider. +type Provider int + +const ( + // ProviderUnknown signals that no parser matched the request. + ProviderUnknown Provider = 0 + // ProviderOpenAI identifies the OpenAI API surface. + ProviderOpenAI Provider = 1 + // ProviderAnthropic identifies the Anthropic Messages API surface. + ProviderAnthropic Provider = 2 + // ProviderBedrock identifies the AWS Bedrock runtime surface. + ProviderBedrock Provider = 3 +) + +// RequestFacts captures the subset of the LLM request body that the +// middleware annotates as metadata (model, streaming flag). Additional +// fields are added as parsers grow. +type RequestFacts struct { + Model string + Stream bool +} + +// Usage is the provider-agnostic token accounting emitted to metrics and +// access logs. Downstream consumers map InputTokens/OutputTokens to the +// plg.llm.* metadata allowlist entries. +// +// CachedInputTokens carries OpenAI's prompt_tokens_details.cached_tokens +// (a SUBSET of InputTokens) when the response is from OpenAI, or +// Anthropic's cache_read_input_tokens (ADDITIVE to InputTokens) when from +// Anthropic. The cost meter switches formula on KeyLLMProvider so the +// two shapes are billed correctly without double-counting. +// +// CacheCreationTokens carries Anthropic's cache_creation_input_tokens +// (ADDITIVE; not present in the OpenAI shape). +type Usage struct { + InputTokens int64 + OutputTokens int64 + TotalTokens int64 + CachedInputTokens int64 + CacheCreationTokens int64 +} + +// Parser is the per-provider interface implemented in this package. The +// dispatcher selects a parser by calling DetectFromURL against the incoming +// request path; ties break by registration order (see Parsers). +type Parser interface { + Provider() Provider + ProviderName() string + DetectFromURL(path string) bool + ParseRequest(body []byte) (RequestFacts, error) + ParseResponse(status int, contentType string, body []byte) (Usage, error) + // ExtractPrompt returns the user-facing prompt text from a request body. + // Different endpoint shapes (chat.completions, responses, messages) are + // handled by the per-provider implementation. Returns "" when no prompt + // can be extracted; never returns an error — extraction is best-effort + // because callers use the result for observability, not authorization. + ExtractPrompt(body []byte) string + // ExtractCompletion returns the assistant-facing completion text from a + // non-streaming response body. status and contentType match the + // ParseResponse arguments so implementations can fast-fail uniformly. + ExtractCompletion(status int, contentType string, body []byte) string + // ExtractSessionID returns a stable identifier that groups requests of + // the same conversation / coding session, read from the per-provider + // location clients populate (e.g. OpenAI Codex's client_metadata.session_id, + // Claude Code's metadata.user_id). Returns "" when the body carries no + // recognised session marker; extraction is best-effort and never errors. + ExtractSessionID(body []byte) string +} + +// Parsers returns the built-in parser set in a stable order. The order is +// deterministic so that DetectFromURL ties produce consistent routing. +func Parsers() []Parser { + return []Parser{ + OpenAIParser{}, + AnthropicParser{}, + BedrockParser{}, + } +} + +// DetectParser returns the first parser whose DetectFromURL matches the given +// request path. ok=false means no parser claimed the path. +func DetectParser(path string) (Parser, bool) { + for _, p := range Parsers() { + if p.DetectFromURL(path) { + return p, true + } + } + return nil, false +} + +// ParserByName returns the parser whose ProviderName matches id. Used by +// callers that already know which provider surface a request will hit +// (e.g. the agent-network middleware chain configured per synthesised +// service) so they can skip URL sniffing. ok=false when no parser is +// registered under that name. +func ParserByName(id string) (Parser, bool) { + if id == "" { + return nil, false + } + for _, p := range Parsers() { + if p.ProviderName() == id { + return p, true + } + } + return nil, false +} diff --git a/proxy/internal/llm/parser_test.go b/proxy/internal/llm/parser_test.go new file mode 100644 index 000000000..b3052ce68 --- /dev/null +++ b/proxy/internal/llm/parser_test.go @@ -0,0 +1,54 @@ +package llm + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParsers_ProviderNames(t *testing.T) { + parsers := Parsers() + require.Len(t, parsers, 3, "three built-in parsers expected") + + names := make([]string, 0, len(parsers)) + for _, p := range parsers { + names = append(names, p.ProviderName()) + } + assert.Contains(t, names, "openai", "OpenAI parser should be registered") + assert.Contains(t, names, "anthropic", "Anthropic parser should be registered") + assert.Contains(t, names, "bedrock", "Bedrock parser should be registered") +} + +func TestDetectParser(t *testing.T) { + cases := []struct { + name string + path string + expectedName string + expectOK bool + }{ + {"openai chat", "/v1/chat/completions", "openai", true}, + {"openai prefixed", "/api/v1/chat/completions", "openai", true}, + {"openai responses", "/v1/responses", "openai", true}, + {"anthropic messages", "/v1/messages", "anthropic", true}, + {"anthropic prefixed", "/proxy/v1/messages?query", "anthropic", true}, + {"unknown path", "/healthz", "", false}, + {"empty path", "", "", false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + p, ok := DetectParser(tc.path) + require.Equal(t, tc.expectOK, ok, "detection success mismatch for %q", tc.path) + if ok { + assert.Equal(t, tc.expectedName, p.ProviderName(), "provider name mismatch") + } + }) + } +} + +func TestProviderValues(t *testing.T) { + assert.Equal(t, Provider(0), ProviderUnknown, "unknown provider is the zero value") + assert.Equal(t, ProviderOpenAI, OpenAIParser{}.Provider(), "OpenAI parser reports its provider enum") + assert.Equal(t, ProviderAnthropic, AnthropicParser{}.Provider(), "Anthropic parser reports its provider enum") +} diff --git a/proxy/internal/llm/pricing/defaults_coverage_test.go b/proxy/internal/llm/pricing/defaults_coverage_test.go new file mode 100644 index 000000000..be23682da --- /dev/null +++ b/proxy/internal/llm/pricing/defaults_coverage_test.go @@ -0,0 +1,65 @@ +package pricing + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestDefaultTable_FirstPartyModelCoverage guards the embedded defaults against +// silent drift/gaps: every metered first-party model the management catalog +// enumerates must resolve to a price, and a few rates that previously drifted +// are pinned to their LiteLLM-validated values. Keep this list in step with the +// catalog (management/server/agentnetwork/catalog) when adding models. +func TestDefaultTable_FirstPartyModelCoverage(t *testing.T) { + tbl := DefaultTable() + require.NotNil(t, tbl, "embedded default pricing table must load") + + mustPrice := map[string][]string{ + // openai parser covers openai_api, azure_openai_api, and mistral_api. + "openai": { + "gpt-5.5", "gpt-5.5-pro", "gpt-5.4", "gpt-5.4-mini", "gpt-5.4-nano", + "gpt-5.3-codex", "gpt-5.3-chat-latest", "o4-mini", + "gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano", "gpt-4o", "gpt-4o-mini", + "gpt-4-turbo", "gpt-3.5-turbo", "gpt-35-turbo", + "text-embedding-3-large", "text-embedding-3-small", + "mistral-large-latest", "mistral-medium-3-5", "codestral-2508", + "ministral-8b-latest", "mistral-embed", + }, + "anthropic": { + "claude-fable-5", "claude-opus-4-8", "claude-opus-4-7", "claude-opus-4-6", + "claude-opus-4-1", "claude-sonnet-4-6", "claude-sonnet-4-5", "claude-haiku-4-5", + }, + // bedrock keys are the normalized ids the request parser emits. + "bedrock": { + "anthropic.claude-opus-4-8", "anthropic.claude-opus-4-7", "anthropic.claude-opus-4-6", + "anthropic.claude-opus-4-1", "anthropic.claude-sonnet-4-6", "anthropic.claude-sonnet-4-5", + "anthropic.claude-haiku-4-5", "meta.llama3-3-70b-instruct", + "amazon.nova-pro", "amazon.nova-lite", "amazon.nova-micro", "amazon.nova-2-lite", + }, + } + for provider, models := range mustPrice { + for _, m := range models { + _, ok := tbl.Cost(provider, m, 1000, 1000, 0, 0) + assert.True(t, ok, "%s/%s must be priced in the embedded defaults", provider, m) + } + } + + // Pin per-direction rates independently (input-only then output-only) so a + // swap or skew of input<->output that preserves the combined total is still + // caught — these are rates that previously drifted or are easy to mis-enter. + in, ok := tbl.Cost("openai", "gpt-5.4", 1000, 0, 0, 0) + require.True(t, ok) + assert.InDelta(t, 0.0025, in, 1e-9, "gpt-5.4 input = 0.0025 per 1k") + out, ok := tbl.Cost("openai", "gpt-5.4", 0, 1000, 0, 0) + require.True(t, ok) + assert.InDelta(t, 0.015, out, 1e-9, "gpt-5.4 output = 0.015 per 1k") + + in, ok = tbl.Cost("bedrock", "anthropic.claude-sonnet-4-5", 1000, 0, 0, 0) + require.True(t, ok) + assert.InDelta(t, 0.003, in, 1e-9, "bedrock sonnet-4-5 input = 0.003 per 1k") + out, ok = tbl.Cost("bedrock", "anthropic.claude-sonnet-4-5", 0, 1000, 0, 0) + require.True(t, ok) + assert.InDelta(t, 0.015, out, 1e-9, "bedrock sonnet-4-5 output = 0.015 per 1k") +} diff --git a/proxy/internal/llm/pricing/defaults_pricing.yaml b/proxy/internal/llm/pricing/defaults_pricing.yaml new file mode 100644 index 000000000..cd5c64fbf --- /dev/null +++ b/proxy/internal/llm/pricing/defaults_pricing.yaml @@ -0,0 +1,264 @@ +# Embedded default pricing for llm_observability. Compiled into the proxy +# binary via go:embed in pricing.go; cost annotation works out of the box +# without any operator action. +# +# Operators override entries by dropping a pricing.yaml into --plugin-data-dir +# (or whichever basename is given via params.pricing_path). The override file +# only needs entries the operator wants to change; missing entries fall +# through to these defaults. +# +# Values are USD per 1_000 tokens. Public list prices drift; ship a fresh +# binary or override individual entries via the override file as needed. +# +# Optional cache fields: +# cached_input_per_1k OpenAI: rate for prompt_tokens_details.cached_tokens +# (a SUBSET of prompt_tokens). Typically 0.5x input. +# Absent → cached portion bills at input_per_1k. +# cache_read_per_1k Anthropic: rate for cache_read_input_tokens +# (ADDITIVE to input_tokens). Typically 0.1x input. +# Absent → cache reads bill at input_per_1k. +# cache_creation_per_1k Anthropic: rate for cache_creation_input_tokens +# (ADDITIVE to input_tokens). Typically 1.25x input. +# Absent → cache writes bill at input_per_1k. + +openai: + # OpenAI + OpenAI-compatible providers (openai_api, azure_openai_api, + # mistral_api, and the openai-parser gateways) all emit llm.provider="openai", + # so their models are priced here. Kept in sync with the management catalog; + # rates cross-checked against LiteLLM model_prices_and_context_window.json. + + # GPT-5.x family — cache reads 10% of input (0.1x). + gpt-5.5: + input_per_1k: 0.005 + output_per_1k: 0.03 + cached_input_per_1k: 0.0005 + gpt-5.5-pro: + input_per_1k: 0.03 + output_per_1k: 0.18 + cached_input_per_1k: 0.003 + gpt-5.4: + input_per_1k: 0.0025 + output_per_1k: 0.015 + cached_input_per_1k: 0.00025 + gpt-5.4-pro: + input_per_1k: 0.03 + output_per_1k: 0.18 + cached_input_per_1k: 0.003 + gpt-5.4-mini: + input_per_1k: 0.00075 + output_per_1k: 0.0045 + cached_input_per_1k: 0.000075 + gpt-5.4-nano: + input_per_1k: 0.0002 + output_per_1k: 0.00125 + cached_input_per_1k: 0.00002 + gpt-5.3-codex: + input_per_1k: 0.00175 + output_per_1k: 0.014 + cached_input_per_1k: 0.000175 + gpt-5.3-chat-latest: + input_per_1k: 0.00175 + output_per_1k: 0.014 + cached_input_per_1k: 0.000175 + # GPT-5 (2025) family — kept for gateway requests using the unsuffixed ids. + gpt-5: + input_per_1k: 0.00125 + output_per_1k: 0.01 + cached_input_per_1k: 0.000125 + gpt-5-mini: + input_per_1k: 0.00025 + output_per_1k: 0.002 + cached_input_per_1k: 0.000025 + gpt-5-nano: + input_per_1k: 0.00005 + output_per_1k: 0.0004 + cached_input_per_1k: 0.000005 + o4-mini: + input_per_1k: 0.0011 + output_per_1k: 0.0044 + cached_input_per_1k: 0.000275 + # GPT-4.1 family — cache reads 25% of input. + gpt-4.1: + input_per_1k: 0.002 + output_per_1k: 0.008 + cached_input_per_1k: 0.0005 + gpt-4.1-mini: + input_per_1k: 0.0004 + output_per_1k: 0.0016 + cached_input_per_1k: 0.0001 + gpt-4.1-nano: + input_per_1k: 0.0001 + output_per_1k: 0.0004 + cached_input_per_1k: 0.000025 + # GPT-4o family — cache reads 50% of input (0.5x). + gpt-4o: + input_per_1k: 0.0025 + output_per_1k: 0.01 + cached_input_per_1k: 0.00125 + gpt-4o-mini: + input_per_1k: 0.00015 + output_per_1k: 0.0006 + cached_input_per_1k: 0.000075 + # Older GPT — no prompt caching. + gpt-4-turbo: + input_per_1k: 0.01 + output_per_1k: 0.03 + gpt-3.5-turbo: + input_per_1k: 0.0005 + output_per_1k: 0.0015 + gpt-35-turbo: # Azure deployment alias of gpt-3.5-turbo + input_per_1k: 0.0005 + output_per_1k: 0.0015 + # Embeddings — no caching, no output tokens. + text-embedding-3-large: + input_per_1k: 0.00013 + output_per_1k: 0 + text-embedding-3-small: + input_per_1k: 0.00002 + output_per_1k: 0 + + # Mistral (mistral_api) — routed via the openai parser; no prompt caching. + mistral-large-latest: + input_per_1k: 0.0005 + output_per_1k: 0.0015 + mistral-medium-latest: + input_per_1k: 0.0004 + output_per_1k: 0.002 + mistral-medium-3-5: + input_per_1k: 0.0015 + output_per_1k: 0.0075 + mistral-small-latest: + input_per_1k: 0.00006 + output_per_1k: 0.00018 + magistral-medium-latest: + input_per_1k: 0.002 + output_per_1k: 0.005 + magistral-small-latest: + input_per_1k: 0.0005 + output_per_1k: 0.0015 + devstral-medium-latest: + input_per_1k: 0.0004 + output_per_1k: 0.002 + devstral-small-latest: + input_per_1k: 0.0001 + output_per_1k: 0.0003 + codestral-2508: + input_per_1k: 0.0003 + output_per_1k: 0.0009 + codestral-latest: + input_per_1k: 0.001 + output_per_1k: 0.003 + ministral-3-14b-2512: + input_per_1k: 0.0002 + output_per_1k: 0.0002 + ministral-8b-latest: + input_per_1k: 0.00015 + output_per_1k: 0.00015 + ministral-3-3b-2512: + input_per_1k: 0.0001 + output_per_1k: 0.0001 + mistral-embed: + input_per_1k: 0.0001 + output_per_1k: 0 + +anthropic: + # Claude 4.x family — cache reads ≈10% of input, cache writes ≈125% of input. + # Pricing source: Anthropic's current published rates per million tokens, + # divided by 1000 for the per-1k figures stored here. + claude-fable-5: + input_per_1k: 0.010 + output_per_1k: 0.050 + cache_read_per_1k: 0.001 + cache_creation_per_1k: 0.0125 + claude-opus-4-8: + input_per_1k: 0.005 + output_per_1k: 0.025 + cache_read_per_1k: 0.0005 + cache_creation_per_1k: 0.00625 + claude-opus-4-7: + input_per_1k: 0.005 + output_per_1k: 0.025 + cache_read_per_1k: 0.0005 + cache_creation_per_1k: 0.00625 + claude-opus-4-6: + input_per_1k: 0.005 + output_per_1k: 0.025 + cache_read_per_1k: 0.0005 + cache_creation_per_1k: 0.00625 + claude-opus-4-1: + input_per_1k: 0.015 + output_per_1k: 0.075 + cache_read_per_1k: 0.0015 + cache_creation_per_1k: 0.01875 + claude-sonnet-4-6: + input_per_1k: 0.003 + output_per_1k: 0.015 + cache_read_per_1k: 0.0003 + cache_creation_per_1k: 0.00375 + claude-sonnet-4-5: + input_per_1k: 0.003 + output_per_1k: 0.015 + cache_read_per_1k: 0.0003 + cache_creation_per_1k: 0.00375 + claude-haiku-4-5: + input_per_1k: 0.001 + output_per_1k: 0.005 + cache_read_per_1k: 0.0001 + cache_creation_per_1k: 0.00125 + +bedrock: + # AWS Bedrock model ids, normalised by the request parser (cross-region + # inference-profile prefix + version/throughput suffix stripped), e.g. + # eu.anthropic.claude-sonnet-4-5-20250929-v1:0 -> anthropic.claude-sonnet-4-5. + # Anthropic-on-Bedrock keeps the additive cache buckets (read ≈0.1x input, + # write ≈1.25x input); Nova / Llama report no cache, so cost is input+output. + anthropic.claude-opus-4-8: + input_per_1k: 0.005 + output_per_1k: 0.025 + cache_read_per_1k: 0.0005 + cache_creation_per_1k: 0.00625 + anthropic.claude-opus-4-7: + input_per_1k: 0.005 + output_per_1k: 0.025 + cache_read_per_1k: 0.0005 + cache_creation_per_1k: 0.00625 + anthropic.claude-opus-4-6: + input_per_1k: 0.005 + output_per_1k: 0.025 + cache_read_per_1k: 0.0005 + cache_creation_per_1k: 0.00625 + anthropic.claude-opus-4-1: + input_per_1k: 0.015 + output_per_1k: 0.075 + cache_read_per_1k: 0.0015 + cache_creation_per_1k: 0.01875 + anthropic.claude-sonnet-4-6: + input_per_1k: 0.003 + output_per_1k: 0.015 + cache_read_per_1k: 0.0003 + cache_creation_per_1k: 0.00375 + anthropic.claude-sonnet-4-5: + input_per_1k: 0.003 + output_per_1k: 0.015 + cache_read_per_1k: 0.0003 + cache_creation_per_1k: 0.00375 + anthropic.claude-haiku-4-5: + input_per_1k: 0.001 + output_per_1k: 0.005 + cache_read_per_1k: 0.0001 + cache_creation_per_1k: 0.00125 + meta.llama3-3-70b-instruct: + input_per_1k: 0.00072 + output_per_1k: 0.00072 + amazon.nova-2-lite: + input_per_1k: 0.0003 + output_per_1k: 0.0025 + amazon.nova-pro: + input_per_1k: 0.0008 + output_per_1k: 0.0032 + amazon.nova-lite: + input_per_1k: 0.00006 + output_per_1k: 0.00024 + amazon.nova-micro: + input_per_1k: 0.000035 + output_per_1k: 0.00014 diff --git a/proxy/internal/llm/pricing/pricing.go b/proxy/internal/llm/pricing/pricing.go new file mode 100644 index 000000000..09afec5ff --- /dev/null +++ b/proxy/internal/llm/pricing/pricing.go @@ -0,0 +1,449 @@ +// Package pricing implements the embedded-default + override pricing table +// shared by middleware that converts LLM token usage into a USD cost +// estimate. The table is hot-reloadable from a basename under the proxy +// data directory; missing override files keep the embedded defaults so +// cost annotation works without operator action. +package pricing + +import ( + "bytes" + "context" + _ "embed" + "errors" + "fmt" + "io" + "io/fs" + "math" + "path/filepath" + "regexp" + "strings" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + "gopkg.in/yaml.v3" +) + +//go:embed defaults_pricing.yaml +var defaultPricingYAML []byte + +var ( + defaultTableOnce sync.Once + defaultTablePtr *Table +) + +// DefaultTable returns the pricing table embedded in the binary. The result +// is parsed once and shared; callers must not mutate the returned value. +// Cost annotation works without any operator action because every loader +// starts with this table. +func DefaultTable() *Table { + defaultTableOnce.Do(func() { + t, err := parsePricingBytes(defaultPricingYAML) + if err != nil { + panic(fmt.Sprintf("llmobs: embedded default pricing failed to parse: %v", err)) + } + defaultTablePtr = t + }) + return defaultTablePtr +} + +// mergeOver returns a new Table containing every entry from base, with any +// matching entry from overlay replacing the base value. Either argument may +// be nil. Result is a fresh allocation so callers can mutate / Store safely. +func mergeOver(base, overlay *Table) *Table { + if overlay == nil || len(overlay.entries) == 0 { + return base + } + if base == nil || len(base.entries) == 0 { + return overlay + } + out := make(map[string]map[string]Entry, len(base.entries)) + for provider, models := range base.entries { + inner := make(map[string]Entry, len(models)) + for model, e := range models { + inner[model] = e + } + out[provider] = inner + } + for provider, models := range overlay.entries { + inner, ok := out[provider] + if !ok { + inner = make(map[string]Entry, len(models)) + out[provider] = inner + } + for model, e := range models { + inner[model] = e + } + } + return &Table{entries: out} +} + +// Entry is a single model's input and output pricing, expressed in USD per +// 1000 tokens. +// +// CachedInputPer1K applies to OpenAI's cached prompt tokens, which are a +// subset of input_tokens — when set, the cached portion is billed at this +// rate and the non-cached remainder at InputPer1K. Zero means "no discount +// configured", and cached tokens are billed at InputPer1K (matches current +// behaviour where cached counts weren't extracted at all). +// +// CacheReadPer1K and CacheCreationPer1K apply to Anthropic's two prompt- +// cache fields, which are additive to input_tokens: cache_read is the +// cheaper read-from-cache rate, cache_creation is the more expensive +// write-to-cache rate. Zero means "no rate configured" and the +// corresponding token bucket is billed at InputPer1K. This is more +// accurate than today's behaviour, where Anthropic's cache tokens are +// ignored and not charged at all. +type Entry struct { + InputPer1K float64 + OutputPer1K float64 + CachedInputPer1K float64 + CacheReadPer1K float64 + CacheCreationPer1K float64 +} + +// Table is a provider-to-model pricing lookup. Instances are immutable once +// built and are swapped atomically by Loader. +type Table struct { + entries map[string]map[string]Entry +} + +// Cost returns the estimated USD cost for the given token counts. ok is +// false when the provider or model is not present in the table; the caller +// can still emit token metrics with a model=unknown label. +// +// Provider-shape semantics for cached / cache-creation counts: +// +// - OpenAI: cachedInput is a SUBSET of inTokens. The cached portion is +// billed at CachedInputPer1K (or InputPer1K when no override), and the +// non-cached remainder of inTokens at InputPer1K. cacheCreation is +// ignored (OpenAI has no analogue). +// - Anthropic: cachedInput (cache_read) and cacheCreation are ADDITIVE to +// inTokens. The three buckets are billed at CacheReadPer1K, +// CacheCreationPer1K, and InputPer1K respectively, each falling back +// to InputPer1K when the corresponding rate is zero. +// - Other providers: cached and cacheCreation are ignored; cost is +// inTokens*InputPer1K + outTokens*OutputPer1K. +func (t *Table) Cost(provider, model string, inTokens, outTokens, cachedInput, cacheCreation int64) (float64, bool) { + // Clamp negatives to zero before any pricing math so a malformed + // upstream count can never produce a negative cost. + if inTokens < 0 { + inTokens = 0 + } + if outTokens < 0 { + outTokens = 0 + } + if cachedInput < 0 { + cachedInput = 0 + } + if cacheCreation < 0 { + cacheCreation = 0 + } + if t == nil { + return 0, false + } + byModel, ok := t.entries[provider] + if !ok { + return 0, false + } + entry, ok := byModel[model] + if !ok { + return 0, false + } + output := (float64(outTokens) / 1000.0) * entry.OutputPer1K + switch provider { + case "openai": + // cachedInput is a subset of inTokens; clamp so a malformed + // upstream (cached > total) can't produce a negative remainder. + clamped := cachedInput + if clamped > inTokens { + clamped = inTokens + } + cachedRate := entry.CachedInputPer1K + if cachedRate <= 0 { + cachedRate = entry.InputPer1K + } + nonCached := float64(inTokens-clamped) / 1000.0 * entry.InputPer1K + cached := float64(clamped) / 1000.0 * cachedRate + return nonCached + cached + output, true + case "anthropic", "bedrock": + // Bedrock-Anthropic returns the same additive cache buckets as + // first-party Anthropic; non-Anthropic Bedrock models simply report + // zero cache tokens, so this formula degrades to input + output. + readRate := entry.CacheReadPer1K + if readRate <= 0 { + readRate = entry.InputPer1K + } + createRate := entry.CacheCreationPer1K + if createRate <= 0 { + createRate = entry.InputPer1K + } + input := float64(inTokens) / 1000.0 * entry.InputPer1K + read := float64(cachedInput) / 1000.0 * readRate + create := float64(cacheCreation) / 1000.0 * createRate + return input + read + create + output, true + default: + input := float64(inTokens) / 1000.0 * entry.InputPer1K + return input + output, true + } +} + +// Has reports whether the provider/model pair is present in the table. +func (t *Table) Has(provider, model string) bool { + if t == nil { + return false + } + byModel, ok := t.entries[provider] + if !ok { + return false + } + _, ok = byModel[model] + return ok +} + +// pricingFile mirrors the on-disk YAML schema. Keys are provider names; the +// nested map keys are model names. +type pricingFile map[string]map[string]struct { + InputPer1K float64 `yaml:"input_per_1k"` + OutputPer1K float64 `yaml:"output_per_1k"` + CachedInputPer1K float64 `yaml:"cached_input_per_1k"` + CacheReadPer1K float64 `yaml:"cache_read_per_1k"` + CacheCreationPer1K float64 `yaml:"cache_creation_per_1k"` +} + +const ( + // ReloadInterval is the mtime-poll cadence for the background reloader. + ReloadInterval = 30 * time.Second + + // errorBackoff bounds how often the loader logs a repeated parse error. + errorBackoff = 5 * time.Minute +) + +var basenameRegex = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) + +// Loader is a confined, hot-reloadable pricing table reader. Construction +// must succeed against the target file; subsequent reload failures keep the +// previously-loaded table so callers never observe a blank price list. +type Loader struct { + baseDir string + fullPath string + pluginID string + table atomic.Pointer[Table] + mtime atomic.Int64 + failures metric.Int64Counter + interval time.Duration +} + +// NewLoader returns a pricing loader that overlays an optional file-based +// table on top of the embedded defaults. Missing override file, baseDir, or +// relPath is not an error: the loader keeps the embedded defaults so cost +// metadata is still emitted for known models. +// +// Errors: +// - bad basename, traversal segment, or absolute relPath are rejected so a +// misconfigured target surfaces immediately. +// - permission errors and YAML parse errors keep the defaults but log a +// warning; cost annotation does not silently break. +// +// failures is optional; pass nil in tests that do not care about +// reload-failure telemetry. +func NewLoader(baseDir, relPath, pluginID string, failures metric.Int64Counter) (*Loader, error) { + defaults := DefaultTable() + l := &Loader{ + baseDir: baseDir, + pluginID: pluginID, + failures: failures, + } + l.table.Store(defaults) + + if strings.TrimSpace(baseDir) == "" || strings.TrimSpace(relPath) == "" { + return l, nil + } + + full, err := resolveMiddlewareDataPath(baseDir, relPath) + if err != nil { + return nil, err + } + l.fullPath = full + + overlay, mtime, err := loadPricing(full) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + // Override file is optional. Defaults already stored. + return l, nil + } + // Symlink rejection, oversize file, parse failure, permission errors + // — surface so a misconfigured operator sees the problem instead of + // silently running with stale defaults. + return nil, fmt.Errorf("load pricing %s: %w", full, err) + } + l.table.Store(mergeOver(defaults, overlay)) + l.mtime.Store(mtime.UnixNano()) + return l, nil +} + +// Get returns the current pricing table. The returned pointer is immutable; +// callers must not mutate its contents. +func (l *Loader) Get() *Table { + if l == nil { + return nil + } + return l.table.Load() +} + +// WatchesFile reports whether this loader is bound to an override file on +// disk. False for defaults-only loaders (no operator override given). +// Callers use this to decide whether to spawn the mtime-poll goroutine. +func (l *Loader) WatchesFile() bool { + if l == nil { + return false + } + return l.fullPath != "" +} + +// SetReloadInterval overrides the mtime-poll cadence used by Reload. Calls +// after Reload has started have no effect on the running loop. Intended for +// tests; production code uses the default ReloadInterval. +func (l *Loader) SetReloadInterval(d time.Duration) { + if l == nil || d <= 0 { + return + } + l.interval = d +} + +// Reload runs a polling loop that checks the pricing file mtime every +// ReloadInterval (or the value passed to SetReloadInterval). Returns when +// ctx is cancelled. +func (l *Loader) Reload(ctx context.Context) { + if l == nil { + return + } + interval := l.interval + if interval <= 0 { + interval = ReloadInterval + } + t := time.NewTicker(interval) + defer t.Stop() + + var lastErrAt time.Time + for { + select { + case <-ctx.Done(): + return + case <-t.C: + if err := l.reload(); err != nil { + if l.failures != nil { + l.failures.Add(ctx, 1, metric.WithAttributes( + attribute.String("plugin", l.pluginID), + )) + } + now := time.Now() + if now.Sub(lastErrAt) >= errorBackoff { + log.Warnf("llmobs: pricing reload failed for %s: %v", l.fullPath, err) + lastErrAt = now + } + } + } + } +} + +// reload performs a single-shot mtime check and reload. The reloaded +// override file is merged on top of the embedded defaults; missing override +// (e.g. operator deleted the file) is not an error and reverts to defaults. +func (l *Loader) reload() error { + if l.fullPath == "" { + // Defaults-only loader; nothing on disk to reload. + return nil + } + mtime, err := statMtime(l.fullPath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + // File was removed since startup. Drop back to defaults and + // reset mtime so a future re-creation triggers a reload. + l.table.Store(DefaultTable()) + l.mtime.Store(0) + return nil + } + return err + } + if mtime.UnixNano() == l.mtime.Load() { + return nil + } + + overlay, newMtime, err := loadPricing(l.fullPath) + if err != nil { + return err + } + l.table.Store(mergeOver(DefaultTable(), overlay)) + l.mtime.Store(newMtime.UnixNano()) + return nil +} + +// resolveMiddlewareDataPath validates relPath is a safe basename and resolves +// it under baseDir. An additional cleaned-prefix check guards against +// CVE-style edge cases where Join is used with trailing path segments. +func resolveMiddlewareDataPath(baseDir, relPath string) (string, error) { + if strings.TrimSpace(baseDir) == "" { + return "", errors.New("middleware-data-dir is not configured") + } + if relPath == "" { + return "", errors.New("pricing path is empty") + } + if !basenameRegex.MatchString(relPath) { + return "", fmt.Errorf("pricing path %q is not a safe basename", relPath) + } + if filepath.IsAbs(relPath) { + return "", fmt.Errorf("pricing path %q must be a basename, not absolute", relPath) + } + + cleanBase, err := filepath.Abs(filepath.Clean(baseDir)) + if err != nil { + return "", fmt.Errorf("resolve middleware-data-dir: %w", err) + } + full := filepath.Join(cleanBase, relPath) + cleanedFull := filepath.Clean(full) + if !strings.HasPrefix(cleanedFull, cleanBase+string(filepath.Separator)) && cleanedFull != cleanBase { + return "", fmt.Errorf("pricing path %q escapes middleware-data-dir", relPath) + } + return cleanedFull, nil +} + +func parsePricingBytes(data []byte) (*Table, error) { + dec := yaml.NewDecoder(bytes.NewReader(data)) + dec.KnownFields(true) + + var raw pricingFile + if err := dec.Decode(&raw); err != nil && !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("decode pricing yaml: %w", err) + } + + out := make(map[string]map[string]Entry, len(raw)) + for provider, models := range raw { + inner := make(map[string]Entry, len(models)) + for model, entry := range models { + for field, v := range map[string]float64{ + "input_per_1k": entry.InputPer1K, + "output_per_1k": entry.OutputPer1K, + "cached_input_per_1k": entry.CachedInputPer1K, + "cache_read_per_1k": entry.CacheReadPer1K, + "cache_creation_per_1k": entry.CacheCreationPer1K, + } { + if v < 0 || math.IsNaN(v) || math.IsInf(v, 0) { + return nil, fmt.Errorf("pricing %s/%s: %s must be a finite, non-negative rate, got %v", provider, model, field, v) + } + } + inner[model] = Entry{ + InputPer1K: entry.InputPer1K, + OutputPer1K: entry.OutputPer1K, + CachedInputPer1K: entry.CachedInputPer1K, + CacheReadPer1K: entry.CacheReadPer1K, + CacheCreationPer1K: entry.CacheCreationPer1K, + } + } + out[provider] = inner + } + return &Table{entries: out}, nil +} diff --git a/proxy/internal/llm/pricing/pricing_other.go b/proxy/internal/llm/pricing/pricing_other.go new file mode 100644 index 000000000..e65fffff1 --- /dev/null +++ b/proxy/internal/llm/pricing/pricing_other.go @@ -0,0 +1,20 @@ +//go:build !unix + +package pricing + +import ( + "fmt" + "time" +) + +// loadPricing is unavailable on non-Unix platforms because O_NOFOLLOW and +// fstat-from-FD are required to honour the spec's symlink-safety rules. The +// proxy is only deployed on Linux today; a Windows port would need an +// equivalent path-as-handle implementation. +func loadPricing(path string) (*Table, time.Time, error) { + return nil, time.Time{}, fmt.Errorf("llmobs pricing loader is not supported on this platform: %s", path) +} + +func statMtime(path string) (time.Time, error) { + return time.Time{}, fmt.Errorf("llmobs pricing loader is not supported on this platform: %s", path) +} diff --git a/proxy/internal/llm/pricing/pricing_test.go b/proxy/internal/llm/pricing/pricing_test.go new file mode 100644 index 000000000..7ac2a85dc --- /dev/null +++ b/proxy/internal/llm/pricing/pricing_test.go @@ -0,0 +1,432 @@ +//go:build unix + +package pricing + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func copyFixture(t *testing.T, src, dst string) { + t.Helper() + data, err := os.ReadFile(src) + require.NoError(t, err, "read source fixture") + require.NoError(t, os.WriteFile(dst, data, 0o600), "write target fixture") +} + +func TestNewLoader_HappyPath(t *testing.T) { + base := t.TempDir() + copyFixture(t, filepath.Join("..", "fixtures", "pricing.yaml"), filepath.Join(base, "pricing.yaml")) + + l, err := NewLoader(base, "pricing.yaml", "llm_observability", nil) + require.NoError(t, err, "NewLoader must succeed with a valid fixture") + table := l.Get() + require.NotNil(t, table, "table populated after load") + + cost, ok := table.Cost("openai", "gpt-4o-mini", 1000, 1000, 0, 0) + require.True(t, ok, "known provider/model resolves") + assert.InDelta(t, 0.00075, cost, 1e-9, "cost = 0.00015 + 0.0006 per 1k tokens") + + cost, ok = table.Cost("openai", "gpt-4o", 2000, 1000, 0, 0) + require.True(t, ok, "second known model resolves") + assert.InDelta(t, 0.015, cost, 1e-9, "cost for gpt-4o: 2*0.0025 + 1*0.01") + + cost, ok = table.Cost("anthropic", "claude-sonnet-4-5", 1000, 1000, 0, 0) + require.True(t, ok, "anthropic model resolves") + assert.InDelta(t, 0.018, cost, 1e-9, "cost for claude-sonnet-4-5: 0.003 + 0.015") +} + +// TestCost_OpenAICachedSubsetDiscount proves OpenAI's cached input +// tokens are billed at the configured cached_input_per_1k rate while +// the non-cached remainder of input_tokens is billed at the regular +// rate. Critical because OpenAI returns cached_tokens as a SUBSET of +// prompt_tokens — naïvely charging the cached count on top of +// prompt_tokens would double-bill that portion. +func TestCost_OpenAICachedSubsetDiscount(t *testing.T) { + tbl := &Table{entries: map[string]map[string]Entry{ + "openai": {"gpt-4o": { + InputPer1K: 0.0025, // 0.0025 USD per 1k input tokens + OutputPer1K: 0.01, + CachedInputPer1K: 0.00125, // 0.5x discount on cached + }}, + }} + // 1000 prompt tokens, 750 of which were cached. 250 non-cached + // at regular rate, 750 cached at the discount rate, 500 output. + cost, ok := tbl.Cost("openai", "gpt-4o", 1000, 500, 750, 0) + require.True(t, ok, "known model resolves") + want := (250.0/1000.0)*0.0025 + (750.0/1000.0)*0.00125 + (500.0/1000.0)*0.01 + assert.InDelta(t, want, cost, 1e-12, + "cached subset must bill at the discount rate; non-cached remainder at regular rate") +} + +// TestCost_OpenAICachedFallsBackToInputRate covers the operator +// opt-in contract: when CachedInputPer1K is unset (zero), cached +// tokens bill at the regular input rate. This matches today's +// behaviour (cached counts weren't extracted at all so they +// implicitly billed at the input rate via prompt_tokens). +func TestCost_OpenAICachedFallsBackToInputRate(t *testing.T) { + tbl := &Table{entries: map[string]map[string]Entry{ + "openai": {"gpt-4o": {InputPer1K: 0.0025, OutputPer1K: 0.01}}, + }} + cost, ok := tbl.Cost("openai", "gpt-4o", 1000, 500, 750, 0) + require.True(t, ok) + want := 0.0025 + (500.0/1000.0)*0.01 + assert.InDelta(t, want, cost, 1e-12, + "absent cached_input_per_1k rate must fall back to input_per_1k — same as pre-feature behaviour") +} + +// TestCost_OpenAIClampsCachedToInputCount is the defensive guard +// against malformed upstream responses that report cached_tokens > +// prompt_tokens. We clamp so the formula never produces a negative +// "non-cached remainder" multiplied by the input rate. +func TestCost_OpenAIClampsCachedToInputCount(t *testing.T) { + tbl := &Table{entries: map[string]map[string]Entry{ + "openai": {"gpt-4o": {InputPer1K: 0.0025, OutputPer1K: 0.01, CachedInputPer1K: 0.00125}}, + }} + cost, ok := tbl.Cost("openai", "gpt-4o", 100, 0, 9999, 0) + require.True(t, ok) + // All 100 cached, 0 non-cached. Output is 0. + want := (100.0 / 1000.0) * 0.00125 + assert.InDelta(t, want, cost, 1e-12, + "cached count > input count must clamp to input — never bill negative non-cached tokens") +} + +// TestCost_AnthropicCacheReadAndCreationAreAdditive proves the +// Anthropic shape: cache_read and cache_creation tokens are +// ADDITIVE to input_tokens (not subset), each billed at its own +// configured rate. The two rates pull in opposite directions — +// cache_read is the cheaper read-from-cache rate (≈0.1× input), +// cache_creation is the more expensive write-to-cache rate +// (≈1.25× input). +func TestCost_AnthropicCacheReadAndCreationAreAdditive(t *testing.T) { + tbl := &Table{entries: map[string]map[string]Entry{ + "anthropic": {"claude-sonnet": { + InputPer1K: 0.003, + OutputPer1K: 0.015, + CacheReadPer1K: 0.0003, // 0.1x of input + CacheCreationPer1K: 0.00375, // 1.25x of input + }}, + }} + // 256 regular input + 768 cache_read + 512 cache_creation + + // 200 output. Each input bucket bills at its own rate. + cost, ok := tbl.Cost("anthropic", "claude-sonnet", 256, 200, 768, 512) + require.True(t, ok, "known model resolves") + want := (256.0/1000.0)*0.003 + + (768.0/1000.0)*0.0003 + + (512.0/1000.0)*0.00375 + + (200.0/1000.0)*0.015 + assert.InDelta(t, want, cost, 1e-12, + "each Anthropic input bucket must bill at its own configured rate") +} + +// TestCost_AnthropicCacheRatesFallBackToInput covers the no-opt-in +// path: when neither CacheReadPer1K nor CacheCreationPer1K is set, +// cache tokens bill at the regular input rate. This is more +// accurate than today's behaviour (cache tokens ignored entirely) +// without requiring operators to opt in via YAML. +func TestCost_AnthropicCacheRatesFallBackToInput(t *testing.T) { + tbl := &Table{entries: map[string]map[string]Entry{ + "anthropic": {"claude-sonnet": {InputPer1K: 0.003, OutputPer1K: 0.015}}, + }} + cost, ok := tbl.Cost("anthropic", "claude-sonnet", 256, 200, 768, 512) + require.True(t, ok) + // Without overrides: every input bucket at input_per_1k. + want := ((256.0+768.0+512.0)/1000.0)*0.003 + (200.0/1000.0)*0.015 + assert.InDelta(t, want, cost, 1e-12, + "absent cache rates must fall back to input_per_1k — Anthropic cache tokens were ignored before this change, billing at input rate is more accurate as a default") +} + +func TestNewLoader_UnknownModel(t *testing.T) { + base := t.TempDir() + copyFixture(t, filepath.Join("..", "fixtures", "pricing.yaml"), filepath.Join(base, "pricing.yaml")) + + l, err := NewLoader(base, "pricing.yaml", "llm_observability", nil) + require.NoError(t, err) + + _, ok := l.Get().Cost("openai", "fantasy-model", 10, 10, 0, 0) + assert.False(t, ok, "unknown model returns ok=false") + + _, ok = l.Get().Cost("cohere", "anything", 10, 10, 0, 0) + assert.False(t, ok, "unknown provider returns ok=false") +} + +func TestNewLoader_InvalidYAMLRejected(t *testing.T) { + base := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(base, "pricing.yaml"), []byte("\t- this is not: valid: yaml: :["), 0o600)) + + _, err := NewLoader(base, "pricing.yaml", "llm_observability", nil) + require.Error(t, err, "invalid YAML must surface as construction error") +} + +func TestLoader_ReloadKeepsPreviousOnParseError(t *testing.T) { + base := t.TempDir() + target := filepath.Join(base, "pricing.yaml") + copyFixture(t, filepath.Join("..", "fixtures", "pricing.yaml"), target) + + l, err := NewLoader(base, "pricing.yaml", "llm_observability", nil) + require.NoError(t, err) + require.NotNil(t, l.Get(), "initial table populated") + + // Overwrite with content that violates the strict schema (extra field) + // plus a bumped mtime to trigger reload. + require.NoError(t, os.WriteFile(target, []byte("openai:\n gpt-4o:\n input_per_1k: 1.0\n output_per_1k: 2.0\n bogus_field: nope\n"), 0o600)) + future := time.Now().Add(time.Hour) + require.NoError(t, os.Chtimes(target, future, future)) + + err = l.reload() + require.Error(t, err, "parse error surfaced by reload()") + + cost, ok := l.Get().Cost("openai", "gpt-4o-mini", 1000, 1000, 0, 0) + require.True(t, ok, "previous table still available after parse failure") + assert.InDelta(t, 0.00075, cost, 1e-9, "previous cost preserved") +} + +func TestLoader_ReloadNoChangeIsNoOp(t *testing.T) { + base := t.TempDir() + target := filepath.Join(base, "pricing.yaml") + copyFixture(t, filepath.Join("..", "fixtures", "pricing.yaml"), target) + + l, err := NewLoader(base, "pricing.yaml", "llm_observability", nil) + require.NoError(t, err) + ptrBefore := l.Get() + + require.NoError(t, l.reload(), "no-change reload must not error") + ptrAfter := l.Get() + assert.Same(t, ptrBefore, ptrAfter, "table pointer unchanged when mtime unchanged") +} + +func TestLoader_ReloadDetectsChange(t *testing.T) { + base := t.TempDir() + target := filepath.Join(base, "pricing.yaml") + copyFixture(t, filepath.Join("..", "fixtures", "pricing.yaml"), target) + + l, err := NewLoader(base, "pricing.yaml", "llm_observability", nil) + require.NoError(t, err) + + updated := []byte("openai:\n gpt-4o-mini:\n input_per_1k: 1.00\n output_per_1k: 2.00\n") + require.NoError(t, os.WriteFile(target, updated, 0o600)) + future := time.Now().Add(time.Hour) + require.NoError(t, os.Chtimes(target, future, future)) + + require.NoError(t, l.reload(), "reload must succeed on valid new content") + + cost, ok := l.Get().Cost("openai", "gpt-4o-mini", 1000, 1000, 0, 0) + require.True(t, ok, "updated model still present") + assert.InDelta(t, 3.0, cost, 0.0001, "new prices are applied: 1 + 2 per 1k") +} + +// TestLoader_ReloadGoroutinePicksUpChanges proves the background goroutine +// started via Reload actually swaps the pricing table when the file changes +// on disk. Without that goroutine running, pricing edits would never reach +// requests until a proxy restart. +func TestLoader_ReloadGoroutinePicksUpChanges(t *testing.T) { + base := t.TempDir() + target := filepath.Join(base, "pricing.yaml") + copyFixture(t, filepath.Join("..", "fixtures", "pricing.yaml"), target) + + l, err := NewLoader(base, "pricing.yaml", "llm_observability", nil) + require.NoError(t, err) + l.SetReloadInterval(20 * time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan struct{}) + go func() { + l.Reload(ctx) + close(done) + }() + + // Before any rewrite, the loader holds the fixture's prices. + costBefore, ok := l.Get().Cost("openai", "gpt-4o-mini", 1000, 1000, 0, 0) + require.True(t, ok, "fixture model must resolve initially") + assert.InDelta(t, 0.00075, costBefore, 1e-9, "fixture prices apply before rewrite") + + updated := []byte("openai:\n gpt-4o-mini:\n input_per_1k: 1.00\n output_per_1k: 2.00\n") + require.NoError(t, os.WriteFile(target, updated, 0o600)) + future := time.Now().Add(time.Hour) + require.NoError(t, os.Chtimes(target, future, future)) + + deadline := time.Now().Add(2 * time.Second) + for { + cost, ok := l.Get().Cost("openai", "gpt-4o-mini", 1000, 1000, 0, 0) + if ok && cost > 2.5 { + break + } + if time.Now().After(deadline) { + t.Fatalf("background reloader did not pick up rewrite within deadline") + } + time.Sleep(10 * time.Millisecond) + } + + cancel() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("Reload loop did not exit after cancel") + } +} + +func TestLoader_ReloadBackgroundLoopCancellation(t *testing.T) { + base := t.TempDir() + copyFixture(t, filepath.Join("..", "fixtures", "pricing.yaml"), filepath.Join(base, "pricing.yaml")) + l, err := NewLoader(base, "pricing.yaml", "llm_observability", nil) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + l.Reload(ctx) + close(done) + }() + cancel() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("Reload loop did not exit on context cancel") + } +} + +func TestNewLoader_PathValidation(t *testing.T) { + base := t.TempDir() + copyFixture(t, filepath.Join("..", "fixtures", "pricing.yaml"), filepath.Join(base, "pricing.yaml")) + + cases := []struct { + name string + relPath string + }{ + {"traversal", "../../etc/passwd"}, + {"absolute", "/etc/passwd"}, + {"slash in basename", "sub/pricing.yaml"}, + {"control chars", "pricing\x00.yaml"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, err := NewLoader(base, tc.relPath, "llm_observability", nil) + require.Error(t, err, "NewLoader must reject %q", tc.relPath) + }) + } + + // Empty relPath is no longer a validation error: the loader treats it + // as "no override file, defaults only" so cost metadata is still + // emitted for the embedded models out of the box. + t.Run("empty falls back to defaults", func(t *testing.T) { + l, err := NewLoader(base, "", "llm_observability", nil) + require.NoError(t, err, "empty relPath should yield a defaults-only loader") + require.NotNil(t, l, "loader must be returned") + require.False(t, l.WatchesFile(), "no file watching when no override is given") + _, ok := l.Get().Cost("openai", "gpt-4o-mini", 1000, 1000, 0, 0) + assert.True(t, ok, "embedded defaults should still resolve gpt-4o-mini") + }) +} + +// TestNewLoader_PathValidation_Extended covers the remaining attack shapes +// called out in C2: dot references, embedded traversal segments, and a +// newline in the basename. The basename regex must reject each one even +// though filepath.Clean would otherwise collapse them. +func TestNewLoader_PathValidation_Extended(t *testing.T) { + base := t.TempDir() + copyFixture(t, filepath.Join("..", "fixtures", "pricing.yaml"), filepath.Join(base, "pricing.yaml")) + + cases := []struct { + name string + relPath string + }{ + {"dot", "."}, + {"dotdot", ".."}, + {"relative traversal", "../pricing.yaml"}, + {"embedded slash", "pri/cing.yaml"}, + {"newline", "pricing\n.yaml"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, err := NewLoader(base, tc.relPath, "llm_observability", nil) + require.Error(t, err, "NewLoader must reject %q", tc.relPath) + }) + } +} + +// TestNewLoader_ValidBasenameLoads proves the allowlist is exclusive: a +// basename containing only safe characters under baseDir loads. Without this +// a regression that over-tightened the regex would silently break valid +// deployments. +func TestNewLoader_ValidBasenameLoads(t *testing.T) { + base := t.TempDir() + copyFixture(t, filepath.Join("..", "fixtures", "pricing.yaml"), filepath.Join(base, "pricing-v2_prod.yaml")) + + l, err := NewLoader(base, "pricing-v2_prod.yaml", "llm_observability", nil) + require.NoError(t, err, "basename with _, -, . must load") + require.NotNil(t, l.Get(), "table populated") +} + +// TestNewLoader_SymlinkOutsideBaseDirRejected constructs a symlink under +// baseDir that points to a file outside it. O_NOFOLLOW must refuse to open +// the symlink even though the symlink path itself is a valid basename under +// baseDir. +func TestNewLoader_SymlinkOutsideBaseDirRejected(t *testing.T) { + outside := t.TempDir() + target := filepath.Join(outside, "evil.yaml") + copyFixture(t, filepath.Join("..", "fixtures", "pricing.yaml"), target) + + base := t.TempDir() + link := filepath.Join(base, "pricing.yaml") + require.NoError(t, os.Symlink(target, link), "symlink setup") + + _, err := NewLoader(base, "pricing.yaml", "llm_observability", nil) + require.Error(t, err, "O_NOFOLLOW must reject symlink even when it points outside baseDir") +} + +func TestNewLoader_SymlinkRejected(t *testing.T) { + base := t.TempDir() + concrete := filepath.Join(base, "real.yaml") + copyFixture(t, filepath.Join("..", "fixtures", "pricing.yaml"), concrete) + + link := filepath.Join(base, "pricing.yaml") + require.NoError(t, os.Symlink(concrete, link), "symlink setup") + + _, err := NewLoader(base, "pricing.yaml", "llm_observability", nil) + require.Error(t, err, "O_NOFOLLOW must reject symlinked targets") +} + +func TestTableCost_NilSafe(t *testing.T) { + var t1 *Table + cost, ok := t1.Cost("x", "y", 1, 1, 0, 0) + assert.False(t, ok, "nil table reports unknown") + assert.Zero(t, cost, "nil table returns zero cost") + assert.False(t, t1.Has("x", "y"), "nil table has nothing") +} + +func TestLoaderGet_NilSafe(t *testing.T) { + var l *Loader + assert.Nil(t, l.Get(), "nil loader returns nil table") +} + +// TestNewLoader_RejectsOversizedFile_FixesM4 proves the loader bounds reads +// at maxPricingBytes so a hostile file cannot exhaust process memory. +func TestNewLoader_RejectsOversizedFile_FixesM4(t *testing.T) { + base := t.TempDir() + target := filepath.Join(base, "pricing.yaml") + + // Build a YAML payload larger than the cap. We pad with valid YAML + // comments so a partial read would still fail the size check rather + // than the parser. + header := "openai:\n" + bigComment := make([]byte, maxPricingBytes+1024) + for i := range bigComment { + bigComment[i] = ' ' + } + bigComment[0] = '#' + bigComment[len(bigComment)-1] = '\n' + payload := append([]byte(header), bigComment...) + require.NoError(t, os.WriteFile(target, payload, 0o600)) + + _, err := NewLoader(base, "pricing.yaml", "llm_observability", nil) + require.Error(t, err, "oversized pricing file must be rejected") + assert.Contains(t, err.Error(), "exceeds", "rejection must reference the byte cap") +} diff --git a/proxy/internal/llm/pricing/pricing_unix.go b/proxy/internal/llm/pricing/pricing_unix.go new file mode 100644 index 000000000..4f3ea33a2 --- /dev/null +++ b/proxy/internal/llm/pricing/pricing_unix.go @@ -0,0 +1,68 @@ +//go:build unix + +package pricing + +import ( + "fmt" + "io" + "os" + "syscall" + "time" + + log "github.com/sirupsen/logrus" +) + +// maxPricingBytes caps the size of the pricing YAML on read so a hostile or +// runaway file cannot exhaust process memory during reload. 1 MiB is several +// orders of magnitude larger than any reasonable pricing table. +const maxPricingBytes int64 = 1 << 20 + +// loadPricing opens the file with O_NOFOLLOW, fstats the open descriptor, +// and parses from that same descriptor. Never re-opens by path so a +// mid-read rename or symlink swap cannot substitute content. Bytes are +// capped at maxPricingBytes so the loader cannot be coerced into reading an +// unbounded file. +func loadPricing(path string) (*Table, time.Time, error) { + f, err := os.OpenFile(path, os.O_RDONLY|syscall.O_NOFOLLOW, 0) + if err != nil { + return nil, time.Time{}, fmt.Errorf("open %s: %w", path, err) + } + defer func() { + if cerr := f.Close(); cerr != nil { + log.Debugf("close pricing file %s: %v", path, cerr) + } + }() + + info, err := f.Stat() + if err != nil { + return nil, time.Time{}, fmt.Errorf("fstat %s: %w", path, err) + } + if !info.Mode().IsRegular() { + return nil, time.Time{}, fmt.Errorf("pricing file %s is not a regular file", path) + } + + data, err := io.ReadAll(io.LimitReader(f, maxPricingBytes+1)) + if err != nil { + return nil, time.Time{}, fmt.Errorf("read %s: %w", path, err) + } + if int64(len(data)) > maxPricingBytes { + return nil, time.Time{}, fmt.Errorf("pricing file %s exceeds %d bytes", path, maxPricingBytes) + } + + table, err := parsePricingBytes(data) + if err != nil { + return nil, time.Time{}, err + } + return table, info.ModTime(), nil +} + +// statMtime returns the mtime of the file at path. It uses lstat semantics +// via os.Lstat so a symlink swap is detected even though O_NOFOLLOW will +// later reject the open. +func statMtime(path string) (time.Time, error) { + info, err := os.Lstat(path) + if err != nil { + return time.Time{}, fmt.Errorf("lstat %s: %w", path, err) + } + return info.ModTime(), nil +} diff --git a/proxy/internal/llm/sse.go b/proxy/internal/llm/sse.go new file mode 100644 index 000000000..3d33ab577 --- /dev/null +++ b/proxy/internal/llm/sse.go @@ -0,0 +1,117 @@ +package llm + +import ( + "bufio" + "errors" + "fmt" + "io" + "strings" +) + +// Event represents a single server-sent event. Type is the dispatch name +// carried on an "event:" line (empty when the stream uses only "data:" +// lines). Data is the concatenation of every "data:" line that made up the +// event, joined by a single newline. +type Event struct { + Type string + Data string +} + +// Scanner reads SSE events from an underlying byte stream. Events are +// delimited by a blank line ("\n\n"). CRLF line endings are normalized to LF +// transparently so fixtures captured from live servers can be replayed. +// +// Scanner is not safe for concurrent use. +type Scanner struct { + r *bufio.Reader + maxLine int +} + +// NewScanner wraps the given reader. The default underlying buffer size is +// large enough for typical provider events (~64 KiB); callers needing +// larger events can wrap the reader in their own bufio.Reader beforehand. +func NewScanner(r io.Reader) *Scanner { + return &Scanner{ + r: bufio.NewReaderSize(r, 64*1024), + maxLine: 1 << 20, + } +} + +// Next returns the next event. It returns io.EOF after the final event has +// been consumed. A trailing event that is not terminated by a blank line is +// still returned before io.EOF so that servers which close the connection +// without a trailing newline are handled correctly. +func (s *Scanner) Next() (Event, error) { + var ( + event Event + dataBuf strings.Builder + hasData bool + hasAny bool + ) + + for { + line, err := s.readLine() + if err != nil { + if errors.Is(err, io.EOF) && hasAny { + event.Data = dataBuf.String() + return event, nil + } + return Event{}, err + } + + if line == "" { + if !hasAny { + continue + } + event.Data = dataBuf.String() + return event, nil + } + + hasAny = true + if strings.HasPrefix(line, ":") { + continue + } + + field, value := splitField(line) + switch field { + case "event": + event.Type = value + case "data": + if hasData { + dataBuf.WriteByte('\n') + } + dataBuf.WriteString(value) + hasData = true + } + } +} + +func (s *Scanner) readLine() (string, error) { + line, err := s.r.ReadString('\n') + if err != nil { + if errors.Is(err, io.EOF) && line != "" { + return trimEOL(line), nil + } + return "", err + } + if len(line) > s.maxLine { + return "", fmt.Errorf("sse line exceeds %d bytes", s.maxLine) + } + return trimEOL(line), nil +} + +func trimEOL(line string) string { + line = strings.TrimRight(line, "\n") + line = strings.TrimRight(line, "\r") + return line +} + +func splitField(line string) (string, string) { + idx := strings.IndexByte(line, ':') + if idx < 0 { + return line, "" + } + field := line[:idx] + value := strings.TrimPrefix(line[idx+1:], " ") + return field, value +} diff --git a/proxy/internal/llm/sse_test.go b/proxy/internal/llm/sse_test.go new file mode 100644 index 000000000..96cecc111 --- /dev/null +++ b/proxy/internal/llm/sse_test.go @@ -0,0 +1,175 @@ +package llm + +import ( + "errors" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func collectEvents(t *testing.T, r io.Reader) []Event { + t.Helper() + s := NewScanner(r) + var out []Event + for { + ev, err := s.Next() + if errors.Is(err, io.EOF) { + return out + } + require.NoError(t, err, "unexpected error scanning SSE") + out = append(out, ev) + } +} + +func TestSSEScanner_OpenAIFixture(t *testing.T) { + f, err := os.Open(filepath.Join("fixtures", "openai_stream.txt")) + require.NoError(t, err, "fixture must be openable") + defer f.Close() + + events := collectEvents(t, f) + require.Len(t, events, 4, "expected 4 data frames (3 chunks + [DONE])") + + for _, ev := range events { + assert.Empty(t, ev.Type, "OpenAI stream uses data-only frames") + } + assert.Contains(t, events[2].Data, `"usage"`, "third chunk carries usage block") + assert.Equal(t, "[DONE]", events[3].Data, "final frame is the OpenAI DONE sentinel") +} + +func TestSSEScanner_AnthropicFixture(t *testing.T) { + f, err := os.Open(filepath.Join("fixtures", "anthropic_stream.txt")) + require.NoError(t, err, "fixture must be openable") + defer f.Close() + + events := collectEvents(t, f) + require.Len(t, events, 7, "expected 7 Anthropic events") + + types := make([]string, 0, len(events)) + for _, ev := range events { + types = append(types, ev.Type) + } + assert.Equal(t, []string{ + "message_start", + "content_block_start", + "content_block_delta", + "content_block_delta", + "content_block_stop", + "message_delta", + "message_stop", + }, types, "Anthropic event ordering matches fixture") + + var deltaUsage Event + for _, ev := range events { + if ev.Type == "message_delta" { + deltaUsage = ev + break + } + } + assert.Contains(t, deltaUsage.Data, `"output_tokens":45`, "message_delta carries partial usage") +} + +func TestSSEScanner_MultilineData(t *testing.T) { + raw := "event: ping\ndata: line1\ndata: line2\ndata: line3\n\n" + events := collectEvents(t, strings.NewReader(raw)) + + require.Len(t, events, 1, "one logical event from three data lines") + assert.Equal(t, "ping", events[0].Type, "event name honored") + assert.Equal(t, "line1\nline2\nline3", events[0].Data, "data lines joined with newline") +} + +func TestSSEScanner_CRLF(t *testing.T) { + raw := "event: foo\r\ndata: bar\r\n\r\ndata: baz\r\n\r\n" + events := collectEvents(t, strings.NewReader(raw)) + + require.Len(t, events, 2, "CRLF-delimited events recognized") + assert.Equal(t, "foo", events[0].Type, "first event type preserved") + assert.Equal(t, "bar", events[0].Data, "first event data preserved") + assert.Empty(t, events[1].Type, "second event has no event name") + assert.Equal(t, "baz", events[1].Data, "second event data preserved") +} + +func TestSSEScanner_EmptyInput(t *testing.T) { + s := NewScanner(strings.NewReader("")) + _, err := s.Next() + require.ErrorIs(t, err, io.EOF, "empty input yields immediate EOF") +} + +func TestSSEScanner_CommentIgnored(t *testing.T) { + raw := ": this is a comment\ndata: hi\n\n" + events := collectEvents(t, strings.NewReader(raw)) + require.Len(t, events, 1, "comment line does not emit an event") + assert.Equal(t, "hi", events[0].Data, "data line honoured after comment") +} + +func TestSSEScanner_TrailingWithoutBlankLine(t *testing.T) { + raw := "event: foo\ndata: bar\n" + events := collectEvents(t, strings.NewReader(raw)) + require.Len(t, events, 1, "trailing event without blank line still emitted") + assert.Equal(t, "foo", events[0].Type) + assert.Equal(t, "bar", events[0].Data) +} + +// TestSSEScanner_ManyConsecutiveEmptyLines feeds a stream that is nothing +// but empty lines. The scanner must terminate without panic — empty lines +// alone do not constitute an event and must yield io.EOF. +func TestSSEScanner_ManyConsecutiveEmptyLines(t *testing.T) { + raw := strings.Repeat("\n", 100) + s := NewScanner(strings.NewReader(raw)) + _, err := s.Next() + require.ErrorIs(t, err, io.EOF, "100 empty lines must terminate as EOF without panic") +} + +// TestSSEScanner_InterleavedCRLFAndLF mixes \r\n and \n terminators within +// the same event. The scanner normalizes both and must still recover a +// coherent event. +func TestSSEScanner_InterleavedCRLFAndLF(t *testing.T) { + raw := "event: mix\r\ndata: first\ndata: second\r\n\n" + events := collectEvents(t, strings.NewReader(raw)) + require.Len(t, events, 1, "mixed line endings must still produce one event") + assert.Equal(t, "mix", events[0].Type) + assert.Equal(t, "first\nsecond", events[0].Data, "both data lines joined") +} + +// TestSSEScanner_LongSingleDataLine constructs a single data line that +// exceeds the default bufio buffer (64 KiB) but stays under the scanner +// maxLine. The scanner must round-trip the value intact without panicking +// or truncating silently. +func TestSSEScanner_LongSingleDataLine(t *testing.T) { + big := strings.Repeat("x", 80<<10) + raw := "data: " + big + "\n\n" + events := collectEvents(t, strings.NewReader(raw)) + require.Len(t, events, 1, "long single-line event must be emitted") + assert.Equal(t, big, events[0].Data, "long data preserved") +} + +// TestSSEScanner_BinaryGarbageInData validates that non-printable bytes +// inside a data line do not crash the parser. The scanner should either +// round-trip them or return a well-formed error — never panic. +func TestSSEScanner_BinaryGarbageInData(t *testing.T) { + raw := "data: \x00\x01\x02\xff\xfe\n\n" + defer func() { + if r := recover(); r != nil { + t.Fatalf("scanner panicked on binary garbage: %v", r) + } + }() + s := NewScanner(strings.NewReader(raw)) + ev, err := s.Next() + require.NoError(t, err, "binary bytes in data should not surface as error") + assert.Equal(t, "\x00\x01\x02\xff\xfe", ev.Data, "binary payload round-trips") +} + +// TestSSEScanner_UnknownFieldsIgnored stresses the field parser by sending +// unrecognized field names ("id:", "retry:", "custom:"). They must be +// silently ignored per the SSE spec; the scanner must not panic or emit +// spurious events. +func TestSSEScanner_UnknownFieldsIgnored(t *testing.T) { + raw := "id: 1\nretry: 5000\ncustom: value\ndata: payload\n\n" + events := collectEvents(t, strings.NewReader(raw)) + require.Len(t, events, 1, "unknown fields must not spawn extra events") + assert.Equal(t, "payload", events[0].Data, "data field survives amid unknown fields") +} diff --git a/proxy/internal/metrics/metrics.go b/proxy/internal/metrics/metrics.go index 41a6b0dd4..5fd23d934 100644 --- a/proxy/internal/metrics/metrics.go +++ b/proxy/internal/metrics/metrics.go @@ -17,6 +17,7 @@ import ( // Metrics collects OpenTelemetry metrics for the proxy. type Metrics struct { ctx context.Context + meter metric.Meter requestsTotal metric.Int64Counter activeRequests metric.Int64UpDownCounter configuredDomains metric.Int64UpDownCounter @@ -49,10 +50,18 @@ type Metrics struct { mappingPaths map[string]int } +// Meter returns the OpenTelemetry meter the bundle was built with, so other +// subsystems (e.g. the middleware manager) register instruments on the same +// meter. +func (m *Metrics) Meter() metric.Meter { + return m.meter +} + // New creates a Metrics instance using the given OpenTelemetry meter. func New(ctx context.Context, meter metric.Meter) (*Metrics, error) { m := &Metrics{ ctx: ctx, + meter: meter, mappingPaths: make(map[string]int), } diff --git a/proxy/internal/middleware/bodypolicy.go b/proxy/internal/middleware/bodypolicy.go new file mode 100644 index 000000000..f31486fc8 --- /dev/null +++ b/proxy/internal/middleware/bodypolicy.go @@ -0,0 +1,63 @@ +package middleware + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" +) + +// ErrExpectContinue is returned when a middleware attempts to replace +// the body of a request that advertised Expect: 100-continue. +var ErrExpectContinue = errors.New("body replace rejected: request has Expect: 100-continue") + +// ErrOriginalNotDrained is returned when the original body was not +// fully consumed before replacement. This prevents the backend from +// seeing a mix of original bytes and the replacement. +var ErrOriginalNotDrained = errors.New("body replace rejected: original body not drained") + +// ErrContentLengthMismatch is returned when the client-advertised +// Content-Length disagrees with the number of bytes actually read from +// the body (short-read). +var ErrContentLengthMismatch = errors.New("body replace rejected: content-length mismatch (short read)") + +// ValidateBodyReplace runs the smuggling-prevention rules before a +// body replacement is applied. Callers must pass originalDrained=true +// once they have read r.Body to EOF. +func ValidateBodyReplace(r *http.Request, newBody []byte, originalDrained bool) error { + if r == nil { + return errors.New("body replace rejected: nil request") + } + if strings.EqualFold(r.Header.Get("Expect"), "100-continue") { + return ErrExpectContinue + } + if !originalDrained { + return ErrOriginalNotDrained + } + if cl := r.Header.Get("Content-Length"); cl != "" && r.ContentLength > 0 { + parsed, err := strconv.ParseInt(cl, 10, 64) + if err == nil && parsed != r.ContentLength { + return fmt.Errorf("%w: header=%d actual=%d", ErrContentLengthMismatch, parsed, r.ContentLength) + } + } + return nil +} + +// ApplyBodyReplace swaps r.Body for a reader over newBody, recomputes +// Content-Length, and strips Transfer-Encoding and Trailer so no stale +// framing reaches the backend. +func ApplyBodyReplace(r *http.Request, newBody []byte) { + if r == nil { + return + } + r.Body = io.NopCloser(bytes.NewReader(newBody)) + r.ContentLength = int64(len(newBody)) + r.Header.Set("Content-Length", strconv.Itoa(len(newBody))) + r.Header.Del("Transfer-Encoding") + r.Header.Del("Trailer") + r.TransferEncoding = nil + r.Trailer = nil +} diff --git a/proxy/internal/middleware/bodytap/request.go b/proxy/internal/middleware/bodytap/request.go new file mode 100644 index 000000000..826883a96 --- /dev/null +++ b/proxy/internal/middleware/bodytap/request.go @@ -0,0 +1,344 @@ +// Package bodytap owns the framework-side body capture used by the +// middleware chain. Request capture buffers up to N bytes of the +// request body for middleware inspection while replaying the original +// stream to the upstream. Response capture tees up to N bytes off the +// streaming response while every byte continues to flow to the client +// untouched. +// +// The package is the single owner of body access — middlewares never +// read req.Body or hijack the response writer. All inspection happens +// against the buffer surfaced by the tap, so streaming remains +// transparent to the client even when middlewares need access to the +// payload. +package bodytap + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "strconv" + "strings" + "sync" +) + +// MaxRoutingScanBytes bounds how far ScanRoutingFields will read into a +// request body to recover routing fields when the normal capture is +// bypassed for size. Sized to comfortably hold a 1M-token context +// request (whose `model` field a client may place after a multi-MB +// `messages` array) while still capping pathological inputs. +const MaxRoutingScanBytes int64 = 32 << 20 + +// Request bypass reasons emitted as the `mw.capture.bypass_reason` +// metadata key by the chain when a request body is not surfaced. +const ( + BypassUpgradeHeader = "upgrade_header" + BypassConnectionUpgrd = "connection_upgrade" + BypassContentType = "content_type_not_allowed" + BypassBudget = "capture_budget_exhausted" + BypassNoConfig = "no_capture_config" + BypassNoMiddlewares = "no_middlewares" + BypassCapZero = "cap_zero" + BypassContentLengthCap = "content_length_over_cap" +) + +// DefaultCaptureBudgetBytes is the default global capture-budget size. +const DefaultCaptureBudgetBytes int64 = 256 << 20 + +// Config holds per-target body capture limits after clamp validation. +// A zero MaxRequestBytes / MaxResponseBytes disables capture in that +// direction. +type Config struct { + MaxRequestBytes int64 + MaxResponseBytes int64 + ContentTypes []string +} + +// Budget is the global token-bucket semaphore shared across all +// in-flight captures so a single misbehaving target cannot exhaust the +// proxy. +type Budget interface { + Acquire(n int64) bool + Release(n int64) +} + +// NewBudget returns a Budget with the given total byte cap. A zero or +// negative total disables the budget check. +func NewBudget(total int64) Budget { + return &budget{total: total} +} + +type budget struct { + mu sync.Mutex + used int64 + total int64 +} + +func (b *budget) Acquire(n int64) bool { + if n <= 0 { + return true + } + b.mu.Lock() + defer b.mu.Unlock() + if b.total <= 0 { + return true + } + if b.used+n > b.total { + return false + } + b.used += n + return true +} + +func (b *budget) Release(n int64) { + if n <= 0 { + return + } + b.mu.Lock() + defer b.mu.Unlock() + if b.total <= 0 { + return + } + b.used -= n + if b.used < 0 { + b.used = 0 + } +} + +// CaptureRequest reads up to cfg.MaxRequestBytes from r.Body into a +// buffer suitable for middleware inspection, replacing r.Body with a +// replay reader so the upstream still sees the original bytes. When +// bypass != "" no body is read and r.Body is left untouched. The +// returned release function must be invoked once the request is fully +// processed; it returns the acquired budget tokens to the shared pool. +// release is always non-nil and is safe to defer immediately after the +// call. +func CaptureRequest(r *http.Request, cfg *Config, b Budget) (body []byte, truncated bool, originalSize int64, bypass string, release func(), err error) { + release = func() {} + if r == nil { + return nil, false, 0, BypassNoConfig, release, nil + } + if cfg == nil { + return nil, false, 0, BypassNoConfig, release, nil + } + if cfg.MaxRequestBytes <= 0 { + return nil, false, 0, BypassCapZero, release, nil + } + if r.Header.Get("Upgrade") != "" { + return nil, false, 0, BypassUpgradeHeader, release, nil + } + if strings.EqualFold(r.Header.Get("Connection"), "upgrade") { + return nil, false, 0, BypassConnectionUpgrd, release, nil + } + if !contentTypeAllowed(r.Header.Get("Content-Type"), cfg.ContentTypes) { + return nil, false, 0, BypassContentType, release, nil + } + + originalSize = parseContentLength(r.Header.Get("Content-Length")) + if originalSize > cfg.MaxRequestBytes { + return nil, true, originalSize, BypassContentLengthCap, release, nil + } + + limit := cfg.MaxRequestBytes + if b != nil && !b.Acquire(limit) { + return nil, false, originalSize, BypassBudget, release, nil + } + if b != nil { + var released sync.Once + release = func() { + released.Do(func() { b.Release(limit) }) + } + } + + if r.Body == nil || r.Body == http.NoBody { + release() + release = func() {} + return nil, false, originalSize, "", release, nil + } + + limited := io.LimitReader(r.Body, limit+1) + buf, readErr := io.ReadAll(limited) + if readErr != nil && !errors.Is(readErr, io.EOF) { + release() + release = func() {} + return nil, false, originalSize, "", release, readErr + } + + truncated = int64(len(buf)) > limit + if truncated { + replay := append([]byte(nil), buf...) + viewable := buf[:limit] + r.Body = &replayReadCloser{replay: bytes.NewReader(replay), tail: r.Body} + return viewable, true, originalSize, "", release, nil + } + _ = r.Body.Close() + r.Body = io.NopCloser(bytes.NewReader(buf)) + if originalSize <= 0 { + originalSize = int64(len(buf)) + } + return buf, false, originalSize, "", release, nil +} + +// replayReadCloser replays the captured prefix and then forwards the +// remaining bytes from the original body so the upstream sees the +// full request stream even when capture truncates. +type replayReadCloser struct { + replay *bytes.Reader + tail io.ReadCloser + drained bool +} + +func (r *replayReadCloser) Read(p []byte) (int, error) { + if !r.drained { + n, err := r.replay.Read(p) + if n > 0 { + return n, nil + } + if errors.Is(err, io.EOF) { + r.drained = true + } else if err != nil { + return 0, err + } + } + return r.tail.Read(p) +} + +func (r *replayReadCloser) Close() error { + return r.tail.Close() +} + +// ScanRoutingFields recovers the LLM routing fields ("model" and +// "stream") from a request whose normal capture was bypassed or +// truncated for size. It reads up to maxScan bytes of r.Body to locate +// the top-level keys — clients (e.g. Claude Code) may place `model` +// after a multi-MB `messages` array — then restores r.Body so the +// upstream still receives the full, untouched stream. Only the small +// routing fields are extracted; the prompt is never buffered for +// capture, keeping memory bounded. Returns ok=false when the body isn't +// a JSON object, the model field isn't found within maxScan, or on a +// read error. +func ScanRoutingFields(r *http.Request, maxScan int64) (model string, stream bool, ok bool) { + if r == nil || r.Body == nil || r.Body == http.NoBody || maxScan <= 0 { + return "", false, false + } + limited := io.LimitReader(r.Body, maxScan+1) + buf, readErr := io.ReadAll(limited) + if readErr != nil && !errors.Is(readErr, io.EOF) { + // Mid-stream read error (e.g. client disconnect): restore the bytes + // read so far plus the untouched tail and abort, rather than + // forwarding only the partial prefix as if it were the whole body. + r.Body = &replayReadCloser{replay: bytes.NewReader(append([]byte(nil), buf...)), tail: r.Body} + return "", false, false + } + if int64(len(buf)) > maxScan { + // Body exceeds the scan ceiling: restore the read prefix plus the + // untouched tail so the upstream still gets every byte. + r.Body = &replayReadCloser{replay: bytes.NewReader(append([]byte(nil), buf...)), tail: r.Body} + } else { + _ = r.Body.Close() + r.Body = io.NopCloser(bytes.NewReader(buf)) + } + return scanTopLevelModelStream(buf) +} + +// scanTopLevelModelStream walks the top level of a JSON object via a +// streaming token reader, extracting the "model" string and "stream" +// bool without materialising large values (each non-target value is +// skipped as a RawMessage). Tolerant of truncation: returns whatever was +// found before a malformed/short tail. +func scanTopLevelModelStream(body []byte) (model string, stream bool, ok bool) { + dec := json.NewDecoder(bytes.NewReader(body)) + tok, err := dec.Token() + if err != nil { + return "", false, false + } + if d, isDelim := tok.(json.Delim); !isDelim || d != '{' { + return "", false, false + } + for dec.More() { + keyTok, err := dec.Token() + if err != nil { + return model, stream, ok + } + key, _ := keyTok.(string) + switch key { + case "model": + var v string + if dec.Decode(&v) == nil { + model, ok = v, true + } + case "stream": + var v bool + if dec.Decode(&v) == nil { + stream = v + } + default: + // Skip the value by walking tokens instead of decoding it into + // a json.RawMessage — a multi-MB messages array would otherwise + // be materialised in full just to be discarded. + if err := skipValue(dec); err != nil { + return model, stream, ok + } + } + } + return model, stream, ok +} + +// skipValue consumes one JSON value from dec without materialising it. +// Scalars are a single token; objects/arrays are walked to their matching +// close delimiter so nested structures are skipped in bounded memory. +func skipValue(dec *json.Decoder) error { + tok, err := dec.Token() + if err != nil { + return err + } + d, isDelim := tok.(json.Delim) + if !isDelim || (d != '{' && d != '[') { + return nil + } + depth := 1 + for depth > 0 { + tok, err := dec.Token() + if err != nil { + return err + } + if d, ok := tok.(json.Delim); ok { + switch d { + case '{', '[': + depth++ + case '}', ']': + depth-- + } + } + } + return nil +} + +func contentTypeAllowed(ct string, allowed []string) bool { + if len(allowed) == 0 { + return false + } + media := ct + if idx := strings.Index(ct, ";"); idx >= 0 { + media = ct[:idx] + } + media = strings.TrimSpace(strings.ToLower(media)) + for _, a := range allowed { + if strings.EqualFold(strings.TrimSpace(a), media) { + return true + } + } + return false +} + +func parseContentLength(v string) int64 { + if v == "" { + return 0 + } + parsed, err := strconv.ParseInt(v, 10, 64) + if err != nil || parsed < 0 { + return 0 + } + return parsed +} diff --git a/proxy/internal/middleware/bodytap/response.go b/proxy/internal/middleware/bodytap/response.go new file mode 100644 index 000000000..c23e35b34 --- /dev/null +++ b/proxy/internal/middleware/bodytap/response.go @@ -0,0 +1,189 @@ +package bodytap + +import ( + "bytes" + "net/http" + "sync" + + "github.com/netbirdio/netbird/proxy/internal/responsewriter" +) + +// CapturingResponseWriter wraps an http.ResponseWriter, forwards bytes +// immediately to the client, and tees a bounded copy into an internal +// buffer for middleware inspection. Streaming-aware in the sense that +// every byte the upstream emits flows to the client without queuing +// — the tee just sees a bounded prefix. SSE-aware parsing happens in +// the response middleware against the buffered prefix; this writer +// makes no attempt to demux event boundaries. +// +// Flusher and Hijacker are preserved via responsewriter.PassthroughWriter. +type CapturingResponseWriter struct { + *responsewriter.PassthroughWriter + mu sync.Mutex + buf bytes.Buffer + cap int64 + status int + statusSet bool + written int64 + truncated bool + stopped bool + releaseBuf func() + released sync.Once + bypassed bool + bypassReas string + acquiredCap int64 +} + +// NewCapturingResponseWriter returns a writer that tees up to maxBytes +// into a capped buffer while forwarding bytes to the underlying writer +// immediately. When budget is non-nil the writer pre-acquires maxBytes +// from it and the returned wrapper must be released by calling +// Release() once the response is fully forwarded. If the budget cannot +// be acquired the writer falls back to forwarding the response +// unmodified, exposes Bypassed()=true with reason BypassBudget, and +// releases nothing. +func NewCapturingResponseWriter(w http.ResponseWriter, maxBytes int64, b Budget) *CapturingResponseWriter { + cw := &CapturingResponseWriter{ + PassthroughWriter: responsewriter.New(w), + cap: maxBytes, + status: http.StatusOK, + releaseBuf: func() {}, + } + if maxBytes <= 0 { + // Capture disabled: mark stopped so Write never tees and never + // flags truncation (a zero cap means "don't capture", not + // "captured nothing"). + cw.stopped = true + return cw + } + if b == nil { + return cw + } + if !b.Acquire(maxBytes) { + cw.bypassed = true + cw.bypassReas = BypassBudget + cw.cap = 0 + cw.stopped = true + return cw + } + cw.acquiredCap = maxBytes + cw.releaseBuf = func() { b.Release(maxBytes) } + return cw +} + +// Release returns the response capture budget acquired at construction +// back to the shared pool. Idempotent. Safe to call from a defer +// immediately after construction even when the writer ended up +// bypassing the budget. +func (c *CapturingResponseWriter) Release() { + if c == nil { + return + } + c.released.Do(func() { + if c.releaseBuf != nil { + c.releaseBuf() + } + }) +} + +// Bypassed reports whether the writer fell through to a no-tee +// passthrough because the response capture budget could not be +// acquired. +func (c *CapturingResponseWriter) Bypassed() bool { + if c == nil { + return false + } + c.mu.Lock() + defer c.mu.Unlock() + return c.bypassed +} + +// BypassReason returns the bypass code recorded by the budget check. +// Empty when capture proceeded normally. +func (c *CapturingResponseWriter) BypassReason() string { + if c == nil { + return "" + } + c.mu.Lock() + defer c.mu.Unlock() + return c.bypassReas +} + +// WriteHeader records the status code and forwards it to the underlying +// writer. Only the first call commits the status — matching HTTP semantics, +// where superfluous WriteHeader calls (and any call after the body has +// started) are ignored — so Status() reflects the code actually sent. +func (c *CapturingResponseWriter) WriteHeader(status int) { + c.mu.Lock() + if c.statusSet { + c.mu.Unlock() + return + } + c.status = status + c.statusSet = true + c.mu.Unlock() + c.PassthroughWriter.WriteHeader(status) +} + +// Write forwards p to the underlying writer unmodified and copies up +// to the remaining buffer capacity into the tee buffer. +func (c *CapturingResponseWriter) Write(p []byte) (int, error) { + n, err := c.PassthroughWriter.Write(p) + if n > 0 { + c.mu.Lock() + // The first byte commits the status (implicit 200 if WriteHeader was + // never called); a later WriteHeader must not change Status(). + c.statusSet = true + c.written += int64(n) + if !c.stopped { + remaining := c.cap - int64(c.buf.Len()) + if remaining <= 0 { + c.truncated = true + c.stopped = true + } else { + take := int64(n) + if take > remaining { + take = remaining + c.truncated = true + c.stopped = true + } + c.buf.Write(p[:take]) + } + } + c.mu.Unlock() + } + return n, err +} + +// Status returns the captured status code (defaults to 200 when +// WriteHeader has not been called). +func (c *CapturingResponseWriter) Status() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.status +} + +// Body returns a copy of the buffered response prefix. +func (c *CapturingResponseWriter) Body() []byte { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]byte, c.buf.Len()) + copy(out, c.buf.Bytes()) + return out +} + +// Truncated reports whether the buffered prefix stopped short of the +// full response stream. +func (c *CapturingResponseWriter) Truncated() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.truncated +} + +// BytesWritten returns the total number of bytes forwarded to the +// underlying writer. +func (c *CapturingResponseWriter) BytesWritten() int64 { + c.mu.Lock() + defer c.mu.Unlock() + return c.written +} diff --git a/proxy/internal/middleware/bodytap/routing_scan_test.go b/proxy/internal/middleware/bodytap/routing_scan_test.go new file mode 100644 index 000000000..1748c989e --- /dev/null +++ b/proxy/internal/middleware/bodytap/routing_scan_test.go @@ -0,0 +1,86 @@ +package bodytap + +import ( + "fmt" + "io" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// makeBigAnthropicBody builds a request body shaped like Claude Code's: +// a multi-MB "messages" array with the routing fields (model, stream) +// placed AFTER it, which is the ordering that defeats a prefix-only +// capture. +func makeBigAnthropicBody(t *testing.T, model string, stream bool, messagesBytes int) string { + t.Helper() + filler := strings.Repeat("x", messagesBytes) + return fmt.Sprintf( + `{"max_tokens":64000,"messages":[{"role":"user","content":%q}],"model":%q,"stream":%t}`, + filler, model, stream, + ) +} + +func TestScanRoutingFields_ModelAfterLargeMessages(t *testing.T) { + body := makeBigAnthropicBody(t, "claude-opus-4-8", true, 3<<20) // 3 MiB messages + req := httptest.NewRequest("POST", "https://x/v1/messages", strings.NewReader(body)) + + model, stream, ok := ScanRoutingFields(req, MaxRoutingScanBytes) + require.True(t, ok, "model must be recovered even when it follows a multi-MB messages array") + assert.Equal(t, "claude-opus-4-8", model, "model field must be extracted") + assert.True(t, stream, "stream field must be extracted") + + // Body must be fully restored for the upstream. + got, err := io.ReadAll(req.Body) + require.NoError(t, err) + assert.Equal(t, body, string(got), "the full request body must be replayed to upstream after scanning") +} + +func TestScanRoutingFields_SmallBody(t *testing.T) { + body := `{"model":"claude-opus-4-8","stream":false,"messages":[]}` + req := httptest.NewRequest("POST", "https://x/v1/messages", strings.NewReader(body)) + + model, stream, ok := ScanRoutingFields(req, MaxRoutingScanBytes) + require.True(t, ok) + assert.Equal(t, "claude-opus-4-8", model) + assert.False(t, stream) + + got, _ := io.ReadAll(req.Body) + assert.Equal(t, body, string(got), "small bodies must also be restored intact") +} + +func TestScanRoutingFields_NoModel(t *testing.T) { + body := `{"stream":true,"messages":[]}` + req := httptest.NewRequest("POST", "https://x/v1/messages", strings.NewReader(body)) + + _, _, ok := ScanRoutingFields(req, MaxRoutingScanBytes) + assert.False(t, ok, "ok must be false when no model field is present") + + got, _ := io.ReadAll(req.Body) + assert.Equal(t, body, string(got), "body must be restored even when model is absent") +} + +func TestScanRoutingFields_NotJSON(t *testing.T) { + body := "this is not json at all" + req := httptest.NewRequest("POST", "https://x/v1/messages", strings.NewReader(body)) + + _, _, ok := ScanRoutingFields(req, MaxRoutingScanBytes) + assert.False(t, ok, "ok must be false for a non-JSON body") +} + +func TestScanRoutingFields_ModelBeyondScanCeiling(t *testing.T) { + // model sits after 4 MiB of messages but the scan ceiling is 1 MiB: + // model can't be recovered, yet the full body must still replay. + body := makeBigAnthropicBody(t, "claude-opus-4-8", true, 4<<20) + req := httptest.NewRequest("POST", "https://x/v1/messages", strings.NewReader(body)) + + _, _, ok := ScanRoutingFields(req, 1<<20) + assert.False(t, ok, "model beyond the scan ceiling is not recoverable") + + got, err := io.ReadAll(req.Body) + require.NoError(t, err) + assert.Equal(t, body, string(got), "the full body must still replay to upstream even when the scan gives up") +} diff --git a/proxy/internal/middleware/builtin/agentnetwork_chain_integration_test.go b/proxy/internal/middleware/builtin/agentnetwork_chain_integration_test.go new file mode 100644 index 000000000..96777025c --- /dev/null +++ b/proxy/internal/middleware/builtin/agentnetwork_chain_integration_test.go @@ -0,0 +1,318 @@ +package builtin_test + +import ( + "context" + "net" + "runtime" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + + mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork" + agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/management/server/store" + nbtypes "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_limit_check" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_limit_record" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// chainIntegrationFixture wires the BOTH new agent-network +// middlewares against a live in-process management stack: real +// sqlite store + real Manager + real gRPC server. The proxy chain +// framework itself isn't constructed (its dispatcher / accumulator / +// metadata gate are tested separately); we exercise the middleware +// pair as the proxy runtime would, by invoking each with a crafted +// Input and asserting the wire path between them. +// +// This is the regression cover for item 16 in the design review: +// real LLM request → cost stamped → consumption row in the table. +type chainIntegrationFixture struct { + store store.Store + manager agentnetwork.Manager + gatecase *llm_limit_check.Middleware + recorder *llm_limit_record.Middleware +} + +func newChainIntegration(t *testing.T) *chainIntegrationFixture { + t.Helper() + if runtime.GOOS == "windows" { + t.Skip("sqlite store not properly supported on Windows yet") + } + t.Setenv("NETBIRD_STORE_ENGINE", string(nbtypes.SqliteStoreEngine)) + + st, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + require.NoError(t, err) + t.Cleanup(cleanUp) + + manager := agentnetwork.NewManager(st, nil, nil, nil) + + server := &mgmtgrpc.ProxyServiceServer{} + server.SetAgentNetworkLimitsService(manager) + + const bufSize = 1024 * 1024 + lis := bufconn.Listen(bufSize) + srv := grpc.NewServer() + proto.RegisterProxyServiceServer(srv, server) + go func() { _ = srv.Serve(lis) }() + t.Cleanup(srv.Stop) + + conn, err := grpc.NewClient("passthrough:///bufnet", + grpc.WithContextDialer(func(_ context.Context, _ string) (net.Conn, error) { return lis.Dial() }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + + mgmtClient := proto.NewProxyServiceClient(conn) + return &chainIntegrationFixture{ + store: st, + manager: manager, + gatecase: llm_limit_check.New(mgmtClient, nil), + recorder: llm_limit_record.New(mgmtClient, nil), + } +} + +// chainInput builds a middleware Input that mirrors what the proxy +// framework would synthesise for a tunnel-peer LLM request. The +// gate consumes the resolved provider id from upstream metadata +// (set by llm_router); the recorder consumes the attribution +// metadata stamped by the gate plus tokens / cost from +// llm_response_parser + cost_meter. +func chainInput(account, user, group, providerID string, requestMeta []middleware.KV) *middleware.Input { + _ = providerID // packed into requestMeta by the caller as KeyLLMResolvedProviderID + return &middleware.Input{ + AccountID: account, + UserID: user, + UserGroups: []string{group}, + Metadata: requestMeta, + } +} + +// chainCapPolicy builds a tight token-cap policy fixture for the +// chain integration tests. Inlined here (rather than imported) because +// the equivalent helper in the management gRPC package is unexported +// and this is a different package boundary. +func chainCapPolicy(id, account string, sourceGroups []string, providerID string, tokenCap, windowSec int64) *agentNetworkTypes.Policy { + return &agentNetworkTypes.Policy{ + ID: id, + AccountID: account, + Enabled: true, + Name: id, + SourceGroups: sourceGroups, + DestinationProviderIDs: []string{providerID}, + Limits: agentNetworkTypes.PolicyLimits{ + TokenLimit: agentNetworkTypes.PolicyTokenLimit{ + Enabled: true, + GroupCap: tokenCap, + WindowSeconds: windowSec, + }, + }, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } +} + +// TestChain_AllowPath_StampsAttributionAndRecordsCounter walks the +// full happy path: gate calls CheckLLMPolicyLimits → stamps +// attribution metadata → recorder reads metadata + tokens / cost → +// calls RecordLLMUsage → counters land in sqlite. Asserting on the +// store at the end proves every leg of the wire works together, +// not just each leg in isolation (which the unit tests already cover). +func TestChain_AllowPath_StampsAttributionAndRecordsCounter(t *testing.T) { + f := newChainIntegration(t) + + const account = "acc-1" + const user = "user-bob" + const group = "grp-engineers" + const provider = "prov-1" + + // Seed a policy with token + budget caps; both halves carry + // real ceilings so the request stays within headroom. + require.NoError(t, f.store.SaveAgentNetworkPolicy(context.Background(), + chainCapPolicy("pol-1", account, []string{group}, provider, 10_000, 86_400))) + + // ── Stage 1 — gate: pre-flight check ────────────────────── + gateIn := chainInput(account, user, group, provider, []middleware.KV{ + {Key: middleware.KeyLLMResolvedProviderID, Value: provider}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + }) + gateOut, err := f.gatecase.Invoke(context.Background(), gateIn) + require.NoError(t, err) + require.Equal(t, middleware.DecisionAllow, gateOut.Decision, "fresh policy must allow") + + // Verify attribution metadata was stamped — the recorder + // depends on these keys. + metaMap := map[string]string{} + for _, kv := range gateOut.Metadata { + metaMap[kv.Key] = kv.Value + } + assert.Equal(t, "pol-1", metaMap[middleware.KeyLLMSelectedPolicyID]) + assert.Equal(t, group, metaMap[middleware.KeyLLMAttributionGroupID]) + assert.Equal(t, "86400", metaMap[middleware.KeyLLMAttributionWindowS]) + + // ── Stage 2 — recorder: post-flight write ───────────────── + // Build the response-leg Input the framework would synthesise + // for the recorder: gate's emitted attribution metadata + the + // tokens / cost stamped by llm_response_parser + cost_meter. + const tokensIn = int64(123) + const tokensOut = int64(45) + const costUSD = 0.0042 + recordIn := chainInput(account, user, group, provider, append([]middleware.KV{}, + gateOut.Metadata...)) + recordIn.Metadata = append(recordIn.Metadata, + middleware.KV{Key: middleware.KeyLLMInputTokens, Value: strconv.FormatInt(tokensIn, 10)}, + middleware.KV{Key: middleware.KeyLLMOutputTokens, Value: strconv.FormatInt(tokensOut, 10)}, + middleware.KV{Key: middleware.KeyCostUSDTotal, Value: strconv.FormatFloat(costUSD, 'f', 6, 64)}, + ) + recordOut, err := f.recorder.Invoke(context.Background(), recordIn) + require.NoError(t, err) + assert.Equal(t, middleware.DecisionAllow, recordOut.Decision, "recorder always allows; its only side effect is the counter write") + + // ── Stage 3 — assert state in sqlite ────────────────────── + windowStart := agentNetworkTypes.WindowStart(time.Now(), 86_400) + userRow, err := f.store.GetAgentNetworkConsumption( + context.Background(), store.LockingStrengthNone, account, + agentNetworkTypes.DimensionUser, user, int64(86_400), windowStart, + ) + require.NoError(t, err) + assert.Equal(t, tokensIn, userRow.TokensInput, "user counter must hold the input tokens the recorder posted") + assert.Equal(t, tokensOut, userRow.TokensOutput) + assert.InDelta(t, costUSD, userRow.CostUSD, 1e-6) + + groupRow, err := f.store.GetAgentNetworkConsumption( + context.Background(), store.LockingStrengthNone, account, + agentNetworkTypes.DimensionGroup, group, int64(86_400), windowStart, + ) + require.NoError(t, err) + assert.Equal(t, tokensIn, groupRow.TokensInput, "group counter mirrors the user counter — single Record posts both dims") +} + +// TestChain_DenyPath_GateRejectsAndNoConsumptionWritten covers the +// negative side: when the gate denies, the recorder is never +// invoked (the proxy framework short-circuits on Decision=Deny). +// We assert no consumption row materialises after the gate-deny +// path, even though the test technically calls the recorder +// afterwards — the recorder must skip on missing attribution +// metadata so the framework's short-circuit isn't load-bearing for +// data integrity. +func TestChain_DenyPath_GateRejectsAndNoConsumptionWritten(t *testing.T) { + f := newChainIntegration(t) + + const account = "acc-1" + const user = "user-bob" + const group = "grp-tight" + const provider = "prov-1" + + policy := chainCapPolicy("pol-tight", account, []string{group}, provider, 100, 86_400) + require.NoError(t, f.store.SaveAgentNetworkPolicy(context.Background(), policy)) + + // Pre-burn the counter to the cap so the gate denies. + require.NoError(t, f.store.IncrementAgentNetworkConsumption( + context.Background(), account, + agentNetworkTypes.DimensionGroup, group, int64(86_400), + agentNetworkTypes.WindowStart(time.Now(), 86_400), + 100, 0, 0, + )) + + gateOut, err := f.gatecase.Invoke(context.Background(), chainInput(account, user, group, provider, + []middleware.KV{ + {Key: middleware.KeyLLMResolvedProviderID, Value: provider}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + }, + )) + require.NoError(t, err) + assert.Equal(t, middleware.DecisionDeny, gateOut.Decision, "policy at-cap must deny on the gate") + require.NotNil(t, gateOut.DenyReason) + assert.Equal(t, "llm_policy.token_cap_exceeded", gateOut.DenyReason.Code) + + // On deny, the gate emits no attribution metadata. If the + // proxy framework still invokes the recorder (defense in + // depth), the recorder's "no attribution window = skip" guard + // prevents a phantom counter increment. + recordOut, err := f.recorder.Invoke(context.Background(), chainInput(account, user, group, provider, + gateOut.Metadata, // no llm.attribution_window_seconds stamped + )) + require.NoError(t, err) + assert.Equal(t, middleware.DecisionAllow, recordOut.Decision) + + // The pre-burned 100 tokens are the only counter movement — + // the recorder must NOT have added a fresh row for the user + // dimension on this denied request. + windowStart := agentNetworkTypes.WindowStart(time.Now(), 86_400) + userRow, err := f.store.GetAgentNetworkConsumption( + context.Background(), store.LockingStrengthNone, account, + agentNetworkTypes.DimensionUser, user, int64(86_400), windowStart, + ) + require.NoError(t, err) + assert.Zero(t, userRow.TokensInput, "user dimension must not gain tokens from a denied request — recorder skip is the safety net") +} + +// TestChain_CapExhaustTransition exercises the allow→deny boundary +// the operator cares most about: a request just under cap allows +// AND records, the next request post-record at-cap denies. This is +// the same lifecycle 50-grpc-allow-record-deny.sh runs in bash, but +// against the actual middleware pair rather than the smoke binary +// driving the gRPC RPCs directly. +func TestChain_CapExhaustTransition(t *testing.T) { + f := newChainIntegration(t) + + const account = "acc-1" + const user = "user-alice" + const group = "grp-cap-edge" + const provider = "prov-1" + const tightCap = int64(100) + + require.NoError(t, f.store.SaveAgentNetworkPolicy(context.Background(), + chainCapPolicy("pol-edge", account, []string{group}, provider, tightCap, 86_400))) + + // Pre-burn 99 tokens so we're at the very edge. + require.NoError(t, f.store.IncrementAgentNetworkConsumption( + context.Background(), account, + agentNetworkTypes.DimensionGroup, group, int64(86_400), + agentNetworkTypes.WindowStart(time.Now(), 86_400), + 99, 0, 0, + )) + + // Gate at 99/100 — must allow (one token of headroom). + gateOut, err := f.gatecase.Invoke(context.Background(), chainInput(account, user, group, provider, + []middleware.KV{ + {Key: middleware.KeyLLMResolvedProviderID, Value: provider}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + }, + )) + require.NoError(t, err) + require.Equal(t, middleware.DecisionAllow, gateOut.Decision, "99/100 must allow — one token of headroom") + + // Record one more input token — pushes us to 100/100. + recordIn := chainInput(account, user, group, provider, append([]middleware.KV{}, + gateOut.Metadata...)) + recordIn.Metadata = append(recordIn.Metadata, + middleware.KV{Key: middleware.KeyLLMInputTokens, Value: "1"}, + middleware.KV{Key: middleware.KeyLLMOutputTokens, Value: "0"}, + middleware.KV{Key: middleware.KeyCostUSDTotal, Value: "0.000001"}, + ) + _, err = f.recorder.Invoke(context.Background(), recordIn) + require.NoError(t, err) + + // Next gate call must deny — counter is exactly at cap. + gateOut2, err := f.gatecase.Invoke(context.Background(), chainInput(account, user, group, provider, + []middleware.KV{ + {Key: middleware.KeyLLMResolvedProviderID, Value: provider}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + }, + )) + require.NoError(t, err) + assert.Equal(t, middleware.DecisionDeny, gateOut2.Decision, + "once recorder pushed the group counter to 100/100, the next gate call must deny — allow→deny transition is the operator-visible product semantic") + require.NotNil(t, gateOut2.DenyReason) + assert.Equal(t, "llm_policy.token_cap_exceeded", gateOut2.DenyReason.Code) +} diff --git a/proxy/internal/middleware/builtin/all_test.go b/proxy/internal/middleware/builtin/all_test.go new file mode 100644 index 000000000..28576e248 --- /dev/null +++ b/proxy/internal/middleware/builtin/all_test.go @@ -0,0 +1,40 @@ +package builtin_test + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/assert" + + mwbuiltin "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" + + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/cost_meter" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_guardrail" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_identity_inject" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_limit_check" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_limit_record" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_request_parser" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_response_parser" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_router" +) + +// TestDefaultRegistry_BuiltinIDs locks the set of middleware IDs that +// the default builtin registry exposes once every sub-package's init() +// has run. The list is the source of truth wired by the synthesiser +// in management; adding a new built-in middleware should consciously +// extend this list. +func TestDefaultRegistry_BuiltinIDs(t *testing.T) { + got := mwbuiltin.DefaultRegistry().IDs() + sort.Strings(got) + want := []string{ + "cost_meter", + "llm_guardrail", + "llm_identity_inject", + "llm_limit_check", + "llm_limit_record", + "llm_request_parser", + "llm_response_parser", + "llm_router", + } + assert.Equal(t, want, got, "default registry must expose every built-in middleware after anonymous imports") +} diff --git a/proxy/internal/middleware/builtin/builtin.go b/proxy/internal/middleware/builtin/builtin.go new file mode 100644 index 000000000..9ea4cf89d --- /dev/null +++ b/proxy/internal/middleware/builtin/builtin.go @@ -0,0 +1,93 @@ +// Package builtin holds the package-level middleware registry that +// concrete middleware packages register themselves into via init(). +// Server boot anonymous-imports each middleware sub-package; the +// resolver attached to the middleware Manager pulls factories out of +// this registry. +package builtin + +import ( + "context" + "sync" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// MgmtClient is the narrow slice of proto.ProxyServiceClient that +// builtin middlewares may use during request / response handling. +// Only the agent-network limit pair (llm_limit_check + llm_limit_record) +// uses this today; declaring the surface here keeps the dependency +// explicit at boot time. +// +// proto.ProxyServiceClient already satisfies this interface so server +// boot just forwards its existing client. +type MgmtClient interface { + CheckLLMPolicyLimits(ctx context.Context, in *proto.CheckLLMPolicyLimitsRequest, opts ...grpc.CallOption) (*proto.CheckLLMPolicyLimitsResponse, error) + RecordLLMUsage(ctx context.Context, in *proto.RecordLLMUsageRequest, opts ...grpc.CallOption) (*proto.RecordLLMUsageResponse, error) +} + +// defaultRegistry is the package-level registry that concrete builtin +// middlewares register themselves into via init(). +var defaultRegistry = middleware.NewRegistry() + +// FactoryContext is the per-process bag that concrete factories may +// consult during construction. It carries the proxy-lifetime context, +// the data directory used for static config files (pricing tables, +// allowlists), the OTel meter, and the proxy logger. +// +// Configure must be called once at boot before any chain build calls +// Resolve. Calling it twice overwrites the prior value; tests may rely +// on this to reset state. +type FactoryContext struct { + Context context.Context + DataDir string + Meter metric.Meter + Logger *log.Logger + MgmtClient MgmtClient +} + +var ( + ctxStore FactoryContext + ctxMu sync.RWMutex +) + +// Configure stores the per-process FactoryContext. Concrete factories +// reach for it via Context(). mgmt may be nil on tests / standalone +// builds with no management server; consumers must guard. +func Configure(ctx context.Context, dataDir string, meter metric.Meter, logger *log.Logger, mgmt MgmtClient) { + ctxMu.Lock() + defer ctxMu.Unlock() + ctxStore = FactoryContext{ + Context: ctx, + DataDir: dataDir, + Meter: meter, + Logger: logger, + MgmtClient: mgmt, + } +} + +// Context returns the stored FactoryContext. Returns a zero value when +// Configure was never called; consumers must guard against nil +// Context/Meter/Logger if they care. +func Context() FactoryContext { + ctxMu.RLock() + defer ctxMu.RUnlock() + return ctxStore +} + +// Register adds a factory to the default registry. Called from init() +// blocks of concrete middleware packages. Panics on collision so +// duplicate IDs surface at startup. +func Register(f middleware.Factory) { + defaultRegistry.MustRegister(f) +} + +// DefaultRegistry returns the shared registry. The proxy server +// constructs the Resolver from it at boot. +func DefaultRegistry() *middleware.Registry { + return defaultRegistry +} diff --git a/proxy/internal/middleware/builtin/cost_meter/factory.go b/proxy/internal/middleware/builtin/cost_meter/factory.go new file mode 100644 index 000000000..b8a58d10e --- /dev/null +++ b/proxy/internal/middleware/builtin/cost_meter/factory.go @@ -0,0 +1,88 @@ +package cost_meter + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + + "github.com/netbirdio/netbird/proxy/internal/llm/pricing" + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" +) + +// defaultPricingFilename is the basename probed inside the proxy data +// directory when no override is configured. +const defaultPricingFilename = "pricing.yaml" + +// Config is the on-wire configuration for the middleware. +type Config struct { + // PricingPath optionally overrides the basename of the pricing + // file probed inside the proxy data directory. When empty the + // loader falls back to "pricing.yaml". + PricingPath string `json:"pricing_path"` +} + +// Factory builds cost_meter instances from raw config bytes. +type Factory struct{} + +// ID returns the registry identifier. +func (Factory) ID() string { return ID } + +// New constructs a middleware instance. Empty, null, and {} configs +// are accepted; non-empty rawConfig that fails to unmarshal is +// rejected so misconfigurations surface at chain build time. The +// pricing loader is built once per instance and reused across +// invocations. +func (Factory) New(rawConfig []byte) (middleware.Middleware, error) { + cfg, err := decodeConfig(rawConfig) + if err != nil { + return nil, err + } + + fctx := builtin.Context() + pricingPath := cfg.PricingPath + if pricingPath == "" { + pricingPath = defaultPricingFilename + } + + loader, err := pricing.NewLoader(fctx.DataDir, pricingPath, ID, nil) + if err != nil { + return nil, fmt.Errorf("init pricing loader: %w", err) + } + + cancel := startReloader(fctx.Context, loader) + + return newMiddleware(loader, cancel), nil +} + +// startReloader binds the loader's mtime-poll goroutine to a context +// derived from the proxy-lifetime context and returns its cancel func so +// the owning middleware can stop the goroutine on teardown. Returns nil +// when there's nothing to watch (nil context or defaults-only loader), in +// which case the middleware's Close is a no-op. +func startReloader(ctx context.Context, loader *pricing.Loader) context.CancelFunc { + if ctx == nil || !loader.WatchesFile() { + return nil + } + cctx, cancel := context.WithCancel(ctx) + go loader.Reload(cctx) + return cancel +} + +// decodeConfig accepts empty, null, and {} configs, returning a +// zero-value Config. Non-empty payloads must parse cleanly. +func decodeConfig(rawConfig []byte) (Config, error) { + var cfg Config + if len(bytes.TrimSpace(rawConfig)) == 0 { + return cfg, nil + } + if err := json.Unmarshal(rawConfig, &cfg); err != nil { + return cfg, fmt.Errorf("decode config: %w", err) + } + return cfg, nil +} + +func init() { + builtin.Register(Factory{}) +} diff --git a/proxy/internal/middleware/builtin/cost_meter/middleware.go b/proxy/internal/middleware/builtin/cost_meter/middleware.go new file mode 100644 index 000000000..4da620310 --- /dev/null +++ b/proxy/internal/middleware/builtin/cost_meter/middleware.go @@ -0,0 +1,193 @@ +// Package cost_meter implements the SlotOnResponse middleware that +// converts token-usage metadata emitted by llm_response_parser into a +// per-request USD cost estimate. The middleware uses the shared pricing +// loader so operator pricing overrides apply to the chain. +package cost_meter + +import ( + "context" + "fmt" + "strconv" + + "github.com/netbirdio/netbird/proxy/internal/llm/pricing" + "github.com/netbirdio/netbird/proxy/internal/middleware" +) + +// ID is the registry identifier for this middleware. +const ID = "cost_meter" + +// Version is the implementation version emitted via the spec merge. +const Version = "1.0.0" + +// Skip reasons emitted under KeyCostSkipped. The set is closed; the +// dashboard surfaces these verbatim. +const ( + skipMissingProvider = "missing_provider" + skipMissingModel = "missing_model" + skipMissingTokens = "missing_tokens" + //nolint:gosec // skip-reason label, not a credential + skipUnparseableTokens = "unparseable_tokens" + skipZeroTokens = "zero_tokens" + skipUnknownModel = "unknown_model" +) + +var metadataKeys = []string{ + middleware.KeyCostUSDTotal, + middleware.KeyCostSkipped, +} + +// Middleware computes a per-response cost estimate from the token +// counts emitted upstream by llm_response_parser. +type Middleware struct { + loader *pricing.Loader + // cancel stops this instance's pricing-reload goroutine. Non-nil only + // when the loader watches an override file; Close calls it so a chain + // rebuild doesn't leak a poll goroutine per retired instance. + cancel context.CancelFunc +} + +// newMiddleware constructs a Middleware bound to the given pricing loader. +// cancel may be nil (defaults-only loader with no reloader to stop). +func newMiddleware(loader *pricing.Loader, cancel context.CancelFunc) *Middleware { + return &Middleware{loader: loader, cancel: cancel} +} + +// ID returns the registry identifier. +func (m *Middleware) ID() string { return ID } + +// Version returns the implementation version. +func (m *Middleware) Version() string { return Version } + +// Slot reports that the middleware runs after the upstream call. +func (m *Middleware) Slot() middleware.Slot { return middleware.SlotOnResponse } + +// AcceptedContentTypes is empty: cost_meter never inspects bodies. +func (m *Middleware) AcceptedContentTypes() []string { return []string{} } + +// MetadataKeys returns the closed allowlist of keys this middleware +// may emit. +func (m *Middleware) MetadataKeys() []string { + return append([]string(nil), metadataKeys...) +} + +// MutationsSupported reports that this middleware never mutates the +// response. +func (m *Middleware) MutationsSupported() bool { return false } + +// Close stops this instance's pricing-reload goroutine, if any. Called by +// the chain when a rebuild retires the instance, so the mtime-poll loop +// doesn't outlive the chain it belonged to. Safe to call on a nil receiver +// and on an instance with no reloader. +func (m *Middleware) Close() error { + if m != nil && m.cancel != nil { + m.cancel() + } + return nil +} + +// Invoke reads provider, model, and token metadata, looks up pricing, +// and emits either KeyCostUSDTotal or KeyCostSkipped. The decision is +// always DecisionAllow; cost metering never denies or mutates. +func (m *Middleware) Invoke(_ context.Context, in *middleware.Input) (*middleware.Output, error) { + out := &middleware.Output{Decision: middleware.DecisionAllow} + if in == nil { + return out, nil + } + + provider := lookupKV(in.Metadata, middleware.KeyLLMProvider) + if provider == "" { + out.Metadata = skip(skipMissingProvider) + return out, nil + } + + model := lookupKV(in.Metadata, middleware.KeyLLMModel) + if model == "" { + out.Metadata = skip(skipMissingModel) + return out, nil + } + + inRaw, hasIn := lookupKVOK(in.Metadata, middleware.KeyLLMInputTokens) + outRaw, hasOut := lookupKVOK(in.Metadata, middleware.KeyLLMOutputTokens) + if !hasIn || !hasOut { + out.Metadata = skip(skipMissingTokens) + return out, nil + } + + inTokens, err := strconv.ParseInt(inRaw, 10, 64) + if err != nil || inTokens < 0 { + // Unparseable or negative tokens are not a runtime error: the + // upstream llm_response_parser emitted a non-numeric / invalid + // value, so we surface that as cost.skipped and continue with + // Allow rather than pricing a negative count. + out.Metadata = skip(skipUnparseableTokens) + return out, nil //nolint:nilerr // structured skip; not a runtime error + } + outTokens, err := strconv.ParseInt(outRaw, 10, 64) + if err != nil || outTokens < 0 { + out.Metadata = skip(skipUnparseableTokens) + return out, nil //nolint:nilerr // structured skip; not a runtime error + } + + // Cache buckets are optional and silently zeroed on a missing / + // malformed value; they're a refinement on top of input cost, + // not a precondition. A buggy value falls back to 0, never aborts. + cachedTokens := parseOptionalInt64(in.Metadata, middleware.KeyLLMCachedInputTokens) + cacheCreationTokens := parseOptionalInt64(in.Metadata, middleware.KeyLLMCacheCreationTokens) + + if inTokens == 0 && outTokens == 0 && cachedTokens == 0 && cacheCreationTokens == 0 { + out.Metadata = skip(skipZeroTokens) + return out, nil + } + + table := m.loader.Get() + cost, ok := table.Cost(provider, model, inTokens, outTokens, cachedTokens, cacheCreationTokens) + if !ok { + out.Metadata = skip(skipUnknownModel) + return out, nil + } + + out.Metadata = []middleware.KV{ + {Key: middleware.KeyCostUSDTotal, Value: fmt.Sprintf("%.6f", cost)}, + } + return out, nil +} + +// skip returns a single-entry metadata slice carrying the given skip +// reason under KeyCostSkipped. +func skip(reason string) []middleware.KV { + return []middleware.KV{{Key: middleware.KeyCostSkipped, Value: reason}} +} + +// lookupKV returns the value associated with key, or the empty string +// when the key is absent. +func lookupKV(kvs []middleware.KV, key string) string { + v, _ := lookupKVOK(kvs, key) + return v +} + +// lookupKVOK returns the value associated with key plus a presence +// flag so callers can distinguish absent from empty. +func lookupKVOK(kvs []middleware.KV, key string) (string, bool) { + for _, kv := range kvs { + if kv.Key == key { + return kv.Value, true + } + } + return "", false +} + +// parseOptionalInt64 reads a metadata value and decodes it as int64. +// Absent or unparseable values yield 0 — the caller treats absence as +// "no cached tokens" rather than an error, since cache buckets are a +// refinement, not a precondition. +func parseOptionalInt64(kvs []middleware.KV, key string) int64 { + raw, ok := lookupKVOK(kvs, key) + if !ok { + return 0 + } + v, err := strconv.ParseInt(raw, 10, 64) + if err != nil || v < 0 { + return 0 + } + return v +} diff --git a/proxy/internal/middleware/builtin/cost_meter/middleware_test.go b/proxy/internal/middleware/builtin/cost_meter/middleware_test.go new file mode 100644 index 000000000..d1c161cab --- /dev/null +++ b/proxy/internal/middleware/builtin/cost_meter/middleware_test.go @@ -0,0 +1,459 @@ +package cost_meter + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" +) + +const fixturePricing = `openai: + gpt-4o: + input_per_1k: 0.0025 + output_per_1k: 0.01 + gpt-4o-mini: + input_per_1k: 0.00015 + output_per_1k: 0.0006 +anthropic: + claude-sonnet-4-5: + input_per_1k: 0.003 + output_per_1k: 0.015 +` + +// configureBuiltin points the package-level FactoryContext at a tmp +// directory containing the test pricing fixture. Returns the path so +// callers can override files later if needed. +func configureBuiltin(t *testing.T) string { + t.Helper() + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "pricing.yaml"), []byte(fixturePricing), 0o600), "write pricing fixture") + builtin.Configure(context.Background(), dir, nil, nil, nil) + return dir +} + +func metaValue(t *testing.T, kvs []middleware.KV, key string) (string, bool) { + t.Helper() + for _, kv := range kvs { + if kv.Key == key { + return kv.Value, true + } + } + return "", false +} + +func buildMiddleware(t *testing.T, raw []byte) middleware.Middleware { + t.Helper() + mw, err := Factory{}.New(raw) + require.NoError(t, err, "factory must accept the supplied config") + return mw +} + +func TestMiddleware_StaticSurface(t *testing.T) { + configureBuiltin(t) + mw := buildMiddleware(t, nil) + + assert.Equal(t, ID, mw.ID(), "ID must match the registered constant") + assert.Equal(t, Version, mw.Version(), "Version must match the constant") + assert.Equal(t, middleware.SlotOnResponse, mw.Slot(), "must run in the response slot") + assert.Empty(t, mw.AcceptedContentTypes(), "cost_meter does not inspect bodies") + assert.False(t, mw.MutationsSupported(), "cost_meter never mutates") + assert.NoError(t, mw.Close(), "Close on stateless middleware is a no-op") + + keys := mw.MetadataKeys() + expected := []string{middleware.KeyCostUSDTotal, middleware.KeyCostSkipped} + assert.Equal(t, expected, keys, "metadata key allowlist must match the spec") +} + +func TestFactory_AcceptsEmptyAndJSONConfig(t *testing.T) { + configureBuiltin(t) + cases := [][]byte{nil, {}, []byte("null"), []byte("{}"), []byte(" ")} + for _, raw := range cases { + mw, err := Factory{}.New(raw) + require.NoError(t, err, "empty/null/object config must be accepted") + require.NotNil(t, mw, "factory must return a middleware instance") + } +} + +func TestFactory_RejectsMalformedConfig(t *testing.T) { + configureBuiltin(t) + mw, err := Factory{}.New([]byte("{not json")) + require.Error(t, err, "malformed config must surface at construction") + assert.Nil(t, mw, "no instance is returned on error") +} + +func TestFactory_DefaultPricingPathLoadsFixture(t *testing.T) { + configureBuiltin(t) + mw := buildMiddleware(t, nil) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o-mini"}, + {Key: middleware.KeyLLMInputTokens, Value: "1000"}, + {Key: middleware.KeyLLMOutputTokens, Value: "1000"}, + }, + }) + require.NoError(t, err) + require.Equal(t, middleware.DecisionAllow, out.Decision, "cost_meter always allows") + + value, ok := metaValue(t, out.Metadata, middleware.KeyCostUSDTotal) + require.True(t, ok, "cost.usd_total must be emitted for known model") + assert.Equal(t, "0.000750", value, "0.00015 + 0.0006 per 1k tokens, 6-decimal format") +} + +func TestFactory_PricingPathOverride(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "custom.yaml"), []byte(fixturePricing), 0o600), "write custom pricing") + builtin.Configure(context.Background(), dir, nil, nil, nil) + + raw, err := json.Marshal(Config{PricingPath: "custom.yaml"}) + require.NoError(t, err) + + mw := buildMiddleware(t, raw) + out, err := mw.Invoke(context.Background(), &middleware.Input{ + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + {Key: middleware.KeyLLMInputTokens, Value: "2000"}, + {Key: middleware.KeyLLMOutputTokens, Value: "1000"}, + }, + }) + require.NoError(t, err) + + value, ok := metaValue(t, out.Metadata, middleware.KeyCostUSDTotal) + require.True(t, ok, "cost.usd_total must be emitted with custom pricing path") + assert.Equal(t, "0.015000", value, "2*0.0025 + 1*0.01 = 0.015 with 6-decimal format") +} + +func TestInvoke_ComputesCostForKnownModel(t *testing.T) { + configureBuiltin(t) + mw := buildMiddleware(t, nil) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "anthropic"}, + {Key: middleware.KeyLLMModel, Value: "claude-sonnet-4-5"}, + {Key: middleware.KeyLLMInputTokens, Value: "1000"}, + {Key: middleware.KeyLLMOutputTokens, Value: "1000"}, + }, + }) + require.NoError(t, err) + + value, ok := metaValue(t, out.Metadata, middleware.KeyCostUSDTotal) + require.True(t, ok, "cost.usd_total must be emitted") + assert.Equal(t, "0.018000", value, "0.003 + 0.015 = 0.018 with 6-decimal format") + _, skipped := metaValue(t, out.Metadata, middleware.KeyCostSkipped) + assert.False(t, skipped, "cost.skipped must not be set when cost is computed") +} + +func TestInvoke_MissingProvider(t *testing.T) { + configureBuiltin(t) + mw := buildMiddleware(t, nil) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + {Key: middleware.KeyLLMInputTokens, Value: "10"}, + {Key: middleware.KeyLLMOutputTokens, Value: "10"}, + }, + }) + require.NoError(t, err) + value, ok := metaValue(t, out.Metadata, middleware.KeyCostSkipped) + require.True(t, ok, "cost.skipped must be set when provider is missing") + assert.Equal(t, skipMissingProvider, value, "skip reason matches missing_provider") +} + +func TestInvoke_MissingModel(t *testing.T) { + configureBuiltin(t) + mw := buildMiddleware(t, nil) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMInputTokens, Value: "10"}, + {Key: middleware.KeyLLMOutputTokens, Value: "10"}, + }, + }) + require.NoError(t, err) + value, ok := metaValue(t, out.Metadata, middleware.KeyCostSkipped) + require.True(t, ok, "cost.skipped must be set when model is missing") + assert.Equal(t, skipMissingModel, value, "skip reason matches missing_model") +} + +func TestInvoke_MissingTokens(t *testing.T) { + configureBuiltin(t) + mw := buildMiddleware(t, nil) + + cases := []struct { + name string + md []middleware.KV + }{ + { + name: "input only", + md: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + {Key: middleware.KeyLLMInputTokens, Value: "10"}, + }, + }, + { + name: "output only", + md: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + {Key: middleware.KeyLLMOutputTokens, Value: "10"}, + }, + }, + { + name: "neither", + md: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + out, err := mw.Invoke(context.Background(), &middleware.Input{Metadata: tc.md}) + require.NoError(t, err) + value, ok := metaValue(t, out.Metadata, middleware.KeyCostSkipped) + require.True(t, ok, "cost.skipped must be set when token keys are missing") + assert.Equal(t, skipMissingTokens, value, "skip reason matches missing_tokens") + }) + } +} + +func TestInvoke_UnparseableTokens(t *testing.T) { + configureBuiltin(t) + mw := buildMiddleware(t, nil) + + cases := []struct { + name string + in string + out string + }{ + {name: "input non-numeric", in: "abc", out: "10"}, + {name: "output non-numeric", in: "10", out: "xyz"}, + {name: "both garbage", in: "??", out: "??"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + out, err := mw.Invoke(context.Background(), &middleware.Input{ + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + {Key: middleware.KeyLLMInputTokens, Value: tc.in}, + {Key: middleware.KeyLLMOutputTokens, Value: tc.out}, + }, + }) + require.NoError(t, err) + value, ok := metaValue(t, out.Metadata, middleware.KeyCostSkipped) + require.True(t, ok, "cost.skipped must be set on unparseable tokens") + assert.Equal(t, skipUnparseableTokens, value, "skip reason matches unparseable_tokens") + }) + } +} + +func TestInvoke_ZeroTokens(t *testing.T) { + configureBuiltin(t) + mw := buildMiddleware(t, nil) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + {Key: middleware.KeyLLMInputTokens, Value: "0"}, + {Key: middleware.KeyLLMOutputTokens, Value: "0"}, + }, + }) + require.NoError(t, err) + value, ok := metaValue(t, out.Metadata, middleware.KeyCostSkipped) + require.True(t, ok, "cost.skipped must be set when both token counts are zero") + assert.Equal(t, skipZeroTokens, value, "skip reason matches zero_tokens") + _, hasCost := metaValue(t, out.Metadata, middleware.KeyCostUSDTotal) + assert.False(t, hasCost, "cost.usd_total must not be emitted for zero tokens") +} + +func TestInvoke_UnknownModel(t *testing.T) { + configureBuiltin(t) + mw := buildMiddleware(t, nil) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "fantasy-model-9000"}, + {Key: middleware.KeyLLMInputTokens, Value: "10"}, + {Key: middleware.KeyLLMOutputTokens, Value: "10"}, + }, + }) + require.NoError(t, err) + value, ok := metaValue(t, out.Metadata, middleware.KeyCostSkipped) + require.True(t, ok, "cost.skipped must be set when pricing entry is absent") + assert.Equal(t, skipUnknownModel, value, "skip reason matches unknown_model") +} + +func TestInvoke_NilInput(t *testing.T) { + configureBuiltin(t) + mw := buildMiddleware(t, nil) + + out, err := mw.Invoke(context.Background(), nil) + require.NoError(t, err) + require.NotNil(t, out, "output must be returned even on nil input") + assert.Equal(t, middleware.DecisionAllow, out.Decision, "decision must be allow on nil input") + assert.Empty(t, out.Metadata, "no metadata must be emitted on nil input") +} + +const fixturePricingWithCache = `openai: + gpt-4o: + input_per_1k: 0.0025 + output_per_1k: 0.01 + cached_input_per_1k: 0.00125 +anthropic: + claude-sonnet-4-5: + input_per_1k: 0.003 + output_per_1k: 0.015 + cache_read_per_1k: 0.0003 + cache_creation_per_1k: 0.00375 +` + +// configureBuiltinWithCacheRates points the package-level +// FactoryContext at a tmp directory containing pricing entries that +// include the cache rate fields. +func configureBuiltinWithCacheRates(t *testing.T) { + t.Helper() + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "pricing.yaml"), []byte(fixturePricingWithCache), 0o600), "write cache-aware pricing fixture") + builtin.Configure(context.Background(), dir, nil, nil, nil) +} + +// TestInvoke_OpenAICachedSubsetDiscount proves the OpenAI shape end +// to end through the middleware: cached_input_tokens is treated as a +// SUBSET of input_tokens and discounted at the configured rate, not +// added on top. +func TestInvoke_OpenAICachedSubsetDiscount(t *testing.T) { + configureBuiltinWithCacheRates(t) + mw := buildMiddleware(t, nil) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + {Key: middleware.KeyLLMInputTokens, Value: "1000"}, + {Key: middleware.KeyLLMOutputTokens, Value: "500"}, + {Key: middleware.KeyLLMCachedInputTokens, Value: "750"}, + }, + }) + require.NoError(t, err) + require.Equal(t, middleware.DecisionAllow, out.Decision) + + value, ok := metaValue(t, out.Metadata, middleware.KeyCostUSDTotal) + require.True(t, ok, "cached subset path must produce a cost — never a skip") + // 250 non-cached at 0.0025/1k + 750 cached at 0.00125/1k + 500 output at 0.01/1k. + assert.Equal(t, "0.006563", value, + "cached subset must be billed at the discount rate, non-cached at the full rate; never double-billed") +} + +// TestInvoke_AnthropicCacheBucketsAdditive proves the Anthropic +// shape: cache_read and cache_creation are additive to input_tokens +// and each carries its own rate. +func TestInvoke_AnthropicCacheBucketsAdditive(t *testing.T) { + configureBuiltinWithCacheRates(t) + mw := buildMiddleware(t, nil) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "anthropic"}, + {Key: middleware.KeyLLMModel, Value: "claude-sonnet-4-5"}, + {Key: middleware.KeyLLMInputTokens, Value: "256"}, + {Key: middleware.KeyLLMOutputTokens, Value: "200"}, + {Key: middleware.KeyLLMCachedInputTokens, Value: "768"}, + {Key: middleware.KeyLLMCacheCreationTokens, Value: "512"}, + }, + }) + require.NoError(t, err) + require.Equal(t, middleware.DecisionAllow, out.Decision) + + value, ok := metaValue(t, out.Metadata, middleware.KeyCostUSDTotal) + require.True(t, ok) + // 256 input * 0.003 + 768 cache_read * 0.0003 + 512 cache_creation * 0.00375 + 200 output * 0.015 + // = 0.000768 + 0.0002304 + 0.00192 + 0.003 = 0.0059184 → "0.005918" with 6-decimal format. + assert.Equal(t, "0.005918", value, + "each Anthropic input bucket must bill at its own rate — cache_read cheap, cache_creation expensive, regular input mid") +} + +// TestInvoke_CachedTokensAbsentFallsBackToBaseFormula covers the +// "operator hasn't opted in" path: with no cached metadata keys +// emitted, the meter must produce exactly the same cost as before +// the feature landed. Critical so operators with the new binary but +// no YAML changes see no behavioural drift on OpenAI requests. +func TestInvoke_CachedTokensAbsentFallsBackToBaseFormula(t *testing.T) { + configureBuiltinWithCacheRates(t) + mw := buildMiddleware(t, nil) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + {Key: middleware.KeyLLMInputTokens, Value: "1000"}, + {Key: middleware.KeyLLMOutputTokens, Value: "500"}, + // No KeyLLMCachedInputTokens — the parser didn't see one. + }, + }) + require.NoError(t, err) + value, ok := metaValue(t, out.Metadata, middleware.KeyCostUSDTotal) + require.True(t, ok) + // 1000 input * 0.0025 + 500 output * 0.01 = 0.0025 + 0.005 = 0.0075 + assert.Equal(t, "0.007500", value, "no cached metadata = same cost as before the feature landed") +} + +// TestInvoke_UnparseableCachedTokensSkippedSilently proves the +// optional-bucket contract: a malformed cached_input_tokens metadata +// value falls back to 0 (= no cached count) and continues with the +// regular formula. Cache buckets are a refinement, never a reason to +// abort cost computation. +func TestInvoke_UnparseableCachedTokensSkippedSilently(t *testing.T) { + configureBuiltinWithCacheRates(t) + mw := buildMiddleware(t, nil) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + {Key: middleware.KeyLLMInputTokens, Value: "1000"}, + {Key: middleware.KeyLLMOutputTokens, Value: "500"}, + {Key: middleware.KeyLLMCachedInputTokens, Value: "not-a-number"}, + }, + }) + require.NoError(t, err) + value, ok := metaValue(t, out.Metadata, middleware.KeyCostUSDTotal) + require.True(t, ok, "garbage cache metadata must NOT switch the response from a cost to a skip — fall back to 0 cached") + assert.Equal(t, "0.007500", value, "same as the no-cached-metadata path") +} + +// TestMiddleware_CloseCancelsReloader proves Close stops the per-instance +// pricing-reload goroutine: a chain rebuild retires the old instance and +// calls Close, which must invoke the cancel func startReloader handed it so +// the mtime-poll loop doesn't outlive the chain. +func TestMiddleware_CloseCancelsReloader(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + m := newMiddleware(nil, cancel) + + require.NoError(t, m.Close(), "Close must not error") + require.Error(t, ctx.Err(), "Close must cancel the reloader context so the poll goroutine exits") +} + +// TestMiddleware_CloseNilSafe confirms Close is a no-op (no panic) for an +// instance with no reloader and for a nil receiver. +func TestMiddleware_CloseNilSafe(t *testing.T) { + require.NoError(t, newMiddleware(nil, nil).Close(), "no-reloader Close must be a no-op") + var m *Middleware + require.NoError(t, m.Close(), "nil-receiver Close must be safe") +} diff --git a/proxy/internal/middleware/builtin/llm_guardrail/factory.go b/proxy/internal/middleware/builtin/llm_guardrail/factory.go new file mode 100644 index 000000000..6dd2a8e8d --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_guardrail/factory.go @@ -0,0 +1,82 @@ +package llm_guardrail + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" +) + +// Config is the JSON-decoded shape accepted by the factory. The +// runtime path consumes the normalised allowlist; raw config is not +// retained beyond construction. +type Config struct { + ModelAllowlist []string `json:"model_allowlist"` + PromptCapture PromptCapture `json:"prompt_capture"` +} + +// PromptCapture toggles the optional prompt capture + redaction step +// that emits llm.request_prompt onto the metadata bag. +type PromptCapture struct { + Enabled bool `json:"enabled"` + RedactPii bool `json:"redact_pii"` +} + +// Factory builds a configured llm_guardrail middleware instance. +type Factory struct{} + +// ID returns the registry identifier matching the middleware ID. +func (Factory) ID() string { return ID } + +// New decodes the raw JSON config and returns a ready Middleware. An +// empty / null / empty-object payload yields a zero-value Config. +func (Factory) New(rawConfig []byte) (middleware.Middleware, error) { + cfg := Config{} + if len(rawConfig) > 0 && !isEmptyJSON(rawConfig) { + if err := json.Unmarshal(rawConfig, &cfg); err != nil { + return nil, fmt.Errorf("decode config: %w", err) + } + } + return New(cfg), nil +} + +// isEmptyJSON reports whether the payload is whitespace, null, or an +// empty object/array. The caller skips Unmarshal in that case so the +// zero-value Config flows through unchanged. +func isEmptyJSON(raw []byte) bool { + trimmed := strings.TrimSpace(string(raw)) + switch trimmed { + case "", "null", "{}", "[]": + return true + } + return false +} + +// normaliseConfig lowercases and trims allowlist entries so the runtime +// match is case-insensitive. Empty entries are dropped. +func normaliseConfig(cfg Config) Config { + if len(cfg.ModelAllowlist) == 0 { + return cfg + } + cleaned := make([]string, 0, len(cfg.ModelAllowlist)) + for _, entry := range cfg.ModelAllowlist { + n := normaliseModel(entry) + if n == "" { + continue + } + cleaned = append(cleaned, n) + } + cfg.ModelAllowlist = cleaned + return cfg +} + +// normaliseModel lowercases and trims a single model identifier. +func normaliseModel(model string) string { + return strings.ToLower(strings.TrimSpace(model)) +} + +func init() { + builtin.Register(Factory{}) +} diff --git a/proxy/internal/middleware/builtin/llm_guardrail/middleware.go b/proxy/internal/middleware/builtin/llm_guardrail/middleware.go new file mode 100644 index 000000000..e6259f06f --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_guardrail/middleware.go @@ -0,0 +1,183 @@ +// Package llm_guardrail implements the SlotOnRequest middleware that +// enforces the per-target LLM guardrail policy: a model allowlist +// check and an opt-in prompt-capture step that may run a PII redactor +// before emitting the prompt into the metadata bag. +// +// The middleware runs after llm_request_parser, which is responsible +// for extracting the model and raw prompt onto the metadata side +// channel. llm_guardrail consumes those keys, decides allow/deny, and +// emits its own decision metadata plus the optional redacted prompt. +package llm_guardrail + +import ( + "context" + "unicode/utf8" + + "github.com/netbirdio/netbird/proxy/internal/middleware" +) + +// ID is the registry key for this middleware. +const ID = "llm_guardrail" + +const ( + version = "1.0.0" + maxPromptBytes = 3500 + denyCodeModel = "llm_policy.model_blocked" + denyReasonModel = "model_blocked" + denyMessageModel = "model is not in the policy allowlist" +) + +// Middleware enforces the model allowlist and optionally captures the +// request prompt with PII redaction. +type Middleware struct { + cfg Config +} + +// New constructs a Middleware with the supplied configuration. Model +// allowlist entries are normalised so the runtime check is +// case-insensitive and trim-tolerant. +func New(cfg Config) *Middleware { + return &Middleware{cfg: normaliseConfig(cfg)} +} + +// ID returns the registry identifier. +func (m *Middleware) ID() string { return ID } + +// Version returns the implementation version. +func (m *Middleware) Version() string { return version } + +// Slot reports the chain slot the middleware lives in. +func (m *Middleware) Slot() middleware.Slot { return middleware.SlotOnRequest } + +// AcceptedContentTypes lists the request body content types the +// middleware needs. Guardrail consumes metadata produced upstream and +// does not touch the body itself, but we keep application/json so the +// body policy retains the parsed payload upstream when required. +func (m *Middleware) AcceptedContentTypes() []string { + return []string{"application/json"} +} + +// MetadataKeys is the closed set of metadata keys this middleware may +// emit. The accumulator drops anything outside this allowlist. +func (m *Middleware) MetadataKeys() []string { + return []string{ + middleware.KeyLLMPolicyDecision, + middleware.KeyLLMPolicyReason, + middleware.KeyLLMRequestPrompt, + } +} + +// MutationsSupported reports whether the middleware emits header / body +// mutations. Guardrail never mutates the request. +func (m *Middleware) MutationsSupported() bool { return false } + +// Invoke runs the policy. The model allowlist is the only deny path; +// prompt capture only affects the metadata emitted alongside an allow. +func (m *Middleware) Invoke(_ context.Context, in *middleware.Input) (*middleware.Output, error) { + model, modelPresent := lookupMetadata(in.Metadata, middleware.KeyLLMModel) + + if denial := m.evaluateAllowlist(model, modelPresent); denial != nil { + return denial, nil + } + + out := &middleware.Output{ + Decision: middleware.DecisionAllow, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMPolicyDecision, Value: "allow"}, + {Key: middleware.KeyLLMPolicyReason, Value: ""}, + }, + } + + if prompt, ok := m.capturePrompt(in.Metadata); ok { + out.Metadata = append(out.Metadata, middleware.KV{ + Key: middleware.KeyLLMRequestPrompt, + Value: prompt, + }) + } + + return out, nil +} + +// Close releases resources owned by the middleware. Stateless, so this +// is a no-op. +func (m *Middleware) Close() error { return nil } + +// evaluateAllowlist returns a deny Output when the configured allowlist +// rejects the model. A nil return means the request should proceed. +func (m *Middleware) evaluateAllowlist(model string, modelPresent bool) *middleware.Output { + if len(m.cfg.ModelAllowlist) == 0 { + return nil + } + if !modelPresent { + return nil + } + if m.modelInAllowlist(model) { + return nil + } + return &middleware.Output{ + Decision: middleware.DecisionDeny, + DenyStatus: 403, + DenyReason: &middleware.DenyReason{ + Code: denyCodeModel, + Message: denyMessageModel, + Details: map[string]string{"model": model}, + }, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMPolicyDecision, Value: "deny"}, + {Key: middleware.KeyLLMPolicyReason, Value: denyReasonModel}, + }, + } +} + +// modelInAllowlist reports whether the model matches any allowlist +// entry under the case-insensitive, trim-tolerant comparison rule. +func (m *Middleware) modelInAllowlist(model string) bool { + normalised := normaliseModel(model) + if normalised == "" { + return false + } + for _, allowed := range m.cfg.ModelAllowlist { + if allowed == normalised { + return true + } + } + return false +} + +// capturePrompt returns the prompt to emit and whether it should be +// emitted at all. The truncation guarantee is upheld here regardless of +// whether redaction grew the string. +func (m *Middleware) capturePrompt(meta []middleware.KV) (string, bool) { + if !m.cfg.PromptCapture.Enabled { + return "", false + } + raw, ok := lookupMetadata(meta, middleware.KeyLLMRequestPromptRaw) + if !ok { + return "", false + } + prompt := raw + if m.cfg.PromptCapture.RedactPii { + prompt = redactPII(prompt) + } + if len(prompt) > maxPromptBytes { + // Back off to a UTF-8 rune boundary so we never emit a string + // split mid-rune. + cut := maxPromptBytes + for cut > 0 && !utf8.RuneStart(prompt[cut]) { + cut-- + } + prompt = prompt[:cut] + } + return prompt, true +} + +// lookupMetadata finds the first KV with the given key. Returns the +// value and true when present; the empty string and false otherwise. +func lookupMetadata(meta []middleware.KV, key string) (string, bool) { + for _, kv := range meta { + if kv.Key == key { + return kv.Value, true + } + } + return "", false +} diff --git a/proxy/internal/middleware/builtin/llm_guardrail/middleware_test.go b/proxy/internal/middleware/builtin/llm_guardrail/middleware_test.go new file mode 100644 index 000000000..865dc07af --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_guardrail/middleware_test.go @@ -0,0 +1,219 @@ +package llm_guardrail + +import ( + "context" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/middleware" +) + +func metaValue(t *testing.T, kvs []middleware.KV, key string) (string, bool) { + t.Helper() + for _, kv := range kvs { + if kv.Key == key { + return kv.Value, true + } + } + return "", false +} + +func newInput(meta ...middleware.KV) *middleware.Input { + return &middleware.Input{Slot: middleware.SlotOnRequest, Metadata: meta} +} + +func TestMiddlewareIdentity(t *testing.T) { + mw := New(Config{}) + assert.Equal(t, ID, mw.ID(), "middleware ID must be llm_guardrail") + assert.Equal(t, "1.0.0", mw.Version(), "version must be 1.0.0") + assert.Equal(t, middleware.SlotOnRequest, mw.Slot(), "guardrail must run in SlotOnRequest") + assert.False(t, mw.MutationsSupported(), "guardrail must not mutate requests") + assert.Equal(t, []string{"application/json"}, mw.AcceptedContentTypes(), "guardrail accepts application/json bodies") + assert.Equal(t, + []string{ + middleware.KeyLLMPolicyDecision, + middleware.KeyLLMPolicyReason, + middleware.KeyLLMRequestPrompt, + }, + mw.MetadataKeys(), + "metadata key allowlist must match the spec", + ) + require.NoError(t, mw.Close()) +} + +func TestAllowlistEmptyAllowsAnyModel(t *testing.T) { + mw := New(Config{}) + out, err := mw.Invoke(context.Background(), newInput( + middleware.KV{Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + )) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "empty allowlist must allow any model") + v, ok := metaValue(t, out.Metadata, middleware.KeyLLMPolicyDecision) + require.True(t, ok, "decision metadata must be emitted") + assert.Equal(t, "allow", v, "decision must be allow") + r, ok := metaValue(t, out.Metadata, middleware.KeyLLMPolicyReason) + require.True(t, ok, "reason metadata must be emitted") + assert.Equal(t, "", r, "reason must be empty on allow") +} + +func TestAllowlistMatchAllows(t *testing.T) { + mw := New(Config{ModelAllowlist: []string{"gpt-4o", "claude-opus-4"}}) + out, err := mw.Invoke(context.Background(), newInput( + middleware.KV{Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + )) + require.NoError(t, err) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "model in allowlist must be allowed") +} + +func TestAllowlistMissDenies(t *testing.T) { + mw := New(Config{ModelAllowlist: []string{"gpt-4o"}}) + out, err := mw.Invoke(context.Background(), newInput( + middleware.KV{Key: middleware.KeyLLMModel, Value: "claude-opus-4"}, + )) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionDeny, out.Decision, "non-allowlisted model must be denied") + assert.Equal(t, 403, out.DenyStatus, "deny status must be 403") + require.NotNil(t, out.DenyReason, "deny reason must be populated") + assert.Equal(t, "llm_policy.model_blocked", out.DenyReason.Code, "deny code must match spec") + assert.Equal(t, "model is not in the policy allowlist", out.DenyReason.Message, "deny message must match spec") + assert.Equal(t, "claude-opus-4", out.DenyReason.Details["model"], "deny details must include the offending model") + + dec, _ := metaValue(t, out.Metadata, middleware.KeyLLMPolicyDecision) + assert.Equal(t, "deny", dec, "decision metadata must be deny") + reason, _ := metaValue(t, out.Metadata, middleware.KeyLLMPolicyReason) + assert.Equal(t, "model_blocked", reason, "reason metadata must be model_blocked") +} + +func TestAllowlistCaseInsensitive(t *testing.T) { + mw := New(Config{ModelAllowlist: []string{" GPT-4o ", "Claude-OPUS-4"}}) + cases := []string{"gpt-4o", "GPT-4O", " claude-opus-4 "} + for _, model := range cases { + out, err := mw.Invoke(context.Background(), newInput( + middleware.KV{Key: middleware.KeyLLMModel, Value: model}, + )) + require.NoError(t, err) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "case/whitespace variants must match: %q", model) + } +} + +func TestAllowlistMissingModelKeyAllows(t *testing.T) { + mw := New(Config{ModelAllowlist: []string{"gpt-4o"}}) + out, err := mw.Invoke(context.Background(), newInput()) + require.NoError(t, err) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "missing model key must allow even with non-empty allowlist") + dec, _ := metaValue(t, out.Metadata, middleware.KeyLLMPolicyDecision) + assert.Equal(t, "allow", dec, "decision must be allow when model key is absent") +} + +func TestPromptCaptureDisabledEmitsNoPrompt(t *testing.T) { + mw := New(Config{}) + out, err := mw.Invoke(context.Background(), newInput( + middleware.KV{Key: middleware.KeyLLMRequestPromptRaw, Value: "hello world"}, + )) + require.NoError(t, err) + _, ok := metaValue(t, out.Metadata, middleware.KeyLLMRequestPrompt) + assert.False(t, ok, "prompt must not be emitted when capture is disabled") +} + +func TestPromptCaptureNoRedactionEmitsRaw(t *testing.T) { + mw := New(Config{PromptCapture: PromptCapture{Enabled: true}}) + raw := "hello world from user@example.com" + out, err := mw.Invoke(context.Background(), newInput( + middleware.KV{Key: middleware.KeyLLMRequestPromptRaw, Value: raw}, + )) + require.NoError(t, err) + prompt, ok := metaValue(t, out.Metadata, middleware.KeyLLMRequestPrompt) + require.True(t, ok, "prompt must be emitted when capture is enabled") + assert.Equal(t, raw, prompt, "prompt must pass through unchanged when redaction is off") +} + +func TestPromptCaptureWithRedactionRedacts(t *testing.T) { + mw := New(Config{PromptCapture: PromptCapture{Enabled: true, RedactPii: true}}) + raw := "contact me at user@example.com or +14155551234" + out, err := mw.Invoke(context.Background(), newInput( + middleware.KV{Key: middleware.KeyLLMRequestPromptRaw, Value: raw}, + )) + require.NoError(t, err) + prompt, ok := metaValue(t, out.Metadata, middleware.KeyLLMRequestPrompt) + require.True(t, ok, "prompt must be emitted when capture is enabled") + assert.Contains(t, prompt, "[REDACTED:email]", "email must be redacted") + assert.Contains(t, prompt, "[REDACTED:phone]", "phone must be redacted") + assert.NotContains(t, prompt, "user@example.com", "raw email must not leak") +} + +func TestPromptCaptureRedactionTruncatesIfGrows(t *testing.T) { + mw := New(Config{PromptCapture: PromptCapture{Enabled: true, RedactPii: true}}) + body := strings.Repeat("a", maxPromptBytes-10) + " user@example.com" + out, err := mw.Invoke(context.Background(), newInput( + middleware.KV{Key: middleware.KeyLLMRequestPromptRaw, Value: body}, + )) + require.NoError(t, err) + prompt, ok := metaValue(t, out.Metadata, middleware.KeyLLMRequestPrompt) + require.True(t, ok, "prompt must be emitted when capture is enabled") + assert.LessOrEqual(t, len(prompt), maxPromptBytes, "prompt must be truncated to maxPromptBytes") +} + +func TestPromptCaptureMissingRawNoEmit(t *testing.T) { + mw := New(Config{PromptCapture: PromptCapture{Enabled: true, RedactPii: true}}) + out, err := mw.Invoke(context.Background(), newInput()) + require.NoError(t, err) + _, ok := metaValue(t, out.Metadata, middleware.KeyLLMRequestPrompt) + assert.False(t, ok, "prompt must not be emitted when raw key is missing") +} + +func TestFactoryAcceptsZeroConfigs(t *testing.T) { + cases := map[string][]byte{ + "nil": nil, + "empty": []byte(""), + "whitespace": []byte(" \n "), + "null": []byte("null"), + "emptyObject": []byte("{}"), + } + f := Factory{} + for name, raw := range cases { + mw, err := f.New(raw) + require.NoError(t, err, "case %s must yield a zero-value config", name) + require.NotNil(t, mw) + assert.Equal(t, ID, mw.ID(), "case %s must build a guardrail middleware", name) + } +} + +func TestFactoryDecodesValidConfig(t *testing.T) { + cfg := Config{ + ModelAllowlist: []string{"gpt-4o"}, + PromptCapture: PromptCapture{Enabled: true, RedactPii: true}, + } + raw, err := json.Marshal(cfg) + require.NoError(t, err, "marshalling test config must succeed") + mw, err := Factory{}.New(raw) + require.NoError(t, err) + require.NotNil(t, mw) +} + +func TestFactoryRejectsMalformedJSON(t *testing.T) { + mw, err := Factory{}.New([]byte("{not-json")) + assert.Error(t, err, "malformed JSON must surface as a factory error") + assert.Nil(t, mw, "no middleware must be returned on malformed config") +} + +func TestFactoryNormalisesAllowlist(t *testing.T) { + raw := []byte(`{"model_allowlist":[" GPT-4o ","",""," Claude-3 "]}`) + mw, err := Factory{}.New(raw) + require.NoError(t, err) + out, err := mw.Invoke(context.Background(), newInput( + middleware.KV{Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + )) + require.NoError(t, err) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "factory must lowercase + trim allowlist entries") + out2, err := mw.Invoke(context.Background(), newInput( + middleware.KV{Key: middleware.KeyLLMModel, Value: "claude-3"}, + )) + require.NoError(t, err) + assert.Equal(t, middleware.DecisionAllow, out2.Decision, "trimmed entry must still match") +} diff --git a/proxy/internal/middleware/builtin/llm_guardrail/redact.go b/proxy/internal/middleware/builtin/llm_guardrail/redact.go new file mode 100644 index 000000000..c6cb270df --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_guardrail/redact.go @@ -0,0 +1,75 @@ +package llm_guardrail + +import ( + "regexp" + "strings" + + "github.com/netbirdio/netbird/proxy/internal/middleware" +) + +// PII redactor scope: redact prompt content BEFORE it lands in the metadata +// bag. The bearer-with-keyword pass runs first so the keyword is preserved. +// We then chain the package-level middleware.Scan to pick up PEM, JWT, AWS +// access keys, generic bearer tokens (40+ chars), and Luhn-validated credit +// cards — keeping prompt redaction in sync with metadata-value scanning. Email, +// SSN (dashed form), phone (E.164 + NA), and IPv4 are prompt-shaped patterns +// the metadata scanner intentionally leaves alone. +var ( + emailRegex = regexp.MustCompile(`[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}`) + ssnRegex = regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`) + phoneE164 = regexp.MustCompile(`\+\d{8,15}\b`) + // phoneNARgx accepts the 3-3-4 North-American shape with any of the common + // separators (space, dot, dash, slash) or none at all between the area code + // and the body. The optional `\(?...\)?` wraps the area code; the separator + // classes use `*` (not `?`) so multi-char separators ("(202) " followed by + // space-and-something) and zero-separator runs ("2025550134") both match. + // False-positive tradeoff: 10 consecutive digits in a prompt will be + // treated as a phone number. For PII redaction that is the correct way to + // err — under-redaction leaks; over-redaction is annoying. + phoneNARgx = regexp.MustCompile(`\(?\b\d{3}\)?[\s.\-/]*\d{3}[\s.\-/]*\d{4}\b`) + bearerRegex = regexp.MustCompile(`(?i)\b(bearer|token|api[_-]?key|authorization)([\s:=]+)(\S{20,})`) + ipv4Regex = regexp.MustCompile(`\b(?:(?:25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)\.){3}(?:25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)\b`) +) + +// redactPII is the package-private alias kept for internal callers; new code +// outside the guardrail middleware should call RedactPII. +func redactPII(value string) string { return RedactPII(value) } + +// RedactPII replaces high-signal PII patterns in value with +// `[REDACTED:]`. Non-matching input is returned unchanged. Exported so +// the request / response parsers can reuse the same coverage on raw prompts +// and completions when the account's redact_pii toggle is on. +func RedactPII(value string) string { + if value == "" { + return value + } + result := value + // Keyword-preserving bearer first so the "bearer "/"token=" prefix survives + // before the generic scanner gets at the same content. + result = bearerRegex.ReplaceAllStringFunc(result, redactBearer) + // Structured secrets shared with metadata-value scanning: PEM, JWT, AWS + // keys, generic bearer (40+), and Luhn-validated credit cards. + result = middleware.Scan(result) + // Prompt-shaped PII the metadata scanner doesn't cover. + result = emailRegex.ReplaceAllString(result, "[REDACTED:email]") + result = ssnRegex.ReplaceAllString(result, "[REDACTED:ssn]") + result = phoneE164.ReplaceAllString(result, "[REDACTED:phone]") + result = phoneNARgx.ReplaceAllString(result, "[REDACTED:phone]") + result = ipv4Regex.ReplaceAllString(result, "[REDACTED:ip]") + return result +} + +// redactBearer keeps the leading keyword and its separator, replacing +// only the secret payload so the surrounding context is preserved. +func redactBearer(match string) string { + sub := bearerRegex.FindStringSubmatch(match) + if len(sub) < 4 { + return "[REDACTED:bearer]" + } + var b strings.Builder + b.Grow(len(sub[1]) + len(sub[2]) + len("[REDACTED:bearer]")) + b.WriteString(sub[1]) + b.WriteString(sub[2]) + b.WriteString("[REDACTED:bearer]") + return b.String() +} diff --git a/proxy/internal/middleware/builtin/llm_guardrail/redact_test.go b/proxy/internal/middleware/builtin/llm_guardrail/redact_test.go new file mode 100644 index 000000000..ef17f1d0a --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_guardrail/redact_test.go @@ -0,0 +1,217 @@ +package llm_guardrail + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRedactPIIEmptyInput(t *testing.T) { + assert.Equal(t, "", redactPII(""), "empty input must round-trip unchanged") +} + +func TestRedactPIIPlainTextUntouched(t *testing.T) { + in := "the quick brown fox jumps over the lazy dog" + assert.Equal(t, in, redactPII(in), "non-PII text must pass through unchanged") +} + +func TestRedactPIIEmail(t *testing.T) { + cases := []string{ + "contact user@example.com today", + "first.last+tag@sub.example.co", + "USER_42@EXAMPLE.COM", + } + for _, in := range cases { + out := redactPII(in) + assert.Contains(t, out, "[REDACTED:email]", "email must be redacted in %q", in) + assert.NotContains(t, strings.ToLower(out), "@example", "raw email host must not survive in %q", in) + } +} + +func TestRedactPIISSN(t *testing.T) { + in := "ssn 123-45-6789 should be hidden" + out := redactPII(in) + assert.Contains(t, out, "[REDACTED:ssn]", "SSN must be redacted") + assert.NotContains(t, out, "123-45-6789", "raw SSN must not survive") +} + +func TestRedactPIIPhoneE164(t *testing.T) { + in := "call me at +14155551234 anytime" + out := redactPII(in) + assert.Contains(t, out, "[REDACTED:phone]", "E.164 phone must be redacted") + assert.NotContains(t, out, "+14155551234", "raw E.164 phone must not survive") +} + +func TestRedactPIIPhoneNorthAmerican(t *testing.T) { + cases := []string{ + "call (415) 555-1234 now", + "call 415-555-1234 now", + "call 415.555.1234 now", + "call 415 555 1234 now", + } + for _, in := range cases { + out := redactPII(in) + assert.Contains(t, out, "[REDACTED:phone]", "NA phone must be redacted in %q", in) + assert.NotContains(t, out, "555-1234", "raw NA phone must not survive in %q", in) + } +} + +func TestRedactPIIBearerKeepsKeyword(t *testing.T) { + cases := []struct { + in string + keyword string + }{ + {"Authorization: Bearer abcdefghijklmnopqrstuvwxyz0123", "Bearer"}, + {"token = abcdefghijklmnopqrstuvwxyz", "token"}, + {"api_key=abcdefghijklmnopqrstuvwxyz0123", "api_key"}, + {"API-KEY: abcdefghijklmnopqrstuvwxyz0123", "API-KEY"}, + {"authorization: abcdefghijklmnopqrstuvwxyz0123", "authorization"}, + } + for _, tc := range cases { + out := redactPII(tc.in) + assert.Contains(t, out, "[REDACTED:bearer]", "bearer-style secret must be redacted in %q", tc.in) + assert.Contains(t, out, tc.keyword, "leading keyword %q must be preserved in %q", tc.keyword, tc.in) + assert.NotContains(t, out, "abcdefghijklmnopqrstuvwxyz0123", "raw bearer payload must not survive in %q", tc.in) + } +} + +func TestRedactPIIBearerShortValueUntouched(t *testing.T) { + in := "token=short" + out := redactPII(in) + assert.Equal(t, in, out, "short bearer-style values must not be redacted") +} + +func TestRedactPIICombined(t *testing.T) { + in := "email user@example.com phone +14155551234 ssn 123-45-6789 token abcdefghijklmnopqrstuvwxyz0123" + out := redactPII(in) + assert.Contains(t, out, "[REDACTED:email]", "email must be redacted in combined input") + assert.Contains(t, out, "[REDACTED:phone]", "phone must be redacted in combined input") + assert.Contains(t, out, "[REDACTED:ssn]", "SSN must be redacted in combined input") + assert.Contains(t, out, "[REDACTED:bearer]", "bearer must be redacted in combined input") + assert.NotContains(t, out, "user@example.com", "raw email must not survive combined input") + assert.NotContains(t, out, "+14155551234", "raw phone must not survive combined input") + assert.NotContains(t, out, "123-45-6789", "raw SSN must not survive combined input") +} + +func TestRedactPIICreditCard(t *testing.T) { + // 4242424242424242 is a well-known Stripe test number (Visa, Luhn-valid). + cases := []string{ + "please charge 4242424242424242 now", + "card: 4242-4242-4242-4242", + "4242 4242 4242 4242 expires 12/30", + } + for _, in := range cases { + out := redactPII(in) + assert.Contains(t, out, "[REDACTED:cc]", "Luhn-valid credit card must be redacted in %q", in) + assert.NotContains(t, out, "4242424242424242", "raw card digits must not survive in %q", in) + assert.NotContains(t, out, "4242-4242-4242-4242", "raw dashed card must not survive in %q", in) + } +} + +func TestRedactPIIIPv4(t *testing.T) { + cases := []string{ + "connect to 10.0.42.7 over the tunnel", + "server 192.168.1.100 down", + "public address 203.0.113.42 was hit", + } + for _, in := range cases { + out := redactPII(in) + assert.Contains(t, out, "[REDACTED:ip]", "IPv4 must be redacted in %q", in) + } +} + +func TestRedactPIIJWT(t *testing.T) { + // No "token "/"bearer " prefix here, so the bearer-with-keyword pass leaves + // it alone and the JWT pattern from middleware.Scan must catch it. + in := "session eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ1c2VyXzQyIn0.signaturepart expires soon" + out := redactPII(in) + assert.Contains(t, out, "[REDACTED:jwt]", "JWT must be redacted when no bearer keyword precedes it") + assert.NotContains(t, out, "eyJhbGciOiJIUzI1NiJ9", "raw JWT header must not survive") +} + +func TestRedactPIIAWSAccessKey(t *testing.T) { + in := "the key AKIAIOSFODNN7EXAMPLE belongs to test user" + out := redactPII(in) + assert.Contains(t, out, "[REDACTED:aws_key]", "AWS access key must be redacted") + assert.NotContains(t, out, "AKIAIOSFODNN7EXAMPLE", "raw AWS key must not survive") +} + +func TestRedactPIIPlainNumbersUntouched(t *testing.T) { + // 1234567890123 is 13 digits but fails Luhn; must NOT trip the CC redactor. + // We use a 13-digit value (the CC-candidate range starts at 13) so the only + // risk is the CC pattern firing. Phone redaction is 10-digit by design and + // would catch 1234567890123 as a phone — that's expected and not what this + // test guards against. + in := "order number 1234567890123 is queued" + out := redactPII(in) + assert.NotContains(t, out, "[REDACTED:cc]", "non-Luhn digit sequences must not be redacted as credit cards") +} + +// piiFixture mirrors the user-supplied test fixture: each record carries one +// email, one SSN, and one phone in a representative format. The test asserts +// that EVERY raw token disappears after redaction and the right [REDACTED:*] +// markers show up. Names are kept in the input and must survive — names are +// not a pattern the redactor tries to catch. +type piiFixture struct { + name string // person name (must survive redaction) + email string + ssn string + phone string +} + +var fixtureRecords = []piiFixture{ + {"Alice Johnson", "alice.johnson@example.com", "123-45-6789", "(202) 555-0147"}, + {"Brian Smith", "brian.smith@example.org", "987-65-4321", "202-555-0163"}, + {"Carla Nguyen", "c.nguyen@test.local", "111-22-3333", "+1-202-555-0188"}, + {"David Martinez", "david.martinez@example.com", "222-33-4444", "202.555.0199"}, + {"Evelyn Parker", "evelyn.parker@example.org", "333-44-5555", "1-202-555-0112"}, + {"Frank O'Connor", "frank.oconnor@test.local", "444-55-6666", "2025550134"}, + {"Grace Lee", "grace.lee@example.com", "555-66-7777", "(202)555-0156"}, + {"Hassan Ali", "hassan.ali@example.org", "666-77-8888", "+1 (202) 555-0175"}, + {"Isabella Rossi", "i.rossi@test.local", "777-88-9999", "202 555 0121"}, + {"Jamal Thompson", "jamal.thompson@example.com", "888-99-0001", "202/555/0108"}, +} + +// TestRedactPII_FixtureRecord drives every record through redactPII and +// asserts the email, SSN, and phone are all redacted, the name survives, and +// the appropriate REDACTED markers are present. This is the spec the redactor +// must meet for the kind of prompts operators throw at it. +func TestRedactPII_FixtureRecord(t *testing.T) { + for _, rec := range fixtureRecords { + t.Run(rec.name, func(t *testing.T) { + in := "Name: " + rec.name + "\n Email: " + rec.email + "\n SSN: " + rec.ssn + "\n Phone: " + rec.phone + out := redactPII(in) + + assert.Contains(t, out, rec.name, "name must survive (not a PII pattern the redactor catches)") + assert.Contains(t, out, "[REDACTED:email]", "email marker must appear for %q", rec.email) + assert.Contains(t, out, "[REDACTED:ssn]", "ssn marker must appear for %q", rec.ssn) + assert.Contains(t, out, "[REDACTED:phone]", "phone marker must appear for %q", rec.phone) + + assert.NotContains(t, out, rec.email, "raw email must not survive: %q", rec.email) + assert.NotContains(t, out, rec.ssn, "raw SSN must not survive: %q", rec.ssn) + // Phone: assert the local digits (last 7) are gone. Country-code + // remnants like "+1 " or "1-" may remain in front of the redaction + // because the E.164 pattern needs digits-only after '+' — that's + // acceptable, the personally-identifying portion is removed. + localDigits := lastSevenDigits(rec.phone) + assert.NotContains(t, out, localDigits, "raw phone local digits %q must not survive in redacted output of %q", localDigits, rec.phone) + }) + } +} + +// lastSevenDigits returns the last 7 digits of a phone number, ignoring +// formatting. It's the unique "subscriber" portion that absolutely must be +// scrubbed regardless of which prefix the redactor leaves behind. +func lastSevenDigits(phone string) string { + digits := make([]byte, 0, len(phone)) + for i := 0; i < len(phone); i++ { + if phone[i] >= '0' && phone[i] <= '9' { + digits = append(digits, phone[i]) + } + } + if len(digits) <= 7 { + return string(digits) + } + return string(digits[len(digits)-7:]) +} diff --git a/proxy/internal/middleware/builtin/llm_identity_inject/factory.go b/proxy/internal/middleware/builtin/llm_identity_inject/factory.go new file mode 100644 index 000000000..8594c392d --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_identity_inject/factory.go @@ -0,0 +1,108 @@ +package llm_identity_inject + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" +) + +// ProviderInjection describes one resolved provider's injection rule. +// Identity stamping uses one of HeaderPair / JSONMetadata; ExtraHeaders +// is independent — each entry is a static (operator-configured) header +// stamped on every matching request with anti-spoof. A rule with no +// shape AND no extras is dropped at New() time as a no-op. +type ProviderInjection struct { + // ProviderID is the resolved provider id — matches the value + // llm_router stamps under KeyLLMResolvedProviderID. + ProviderID string `json:"provider_id"` + // HeaderPair is the LiteLLM-style wire convention: separate + // headers for end-user id and tags CSV. + HeaderPair *HeaderPairRule `json:"header_pair,omitempty"` + // JSONMetadata is the Portkey-style wire convention: a single + // header carrying a JSON object keyed by reserved field names. + JSONMetadata *JSONMetadataRule `json:"json_metadata,omitempty"` + // ExtraHeaders is an operator-configured list of static headers + // (e.g. "x-portkey-config: pc-...") that the middleware stamps + // on every matching request. The synth pre-resolves the values + // from the provider record's ExtraValues map; the middleware + // just emits them. Each name is also added to HeadersRemove for + // anti-spoof so a client can't smuggle their own value. + ExtraHeaders []ExtraHeaderKV `json:"extra_headers,omitempty"` +} + +// ExtraHeaderKV is one static header entry the middleware stamps as-is. +type ExtraHeaderKV struct { + Name string `json:"name"` + Value string `json:"value"` +} + +// HeaderPairRule emits identity through dedicated per-dimension +// headers. The two *InBody flags layer body-level identity on top: when +// TagsInBody is set the middleware also writes the tag list into the +// request body's metadata.tags array (required for LiteLLM tag-budget +// enforcement, which only inspects the body); when EndUserIDInBody is +// set the display identity is also written into the body's top-level +// "user" field (the OpenAI-standard end-user identifier — defense-in- +// depth and anti-spoof on top of the header path). +type HeaderPairRule struct { + EndUserIDHeader string `json:"end_user_id_header,omitempty"` + TagsHeader string `json:"tags_header,omitempty"` + TagsInBody bool `json:"tags_in_body,omitempty"` + EndUserIDInBody bool `json:"end_user_id_in_body,omitempty"` +} + +// JSONMetadataRule emits identity through a single JSON-object header. +// Empty UserKey/GroupsKey skip that dimension at emit time. When +// MaxValueLength > 0 each emitted JSON value is truncated to that many +// bytes — Portkey enforces 128 chars per value. +type JSONMetadataRule struct { + Header string `json:"header"` + UserKey string `json:"user_key,omitempty"` + GroupsKey string `json:"groups_key,omitempty"` + MaxValueLength int `json:"max_value_length,omitempty"` +} + +// Config is the on-wire configuration accepted by the factory. An +// empty Providers slice yields a no-op middleware (every resolved +// provider passes through unchanged). +type Config struct { + Providers []ProviderInjection `json:"providers"` +} + +// Factory builds llm_identity_inject instances from raw config bytes. +type Factory struct{} + +// ID returns the registry identifier. +func (Factory) ID() string { return ID } + +// New constructs a middleware instance. Empty, null, and {} configs +// yield a no-op middleware. Non-empty payloads must parse cleanly so +// misconfigurations surface at chain build time. +func (Factory) New(rawConfig []byte) (middleware.Middleware, error) { + cfg := Config{} + if !isEmptyJSON(rawConfig) { + if err := json.Unmarshal(rawConfig, &cfg); err != nil { + return nil, fmt.Errorf("decode config: %w", err) + } + } + return New(cfg), nil +} + +// isEmptyJSON reports whether the payload is whitespace, null, or an +// empty object/array. +func isEmptyJSON(raw []byte) bool { + trimmed := strings.TrimSpace(string(bytes.TrimSpace(raw))) + switch trimmed { + case "", "null", "{}", "[]": + return true + } + return false +} + +func init() { + builtin.Register(Factory{}) +} diff --git a/proxy/internal/middleware/builtin/llm_identity_inject/middleware.go b/proxy/internal/middleware/builtin/llm_identity_inject/middleware.go new file mode 100644 index 000000000..ee3f1c20d --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_identity_inject/middleware.go @@ -0,0 +1,439 @@ +// Package llm_identity_inject implements the SlotOnRequest middleware +// that stamps the caller's NetBird identity onto upstream LLM-gateway +// requests. It runs after llm_router (which resolves the provider) and +// looks up the resolved provider id against a per-account injection +// table built by the synthesiser from the catalog's IdentityInjection +// metadata. +// +// Two wire shapes are supported, dispatched per-rule: +// +// - HeaderPair (LiteLLM-style): separate end-user-id and tags +// headers; tags emitted as a CSV value. +// - JSONMetadata (Portkey-style): one header carrying a JSON +// object with reserved keys for user / groups; per-value byte +// length capped when the rule sets MaxValueLength. +// +// In both cases, identity comes from Input.UserEmail (peer-attached +// user's email or peer.Name fallback) and groups come from the +// authorising-groups intersection llm_router emitted (with +// id→display-name translation via Input.UserGroups / UserGroupNames +// positional pairing). HeadersRemove runs before HeadersAdd in the +// framework, so a client can never spoof identity by stamping these +// headers themselves. +package llm_identity_inject + +import ( + "context" + "encoding/json" + "sort" + "strings" + + "github.com/netbirdio/netbird/proxy/internal/middleware" +) + +// ID is the registry identifier for this middleware. +const ID = "llm_identity_inject" + +// Version is reported via Middleware.Version(). +const Version = "1.0.0" + +// Middleware stamps NetBird identity onto upstream requests for the +// configured set of resolved providers. +type Middleware struct { + cfg Config + byID map[string]ProviderInjection +} + +// New constructs a Middleware from the supplied configuration. A nil +// or empty Providers slice yields a no-op middleware. +func New(cfg Config) *Middleware { + byID := make(map[string]ProviderInjection, len(cfg.Providers)) + for _, p := range cfg.Providers { + if p.ProviderID == "" || !injectionEmitsAnything(p) { + continue + } + byID[p.ProviderID] = p + } + return &Middleware{cfg: cfg, byID: byID} +} + +// injectionEmitsAnything reports whether a provider injection rule would +// stamp anything at runtime. Rules that set both identity shapes are a +// configuration error (we refuse to guess which wins), and rules that +// resolve to no headers are dropped to keep the runtime check tight. +// Non-empty extras alone keep a rule alive even when neither identity +// shape is set. +func injectionEmitsAnything(p ProviderInjection) bool { + hasExtras := false + for _, e := range p.ExtraHeaders { + if e.Name != "" && e.Value != "" { + hasExtras = true + break + } + } + switch { + case p.HeaderPair != nil && p.JSONMetadata != nil: + return false + case p.HeaderPair != nil: + return p.HeaderPair.EndUserIDHeader != "" || p.HeaderPair.TagsHeader != "" || + p.HeaderPair.TagsInBody || p.HeaderPair.EndUserIDInBody || hasExtras + case p.JSONMetadata != nil: + if p.JSONMetadata.Header == "" { + return false + } + return p.JSONMetadata.UserKey != "" || p.JSONMetadata.GroupsKey != "" || hasExtras + default: + return hasExtras + } +} + +// ID returns the registry identifier. +func (m *Middleware) ID() string { return ID } + +// Version returns the implementation version. +func (m *Middleware) Version() string { return Version } + +// Slot reports the chain slot the middleware lives in. +func (m *Middleware) Slot() middleware.Slot { return middleware.SlotOnRequest } + +// AcceptedContentTypes returns nil — this middleware reads only +// metadata and identity fields on the Input envelope. +func (m *Middleware) AcceptedContentTypes() []string { return nil } + +// MetadataKeys is empty: the middleware emits no metadata. Identity +// stamping is a header-only operation. +func (m *Middleware) MetadataKeys() []string { return nil } + +// MutationsSupported reports that the middleware emits header +// mutations on the Output envelope. +func (m *Middleware) MutationsSupported() bool { return true } + +// Close releases resources owned by the middleware. Stateless, so +// this is a no-op. +func (m *Middleware) Close() error { return nil } + +// Invoke stamps identity headers when the resolved provider has an +// injection rule. Always Allow. +func (m *Middleware) Invoke(_ context.Context, in *middleware.Input) (*middleware.Output, error) { + out := &middleware.Output{Decision: middleware.DecisionAllow} + if len(m.byID) == 0 || in == nil { + return out, nil + } + resolved, ok := lookupMetadata(in.Metadata, middleware.KeyLLMResolvedProviderID) + if !ok || resolved == "" { + return out, nil + } + rule, ok := m.byID[resolved] + if !ok { + return out, nil + } + + var mutations *middleware.Mutations + switch { + case rule.HeaderPair != nil: + mutations = applyHeaderPair(rule.HeaderPair, in) + case rule.JSONMetadata != nil: + mutations = applyJSONMetadata(rule.JSONMetadata, in) + } + + // ExtraHeaders are independent of the identity shape. Stamp each + // non-empty entry with anti-spoof: Remove first (frame strips it + // before our Add lands) so a client can't smuggle a value, then + // Add our trusted one. + if len(rule.ExtraHeaders) > 0 { + if mutations == nil { + mutations = &middleware.Mutations{} + } + for _, h := range rule.ExtraHeaders { + if h.Name == "" || h.Value == "" { + continue + } + mutations.HeadersRemove = append(mutations.HeadersRemove, h.Name) + mutations.HeadersAdd = append(mutations.HeadersAdd, middleware.KV{ + Key: h.Name, + Value: h.Value, + }) + } + } + + if mutations == nil || (len(mutations.HeadersAdd) == 0 && len(mutations.HeadersRemove) == 0 && len(mutations.BodyReplace) == 0) { + return out, nil + } + out.Mutations = mutations + return out, nil +} + +// applyHeaderPair builds the LiteLLM-style mutations: separate per- +// dimension headers, with anti-spoof Removes paired with trusted Adds. +func applyHeaderPair(rule *HeaderPairRule, in *middleware.Input) *middleware.Mutations { + mutations := &middleware.Mutations{} + + if rule.EndUserIDHeader != "" { + mutations.HeadersRemove = append(mutations.HeadersRemove, rule.EndUserIDHeader) + // Prefer the email when the auth path carried it: gateways + // like LiteLLM key per-user budgets and dashboards on a + // human-readable identifier; the user_id is an opaque + // management-server primary key. Fall back to user_id when + // no email is available (non-OIDC schemes, legacy JWTs). + if identity := identityFor(in); identity != "" { + mutations.HeadersAdd = append(mutations.HeadersAdd, middleware.KV{ + Key: rule.EndUserIDHeader, + Value: identity, + }) + } + } + + if rule.TagsHeader != "" { + mutations.HeadersRemove = append(mutations.HeadersRemove, rule.TagsHeader) + if csv := authorisingTagsCSV(in); csv != "" { + mutations.HeadersAdd = append(mutations.HeadersAdd, middleware.KV{ + Key: rule.TagsHeader, + Value: csv, + }) + } + } + + if rule.TagsInBody || rule.EndUserIDInBody { + // Body-level identity unlocks gateway behaviour the header + // path can't reach (LiteLLM's _tag_max_budget_check only + // inspects the body; OpenAI direct only reads the body's + // "user" field for attribution). The header path stays + // intact, so we still get attribution + per-end-user budget + // gating when body inject can't run (truncated body, + // non-JSON, hostile metadata shape). + var bodyTags []string + if rule.TagsInBody { + bodyTags = authorisingTagsSlice(in) + } + var bodyUser string + if rule.EndUserIDInBody { + bodyUser = identityFor(in) + } + if newBody, ok := injectIntoBody(in, bodyTags, bodyUser); ok { + mutations.BodyReplace = newBody + } + } + + return mutations +} + +// injectIntoBody parses the request body and writes the supplied +// identity dimensions into it. Tags land at metadata.tags (creating +// the metadata object when absent); the user identity lands at the +// top-level "user" field (OpenAI-standard end-user identifier). +// Returns the re-marshaled body and ok=true when at least one field +// was written. Returns ok=false (no mutation) when: +// +// - both inputs are empty (nothing to write); +// - the body is empty or truncated (we don't have the full document +// to safely round-trip); +// - the body isn't a JSON object (skip silently — this middleware +// only knows how to inject into OpenAI-compatible JSON payloads). +// +// A non-object existing `metadata` field skips the tag write but +// still allows the user write to land — we don't clobber the client's +// non-object metadata, but the orthogonal user field is fair game. +// The header path emission still runs in skip cases, so spend tracking +// + header-resolved end-user budgets continue to work without body- +// level enforcement. +func injectIntoBody(in *middleware.Input, tags []string, userID string) ([]byte, bool) { + wantTags := len(tags) > 0 + wantUser := userID != "" + if !wantTags && !wantUser { + return nil, false + } + if in == nil || len(in.Body) == 0 || in.BodyTruncated { + return nil, false + } + var doc map[string]any + if err := json.Unmarshal(in.Body, &doc); err != nil { + return nil, false + } + injected := false + if wantTags { + var meta map[string]any + if existing, ok := doc["metadata"]; ok { + if typed, isObject := existing.(map[string]any); isObject { + meta = typed + } + // non-object metadata: leave it; tags go unwritten so we + // don't clobber the client's value. Header fallback covers + // spend tracking. + } else { + meta = map[string]any{} + } + if meta != nil { + meta["tags"] = tags + doc["metadata"] = meta + injected = true + } + } + if wantUser { + // Anti-spoof: overwrite any client-supplied "user" so the + // gateway only sees our trusted identity. + doc["user"] = userID + injected = true + } + if !injected { + return nil, false + } + out, err := json.Marshal(doc) + if err != nil { + return nil, false + } + return out, true +} + +// applyJSONMetadata builds the Portkey-style mutations: a single header +// carrying a JSON object keyed by the rule's reserved field names. Per- +// value byte length is capped at MaxValueLength when set (Portkey +// enforces 128 chars). +func applyJSONMetadata(rule *JSONMetadataRule, in *middleware.Input) *middleware.Mutations { + mutations := &middleware.Mutations{} + mutations.HeadersRemove = append(mutations.HeadersRemove, rule.Header) + + payload := map[string]string{} + if rule.UserKey != "" { + if identity := identityFor(in); identity != "" { + payload[rule.UserKey] = truncate(identity, rule.MaxValueLength) + } + } + if rule.GroupsKey != "" { + if csv := authorisingTagsCSV(in); csv != "" { + payload[rule.GroupsKey] = truncate(csv, rule.MaxValueLength) + } + } + if len(payload) == 0 { + return mutations + } + raw, err := json.Marshal(payload) + if err != nil { + return mutations + } + mutations.HeadersAdd = append(mutations.HeadersAdd, middleware.KV{ + Key: rule.Header, + Value: string(raw), + }) + return mutations +} + +// identityFor returns the caller's display identity. UserEmail wins +// (carries the user email when peer-attached, peer.Name otherwise); +// UserID falls in only as a defensive last resort. +func identityFor(in *middleware.Input) string { + if in.UserEmail != "" { + return in.UserEmail + } + return in.UserID +} + +// authorisingTagsSlice returns the sorted, deduplicated slice of group +// display names the request was authorised under. Prefers the per- +// request authorising groups emitted by llm_router (intersection of the +// caller's UserGroups with the resolved route's AllowedGroupIDs) so the +// tags carry only the groups that actually authorise THIS request, not +// every group the peer happens to be in. Falls back to the full +// UserGroups when the router metadata key is absent. +func authorisingTagsSlice(in *middleware.Input) []string { + ids := tagsIDsFromAuthorising(in.Metadata) + if len(ids) == 0 { + ids = in.UserGroups + } + return tagsNamedSlice(ids, in.UserGroups, in.UserGroupNames) +} + +// authorisingTagsCSV is a convenience wrapper that joins +// authorisingTagsSlice with commas for HeaderPair-style emission. +func authorisingTagsCSV(in *middleware.Input) string { + return strings.Join(authorisingTagsSlice(in), ",") +} + +// truncate caps s to maxBytes bytes when maxBytes > 0. No-op when +// maxBytes <= 0 or s already fits. Truncation is byte-wise — sufficient +// for Portkey's 128-char ASCII limit. UTF-8 sequences could in theory +// be split, but the gateway treats the value as opaque bytes. +func truncate(s string, maxBytes int) string { + if maxBytes <= 0 || len(s) <= maxBytes { + return s + } + return s[:maxBytes] +} + +// tagsIDsFromAuthorising reads llm_router's authorising-groups metadata +// (a CSV of group ids) and returns the parsed slice. Returns nil when +// the key is absent or empty so the caller can fall back to the full +// UserGroups. +func tagsIDsFromAuthorising(meta []middleware.KV) []string { + v, ok := lookupMetadata(meta, middleware.KeyLLMAuthorisingGroups) + if !ok { + return nil + } + v = strings.TrimSpace(v) + if v == "" { + return nil + } + parts := strings.Split(v, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + out = append(out, p) + } + } + if len(out) == 0 { + return nil + } + return out +} + +// tagsNamedSlice returns the sorted, deduplicated list of group display +// names. ids carries the canonical group identifiers to emit; +// userGroups + userGroupNames provide the positional id→name +// translation table from the Input envelope. When a name is missing +// for a given id (slice shorter than userGroups, or id absent from the +// table), the id is used verbatim so the tag still attributes +// correctly. Sorted so the same caller produces the same header value +// across requests (helps gateway-side cache hits and log correlation). +func tagsNamedSlice(ids, userGroups, userGroupNames []string) []string { + if len(ids) == 0 { + return nil + } + idToName := make(map[string]string, len(userGroups)) + for i, id := range userGroups { + if i < len(userGroupNames) { + idToName[id] = userGroupNames[i] + } + } + seen := make(map[string]struct{}, len(ids)) + out := make([]string, 0, len(ids)) + for _, id := range ids { + id = strings.TrimSpace(id) + if id == "" { + continue + } + tag := idToName[id] + if tag == "" { + tag = id + } + if _, dup := seen[tag]; dup { + continue + } + seen[tag] = struct{}{} + out = append(out, tag) + } + if len(out) == 0 { + return nil + } + sort.Strings(out) + return out +} + +// lookupMetadata returns the value for key plus a presence flag. +func lookupMetadata(meta []middleware.KV, key string) (string, bool) { + for _, kv := range meta { + if kv.Key == key { + return kv.Value, true + } + } + return "", false +} diff --git a/proxy/internal/middleware/builtin/llm_identity_inject/middleware_test.go b/proxy/internal/middleware/builtin/llm_identity_inject/middleware_test.go new file mode 100644 index 000000000..aab1271d8 --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_identity_inject/middleware_test.go @@ -0,0 +1,666 @@ +package llm_identity_inject + +import ( + "context" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/middleware" +) + +const ( + litellmProvider = "ainp_litellm-test" + portkeyProvider = "ainp_portkey-test" +) + +func newInput(resolvedProvider, userID string, groups []string) *middleware.Input { + return &middleware.Input{ + Slot: middleware.SlotOnRequest, + AccountID: "acct-test", + UserID: userID, + UserGroups: groups, + SourceIP: "100.64.0.5", + RequestID: "req-1", + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMResolvedProviderID, Value: resolvedProvider}, + }, + } +} + +func liteLLMRule() ProviderInjection { + return ProviderInjection{ + ProviderID: litellmProvider, + HeaderPair: &HeaderPairRule{ + EndUserIDHeader: "x-litellm-end-user-id", + TagsHeader: "x-litellm-tags", + }, + } +} + +func TestMiddlewareIdentity(t *testing.T) { + mw := New(Config{}) + assert.Equal(t, ID, mw.ID()) + assert.Equal(t, Version, mw.Version()) + assert.Equal(t, middleware.SlotOnRequest, mw.Slot()) + assert.True(t, mw.MutationsSupported()) + assert.Empty(t, mw.MetadataKeys(), "middleware emits no metadata") + assert.Nil(t, mw.AcceptedContentTypes()) + require.NoError(t, mw.Close()) +} + +func TestInject_MatchedProvider_StampsHeaders(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{liteLLMRule()}}) + in := newInput(litellmProvider, "alice", []string{"grp-eng", "grp-it"}) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision) + require.NotNil(t, out.Mutations) + + // Strips the same headers we're about to add (anti-spoof). + assert.ElementsMatch(t, + []string{"x-litellm-end-user-id", "x-litellm-tags"}, + out.Mutations.HeadersRemove, + "every injected header must also appear in HeadersRemove so client-supplied values are wiped before our trusted values land") + + added := map[string]string{} + for _, kv := range out.Mutations.HeadersAdd { + added[kv.Key] = kv.Value + } + assert.Equal(t, "alice", added["x-litellm-end-user-id"]) + assert.Equal(t, "grp-eng,grp-it", added["x-litellm-tags"], "tags CSV must be sorted") +} + +func TestInject_UnmatchedProvider_NoMutations(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{liteLLMRule()}}) + in := newInput("ainp_some-other-provider", "alice", []string{"grp-eng"}) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision) + assert.Nil(t, out.Mutations, "non-LiteLLM resolved provider must produce no mutations") +} + +func TestInject_NoResolvedProvider_NoMutations(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{liteLLMRule()}}) + in := &middleware.Input{Slot: middleware.SlotOnRequest, UserID: "alice"} + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out) + assert.Nil(t, out.Mutations, + "missing llm.resolved_provider_id metadata means the router didn't run; never stamp identity blindly") +} + +func TestInject_PartialRule_StampsOnlyConfiguredHeaders(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{{ + ProviderID: litellmProvider, + HeaderPair: &HeaderPairRule{ + EndUserIDHeader: "x-litellm-end-user-id", + // TagsHeader intentionally empty. + }, + }}}) + in := newInput(litellmProvider, "alice", []string{"grp-eng"}) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out) + require.NotNil(t, out.Mutations) + + assert.Equal(t, []string{"x-litellm-end-user-id"}, out.Mutations.HeadersRemove, + "only configured header should be stripped") + require.Len(t, out.Mutations.HeadersAdd, 1) + assert.Equal(t, "x-litellm-end-user-id", out.Mutations.HeadersAdd[0].Key) + assert.Equal(t, "alice", out.Mutations.HeadersAdd[0].Value) +} + +func TestInject_EmptyIdentity_StripsButDoesNotAdd(t *testing.T) { + // Caller has no UserID and no groups. We still strip the headers + // (so the client can't inject identity) but we don't add empty + // values that would mislead the gateway. + mw := New(Config{Providers: []ProviderInjection{liteLLMRule()}}) + in := newInput(litellmProvider, "", nil) + in.AccountID = "" + in.SourceIP = "" + in.RequestID = "" + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out) + require.NotNil(t, out.Mutations) + + assert.ElementsMatch(t, + []string{"x-litellm-end-user-id", "x-litellm-tags"}, + out.Mutations.HeadersRemove, + "identity headers must be stripped even when we don't have values to add — anti-spoof") + assert.Empty(t, out.Mutations.HeadersAdd, + "no NetBird identity available; do not stamp empty / misleading values") +} + +func TestInject_TagsCSV_DedupesAndSorts(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{liteLLMRule()}}) + in := newInput(litellmProvider, "alice", []string{"grp-zzz", "grp-aaa", "grp-zzz", "", " "}) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out) + require.NotNil(t, out.Mutations) + + for _, kv := range out.Mutations.HeadersAdd { + if kv.Key == "x-litellm-tags" { + assert.Equal(t, "grp-aaa,grp-zzz", kv.Value, + "tags CSV must dedupe, drop empty, and sort") + return + } + } + t.Fatalf("expected x-litellm-tags in HeadersAdd; got %v", out.Mutations.HeadersAdd) +} + +func TestFactory_RejectsBadJSON(t *testing.T) { + _, err := Factory{}.New([]byte("{not json")) + require.Error(t, err) +} + +func TestFactory_AcceptsEmptyShapes(t *testing.T) { + for _, raw := range [][]byte{nil, []byte(""), []byte(" "), []byte("null"), []byte("{}"), []byte("[]")} { + mw, err := Factory{}.New(raw) + require.NoError(t, err) + require.NotNil(t, mw) + + out, ierr := mw.Invoke(context.Background(), + newInput(litellmProvider, "alice", []string{"grp-eng"})) + require.NoError(t, ierr) + assert.Equal(t, middleware.DecisionAllow, out.Decision) + assert.Nil(t, out.Mutations, + "empty config means no providers to inject for; every resolved provider passes through") + } +} + +func TestFactory_DropsInjectionRuleWithEmptyHeaders(t *testing.T) { + mw, err := Factory{}.New([]byte(`{"providers":[{"provider_id":"x"}]}`)) + require.NoError(t, err) + out, ierr := mw.Invoke(context.Background(), newInput("x", "alice", []string{"grp-eng"})) + require.NoError(t, ierr) + assert.Nil(t, out.Mutations, + "a rule with no header names is functionally a no-op and must be dropped at New() time") +} + +// TestInject_TagsFromAuthorisingMetadata pins that when llm_router has +// emitted llm.authorising_groups, the inject middleware uses THAT +// (the per-request authorising intersection) for the tags header — not +// the full UserGroups, which can include groups unrelated to this +// request's routing. +func TestInject_TagsFromAuthorisingMetadata(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{liteLLMRule()}}) + in := newInput(litellmProvider, "alice", []string{"grp-eng", "grp-it", "grp-oncall"}) + in.Metadata = append(in.Metadata, middleware.KV{ + Key: middleware.KeyLLMAuthorisingGroups, + Value: "grp-eng", + }) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out) + require.NotNil(t, out.Mutations) + + for _, kv := range out.Mutations.HeadersAdd { + if kv.Key == "x-litellm-tags" { + assert.Equal(t, "grp-eng", kv.Value, + "tags must come from llm.authorising_groups, not the full UserGroups; unrelated peer groups must not leak") + return + } + } + t.Fatalf("expected x-litellm-tags in HeadersAdd; got %v", out.Mutations.HeadersAdd) +} + +// TestInject_TagsFallsBackToUserGroups pins the defensive fallback: if +// llm_router didn't emit authorising-groups metadata (chain +// misconfiguration) the middleware uses UserGroups so identity is +// still stamped, just over-broad. +func TestInject_TagsFallsBackToUserGroups(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{liteLLMRule()}}) + in := newInput(litellmProvider, "alice", []string{"grp-eng", "grp-it"}) + // No llm.authorising_groups metadata. + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out) + require.NotNil(t, out.Mutations) + + for _, kv := range out.Mutations.HeadersAdd { + if kv.Key == "x-litellm-tags" { + assert.Equal(t, "grp-eng,grp-it", kv.Value, + "absent metadata must fall back to the full UserGroups CSV") + return + } + } + t.Fatalf("expected x-litellm-tags in HeadersAdd; got %v", out.Mutations.HeadersAdd) +} + +// portkeyRule is the JSONMetadata-shape analogue of liteLLMRule: a +// single x-portkey-metadata header carrying _user and groups, with +// Portkey's 128-byte per-value cap. +func portkeyRule() ProviderInjection { + return ProviderInjection{ + ProviderID: portkeyProvider, + JSONMetadata: &JSONMetadataRule{ + Header: "x-portkey-metadata", + UserKey: "_user", + GroupsKey: "groups", + MaxValueLength: 128, + }, + } +} + +// TestInject_JSONMetadata_StampsHeader pins the Portkey-style emission: +// one header carrying a JSON envelope with reserved keys for user +// identity and groups CSV. +func TestInject_JSONMetadata_StampsHeader(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{portkeyRule()}}) + in := newInput(portkeyProvider, "alice", []string{"grp-eng", "grp-it"}) + in.UserEmail = "alice@example.com" + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out) + require.NotNil(t, out.Mutations) + + assert.Equal(t, []string{"x-portkey-metadata"}, out.Mutations.HeadersRemove, + "the JSON header must be stripped before we add our trusted value") + require.Len(t, out.Mutations.HeadersAdd, 1) + added := out.Mutations.HeadersAdd[0] + assert.Equal(t, "x-portkey-metadata", added.Key) + + var payload map[string]string + require.NoError(t, json.Unmarshal([]byte(added.Value), &payload)) + assert.Equal(t, "alice@example.com", payload["_user"], + "_user reserved key carries the display identity (UserEmail)") + assert.Equal(t, "grp-eng,grp-it", payload["groups"], + "groups key carries the sorted CSV of group display names") +} + +// TestInject_JSONMetadata_TruncatesValues pins the per-value byte cap. +// Portkey rejects metadata values longer than 128 chars; oversized +// values are truncated rather than failing the request. +func TestInject_JSONMetadata_TruncatesValues(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{portkeyRule()}}) + in := newInput(portkeyProvider, "alice", []string{"grp-eng"}) + in.UserEmail = strings.Repeat("a", 200) + "@example.com" + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out.Mutations) + require.Len(t, out.Mutations.HeadersAdd, 1) + + var payload map[string]string + require.NoError(t, json.Unmarshal([]byte(out.Mutations.HeadersAdd[0].Value), &payload)) + assert.Len(t, payload["_user"], 128, + "per-value byte length must be capped at MaxValueLength") +} + +// TestInject_JSONMetadata_EmptyIdentity_StripsButDoesNotAdd verifies the +// anti-spoof Remove still fires when there's nothing to stamp. +func TestInject_JSONMetadata_EmptyIdentity_StripsButDoesNotAdd(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{portkeyRule()}}) + in := newInput(portkeyProvider, "", nil) + in.UserEmail = "" + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out.Mutations) + + assert.Equal(t, []string{"x-portkey-metadata"}, out.Mutations.HeadersRemove, + "strip even with no payload — client can't smuggle identity headers") + assert.Empty(t, out.Mutations.HeadersAdd, + "no NetBird identity available; do not stamp empty / misleading values") +} + +// TestFactory_RejectsRuleWithBothShapes pins the configuration-error +// guard: a rule that sets both HeaderPair and JSONMetadata is dropped +// at New() time rather than guessing which wins. +func TestFactory_RejectsRuleWithBothShapes(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{{ + ProviderID: litellmProvider, + HeaderPair: &HeaderPairRule{ + EndUserIDHeader: "x-litellm-end-user-id", + }, + JSONMetadata: &JSONMetadataRule{ + Header: "x-portkey-metadata", + UserKey: "_user", + }, + }}}) + in := newInput(litellmProvider, "alice", []string{"grp-eng"}) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + assert.Nil(t, out.Mutations, + "a rule that sets both shapes is ambiguous and must be dropped at New() time") +} + +// liteLLMRuleWithBody is the LiteLLM-style rule with body tag injection +// enabled (matches the catalog default). +func liteLLMRuleWithBody() ProviderInjection { + return ProviderInjection{ + ProviderID: litellmProvider, + HeaderPair: &HeaderPairRule{ + EndUserIDHeader: "x-litellm-end-user-id", + TagsHeader: "x-litellm-tags", + TagsInBody: true, + }, + } +} + +// TestInject_BodyTags_AddsMetadataTags pins the body-inject path that +// LiteLLM's _tag_max_budget_check requires. With TagsInBody set, the +// middleware writes the authorising-groups slice into +// request.metadata.tags (in addition to the header). +func TestInject_BodyTags_AddsMetadataTags(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{liteLLMRuleWithBody()}}) + in := newInput(litellmProvider, "alice", []string{"grp-eng", "grp-sre"}) + in.Body = []byte(`{"model":"gpt-4o-mini","messages":[]}`) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out.Mutations) + require.NotEmpty(t, out.Mutations.BodyReplace, "body must be rewritten when TagsInBody is set") + + var doc map[string]any + require.NoError(t, json.Unmarshal(out.Mutations.BodyReplace, &doc)) + meta, ok := doc["metadata"].(map[string]any) + require.True(t, ok, "metadata must be an object") + tags, ok := meta["tags"].([]any) + require.True(t, ok, "metadata.tags must be a JSON array") + got := make([]string, 0, len(tags)) + for _, t := range tags { + s, _ := t.(string) + got = append(got, s) + } + assert.Equal(t, []string{"grp-eng", "grp-sre"}, got, + "metadata.tags must carry the sorted authorising-groups slice") + assert.Equal(t, "gpt-4o-mini", doc["model"], + "the rest of the body must be preserved verbatim") +} + +// TestInject_BodyTags_PreservesExistingMetadata pins that an existing +// metadata object on the request is merged with our tags rather than +// clobbered — clients sometimes set metadata fields the proxy +// shouldn't blow away (jobID, taskName, etc.). +func TestInject_BodyTags_PreservesExistingMetadata(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{liteLLMRuleWithBody()}}) + in := newInput(litellmProvider, "alice", []string{"grp-eng"}) + in.Body = []byte(`{"model":"gpt-4o-mini","metadata":{"jobID":"j-42","tags":["should-be-replaced"]}}`) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotEmpty(t, out.Mutations.BodyReplace) + + var doc map[string]any + require.NoError(t, json.Unmarshal(out.Mutations.BodyReplace, &doc)) + meta := doc["metadata"].(map[string]any) + assert.Equal(t, "j-42", meta["jobID"], + "client-supplied metadata fields outside `tags` must survive") + tags := meta["tags"].([]any) + require.Len(t, tags, 1) + assert.Equal(t, "grp-eng", tags[0], + "our tags overwrite any client-supplied metadata.tags so spoofing is impossible") +} + +// TestInject_BodyTags_SkipsHostileMetadataShape pins the defensive +// refusal: when the request body has a non-object metadata field +// (string/number/array), we don't inject — header path still emits. +func TestInject_BodyTags_SkipsHostileMetadataShape(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{liteLLMRuleWithBody()}}) + in := newInput(litellmProvider, "alice", []string{"grp-eng"}) + in.Body = []byte(`{"model":"gpt-4o-mini","metadata":"not-an-object"}`) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out.Mutations) + assert.Empty(t, out.Mutations.BodyReplace, + "non-object metadata must skip body inject (don't clobber)") + + for _, kv := range out.Mutations.HeadersAdd { + if kv.Key == "x-litellm-tags" { + assert.Equal(t, "grp-eng", kv.Value, + "header path must still emit so spend tracking keeps working") + return + } + } + t.Fatalf("expected x-litellm-tags header even when body inject was skipped") +} + +// TestInject_BodyTags_SkipsTruncatedBody pins that we don't blindly +// rewrite a body we don't have in full. The header path still runs. +func TestInject_BodyTags_SkipsTruncatedBody(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{liteLLMRuleWithBody()}}) + in := newInput(litellmProvider, "alice", []string{"grp-eng"}) + in.Body = []byte(`{"model":"gpt-4o-mini","messages":[]}`) + in.BodyTruncated = true + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + assert.Empty(t, out.Mutations.BodyReplace, + "truncated body must skip body inject — re-marshaling would corrupt the request") +} + +// TestInject_BodyTags_SkipsNonJSONBody pins graceful behavior when the +// body isn't JSON (e.g. a streaming binary or form upload sneaking +// through the LLM chain). Header path still runs. +func TestInject_BodyTags_SkipsNonJSONBody(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{liteLLMRuleWithBody()}}) + in := newInput(litellmProvider, "alice", []string{"grp-eng"}) + in.Body = []byte(`not even close to json`) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + assert.Empty(t, out.Mutations.BodyReplace, + "non-JSON body must skip body inject silently") +} + +// liteLLMRuleFull mirrors the catalog default: header path + body +// metadata.tags (groups) + body user (end-user id). +func liteLLMRuleFull() ProviderInjection { + return ProviderInjection{ + ProviderID: litellmProvider, + HeaderPair: &HeaderPairRule{ + EndUserIDHeader: "x-litellm-end-user-id", + TagsHeader: "x-litellm-tags", + TagsInBody: true, + EndUserIDInBody: true, + }, + } +} + +// TestInject_BodyUser_WritesTopLevelUser pins the EndUserIDInBody path +// alone: body's top-level "user" field carries the display identity. +// Tags-in-body is OFF here so we isolate the user write. +func TestInject_BodyUser_WritesTopLevelUser(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{{ + ProviderID: litellmProvider, + HeaderPair: &HeaderPairRule{ + EndUserIDHeader: "x-litellm-end-user-id", + EndUserIDInBody: true, + }, + }}}) + in := newInput(litellmProvider, "alice", nil) + in.UserEmail = "alice@example.com" + in.Body = []byte(`{"model":"gpt-4o-mini","messages":[]}`) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out.Mutations) + require.NotEmpty(t, out.Mutations.BodyReplace) + + var doc map[string]any + require.NoError(t, json.Unmarshal(out.Mutations.BodyReplace, &doc)) + assert.Equal(t, "alice@example.com", doc["user"], + "body's top-level user field must carry the display identity") + _, hasMeta := doc["metadata"] + assert.False(t, hasMeta, "TagsInBody is off; metadata must not be added") +} + +// TestInject_BodyUser_OverwritesClientSupplied pins anti-spoof: a +// client-supplied "user" in the body is overwritten so the gateway +// only sees our trusted identity. +func TestInject_BodyUser_OverwritesClientSupplied(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{liteLLMRuleFull()}}) + in := newInput(litellmProvider, "alice", []string{"grp-eng"}) + in.UserEmail = "alice@example.com" + in.Body = []byte(`{"model":"gpt-4o-mini","user":"ceo@company.com"}`) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotEmpty(t, out.Mutations.BodyReplace) + + var doc map[string]any + require.NoError(t, json.Unmarshal(out.Mutations.BodyReplace, &doc)) + assert.Equal(t, "alice@example.com", doc["user"], + "client-supplied user must be overwritten with the trusted identity") +} + +// TestInject_BodyCombined_TagsAndUser pins that with both flags on, +// the body carries both metadata.tags AND top-level user, and the +// header path still emits. +func TestInject_BodyCombined_TagsAndUser(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{liteLLMRuleFull()}}) + in := newInput(litellmProvider, "alice", []string{"grp-eng", "grp-sre"}) + in.UserEmail = "alice@example.com" + in.Body = []byte(`{"model":"gpt-4o-mini"}`) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotEmpty(t, out.Mutations.BodyReplace) + + var doc map[string]any + require.NoError(t, json.Unmarshal(out.Mutations.BodyReplace, &doc)) + assert.Equal(t, "alice@example.com", doc["user"]) + meta := doc["metadata"].(map[string]any) + tags := meta["tags"].([]any) + require.Len(t, tags, 2) + assert.Equal(t, "grp-eng", tags[0]) + assert.Equal(t, "grp-sre", tags[1]) + + // Header path still emits — header end-user-id is the primary + // path for LiteLLM's resolver, body is defense-in-depth. + added := map[string]string{} + for _, kv := range out.Mutations.HeadersAdd { + added[kv.Key] = kv.Value + } + assert.Equal(t, "alice@example.com", added["x-litellm-end-user-id"]) + assert.Equal(t, "grp-eng,grp-sre", added["x-litellm-tags"]) +} + +// TestInject_BodyCombined_HostileMetadataKeepsUser pins the partial- +// success path: a hostile (non-object) metadata field skips the tag +// write but still allows the orthogonal user write to land. +func TestInject_BodyCombined_HostileMetadataKeepsUser(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{liteLLMRuleFull()}}) + in := newInput(litellmProvider, "alice", []string{"grp-eng"}) + in.UserEmail = "alice@example.com" + in.Body = []byte(`{"model":"gpt-4o-mini","metadata":"not-an-object"}`) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotEmpty(t, out.Mutations.BodyReplace, + "user write must still go through even when metadata is hostile") + + var doc map[string]any + require.NoError(t, json.Unmarshal(out.Mutations.BodyReplace, &doc)) + assert.Equal(t, "alice@example.com", doc["user"]) + assert.Equal(t, "not-an-object", doc["metadata"], + "hostile metadata must be left untouched, not clobbered") +} + +// TestInject_ExtraHeaders_Stamped pins the extras path: with a +// per-provider ExtraHeader configured (e.g. Portkey config id), the +// middleware stamps it on every matching request and adds the same +// name to HeadersRemove for anti-spoof. +func TestInject_ExtraHeaders_Stamped(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{{ + ProviderID: portkeyProvider, + JSONMetadata: &JSONMetadataRule{ + Header: "x-portkey-metadata", + UserKey: "_user", + GroupsKey: "groups", + }, + ExtraHeaders: []ExtraHeaderKV{ + {Name: "x-portkey-config", Value: "pc-prod-3f2a"}, + }, + }}}) + in := newInput(portkeyProvider, "alice", []string{"grp-eng"}) + in.UserEmail = "alice@example.com" + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out.Mutations) + + assert.Contains(t, out.Mutations.HeadersRemove, "x-portkey-config", + "extras must be stripped before stamping for anti-spoof") + added := map[string]string{} + for _, kv := range out.Mutations.HeadersAdd { + added[kv.Key] = kv.Value + } + assert.Equal(t, "pc-prod-3f2a", added["x-portkey-config"], + "extras must carry the operator-configured value verbatim") + // Identity-stamping shape (JSONMetadata header) still emitted. + assert.Contains(t, added, "x-portkey-metadata", + "extras and identity stamping are independent — both must land") +} + +// TestInject_ExtraHeaders_OnlyRule pins that an extras-only rule +// (no HeaderPair, no JSONMetadata) survives New() and stamps the +// extras anyway. Useful for hypothetical gateways that need a static +// routing header but no NetBird identity stamping. +func TestInject_ExtraHeaders_OnlyRule(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{{ + ProviderID: "ainp_extras-only", + ExtraHeaders: []ExtraHeaderKV{ + {Name: "x-routing-key", Value: "rk-1"}, + }, + }}}) + in := newInput("ainp_extras-only", "alice", nil) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out.Mutations, + "extras alone keep the rule alive — middleware must emit them") + added := map[string]string{} + for _, kv := range out.Mutations.HeadersAdd { + added[kv.Key] = kv.Value + } + assert.Equal(t, "rk-1", added["x-routing-key"]) +} + +// TestInject_ExtraHeaders_EmptyValueSkipped pins that empty values are +// dropped silently (the synth would normally not send them, but the +// middleware is defensive). +func TestInject_ExtraHeaders_EmptyValueSkipped(t *testing.T) { + mw := New(Config{Providers: []ProviderInjection{{ + ProviderID: portkeyProvider, + JSONMetadata: &JSONMetadataRule{ + Header: "x-portkey-metadata", + UserKey: "_user", + }, + ExtraHeaders: []ExtraHeaderKV{ + {Name: "x-portkey-config", Value: ""}, + }, + }}}) + in := newInput(portkeyProvider, "alice", []string{"grp-eng"}) + in.UserEmail = "alice@example.com" + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out.Mutations) + assert.NotContains(t, out.Mutations.HeadersRemove, "x-portkey-config", + "empty extra value must not even strip the header") + for _, kv := range out.Mutations.HeadersAdd { + assert.NotEqual(t, "x-portkey-config", kv.Key, + "empty extra value must not be stamped") + } +} diff --git a/proxy/internal/middleware/builtin/llm_limit_check/factory.go b/proxy/internal/middleware/builtin/llm_limit_check/factory.go new file mode 100644 index 000000000..1068a6867 --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_limit_check/factory.go @@ -0,0 +1,38 @@ +// Package llm_limit_check is the SlotOnRequest middleware that asks +// management which agent-network policy "pays" for the current LLM +// request. On allow, it stamps the selected policy id, attribution +// group id, and effective window length onto the metadata bag so the +// post-flight llm_limit_record middleware can tick the right counters. +// On deny, it returns a 403 carrying the canonical llm_policy.* deny +// code surfaced by management. +package llm_limit_check + +import ( + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" +) + +// ID is the registry identifier for this middleware. +const ID = "llm_limit_check" + +// Factory builds a configured llm_limit_check instance. The factory +// has no per-target config — it pulls the management gRPC client from +// the package-level FactoryContext at construction time. A nil +// MgmtClient on the context is allowed; the middleware then becomes +// a no-op pass-through (allow without attribution) so a partially +// wired environment doesn't break the chain. +type Factory struct{} + +// ID returns the registry identifier matching the middleware ID. +func (Factory) ID() string { return ID } + +// New ignores the rawConfig payload (no per-target config today) and +// returns a Middleware bound to the FactoryContext's MgmtClient. +func (Factory) New(_ []byte) (middleware.Middleware, error) { + ctx := builtin.Context() + return New(ctx.MgmtClient, ctx.Logger), nil +} + +func init() { + builtin.Register(Factory{}) +} diff --git a/proxy/internal/middleware/builtin/llm_limit_check/middleware.go b/proxy/internal/middleware/builtin/llm_limit_check/middleware.go new file mode 100644 index 000000000..bebe4dca4 --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_limit_check/middleware.go @@ -0,0 +1,196 @@ +package llm_limit_check + +import ( + "context" + "strconv" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// Version is reported via Middleware.Version(). +const Version = "1.0.0" + +// callTimeout caps the wall-clock budget for the pre-flight RPC. The +// middleware sits on the request leg, so a slow management call +// translates directly to user-visible latency. 2s is loose enough for +// a healthy management cluster but tight enough that a stalled call +// fails open via the same path nil-MgmtClient does — an enforcement +// gate that adds 30s of latency is worse than a stale gate. +const callTimeout = 2 * time.Second + +// Middleware is the per-target instance that runs the pre-flight check. +type Middleware struct { + mgmt builtin.MgmtClient + logger *log.Logger +} + +// New constructs a Middleware. mgmt may be nil — that's the +// no-management-wired case where the middleware is a pass-through +// (allow without attribution); useful for unit tests and for +// progressive rollout of the management RPC. +func New(mgmt builtin.MgmtClient, logger *log.Logger) *Middleware { + if logger == nil { + logger = log.StandardLogger() + } + return &Middleware{mgmt: mgmt, logger: logger} +} + +// ID returns the registry identifier. +func (m *Middleware) ID() string { return ID } + +// Version returns the implementation version. +func (m *Middleware) Version() string { return Version } + +// Slot reports the chain slot the middleware lives in. +func (m *Middleware) Slot() middleware.Slot { return middleware.SlotOnRequest } + +// AcceptedContentTypes returns nil because the gate consults metadata +// emitted upstream (KeyLLMResolvedProviderID) and never inspects bodies. +func (m *Middleware) AcceptedContentTypes() []string { return nil } + +// MetadataKeys is the closed allowlist of keys this middleware emits. +func (m *Middleware) MetadataKeys() []string { + return []string{ + middleware.KeyLLMSelectedPolicyID, + middleware.KeyLLMAttributionGroupID, + middleware.KeyLLMAttributionWindowS, + middleware.KeyLLMPolicyDecision, + middleware.KeyLLMPolicyReason, + } +} + +// MutationsSupported reports that the middleware never mutates the +// request body or headers; the only outcome is allow + metadata or +// deny. +func (m *Middleware) MutationsSupported() bool { return false } + +// Close releases resources owned by the middleware. Stateless, so +// this is a no-op. +func (m *Middleware) Close() error { return nil } + +// Invoke runs the pre-flight policy check. +func (m *Middleware) Invoke(ctx context.Context, in *middleware.Input) (*middleware.Output, error) { + if m.mgmt == nil { + // No management client wired — fall through to allow with + // no attribution. RecordLLMUsage on the response leg will + // also be a no-op so counters stay at zero. This matches + // the PR1 behaviour exactly so a partial wiring is + // indistinguishable from "no enforcement". + return allowNoAttribution(), nil + } + + providerID := lookupKV(in.Metadata, middleware.KeyLLMResolvedProviderID) + if providerID == "" { + // llm_router didn't emit a resolved provider id — usually + // because the request didn't carry an llm.model. The + // router itself denied; we won't reach here in production, + // but defensively pass through so we never deny on top of + // an upstream allow. + return allowNoAttribution(), nil + } + + rpcCtx, cancel := context.WithTimeout(ctx, callTimeout) + defer cancel() + + resp, err := m.mgmt.CheckLLMPolicyLimits(rpcCtx, &proto.CheckLLMPolicyLimitsRequest{ + AccountId: in.AccountID, + UserId: in.UserID, + GroupIds: append([]string(nil), in.UserGroups...), + ProviderId: providerID, + Model: lookupKV(in.Metadata, middleware.KeyLLMModel), + }) + if err != nil { + // Fail-open on transport / management errors. The + // alternative — denying every request when management is + // unreachable — is worse for v1 (operational outage = + // total LLM outage). Operators can audit via the + // access-log; PR3 can switch to fail-closed under a flag. + m.logger.WithError(err). + WithField("middleware", ID). + Debugf("management pre-flight failed; failing open") + return allowNoAttribution(), nil + } + + if resp.GetDecision() == "deny" { + return denyFromManagement(resp), nil + } + return allowFromManagement(resp), nil +} + +// allowNoAttribution returns the no-op allow envelope used when no +// management client is wired or no provider was resolved. Stamps +// decision=allow but no policy / attribution metadata so +// llm_limit_record skips its post-flight write. +func allowNoAttribution() *middleware.Output { + return &middleware.Output{ + Decision: middleware.DecisionAllow, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMPolicyDecision, Value: "allow"}, + }, + } +} + +// allowFromManagement converts a successful CheckLLMPolicyLimits +// response into the chain's allow envelope, stamping the attribution +// metadata the response leg consumes. +func allowFromManagement(resp *proto.CheckLLMPolicyLimitsResponse) *middleware.Output { + out := &middleware.Output{ + Decision: middleware.DecisionAllow, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMPolicyDecision, Value: "allow"}, + }, + } + if id := resp.GetSelectedPolicyId(); id != "" { + out.Metadata = append(out.Metadata, middleware.KV{Key: middleware.KeyLLMSelectedPolicyID, Value: id}) + } + if g := resp.GetAttributionGroupId(); g != "" { + out.Metadata = append(out.Metadata, middleware.KV{Key: middleware.KeyLLMAttributionGroupID, Value: g}) + } + if w := resp.GetWindowSeconds(); w > 0 { + out.Metadata = append(out.Metadata, middleware.KV{Key: middleware.KeyLLMAttributionWindowS, Value: strconv.FormatInt(w, 10)}) + } + return out +} + +// denyFromManagement converts a deny response into the chain's deny +// envelope. The deny code surfaces verbatim through the framework's +// fixed JSON template; arbitrary middleware bytes can't reach the +// wire. +func denyFromManagement(resp *proto.CheckLLMPolicyLimitsResponse) *middleware.Output { + code := resp.GetDenyCode() + if code == "" { + code = "llm_policy.cap_exceeded" + } + // The canonical code is safe to surface; the management-supplied + // reason can name internal quota details (used amounts, caps, rule + // ids), so keep the public message generic and leave the detail to + // server-side logs. + return &middleware.Output{ + Decision: middleware.DecisionDeny, + DenyStatus: 403, + DenyReason: &middleware.DenyReason{ + Code: code, + Message: "LLM policy limit exceeded", + }, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMPolicyDecision, Value: "deny"}, + {Key: middleware.KeyLLMPolicyReason, Value: code}, + }, + } +} + +// lookupKV returns the value associated with key, or the empty +// string when absent. +func lookupKV(kvs []middleware.KV, key string) string { + for _, kv := range kvs { + if kv.Key == key { + return kv.Value + } + } + return "" +} diff --git a/proxy/internal/middleware/builtin/llm_limit_check/middleware_test.go b/proxy/internal/middleware/builtin/llm_limit_check/middleware_test.go new file mode 100644 index 000000000..2c26c2abe --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_limit_check/middleware_test.go @@ -0,0 +1,186 @@ +package llm_limit_check + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// fakeMgmt is a minimal builtin.MgmtClient stub that lets the test +// drive CheckLLMPolicyLimits responses without a real gRPC dial. +type fakeMgmt struct { + checkResp *proto.CheckLLMPolicyLimitsResponse + checkErr error + checkReq *proto.CheckLLMPolicyLimitsRequest +} + +func (f *fakeMgmt) CheckLLMPolicyLimits(_ context.Context, in *proto.CheckLLMPolicyLimitsRequest, _ ...grpc.CallOption) (*proto.CheckLLMPolicyLimitsResponse, error) { + f.checkReq = in + return f.checkResp, f.checkErr +} + +func (f *fakeMgmt) RecordLLMUsage(_ context.Context, _ *proto.RecordLLMUsageRequest, _ ...grpc.CallOption) (*proto.RecordLLMUsageResponse, error) { + return &proto.RecordLLMUsageResponse{}, nil +} + +func runInvoke(t *testing.T, m *Middleware, in *middleware.Input) *middleware.Output { + t.Helper() + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "Invoke must not propagate transport errors") + require.NotNil(t, out, "Invoke must always return an Output") + return out +} + +// TestInvoke_AllowStampsAttributionMetadata covers the happy path: +// management returns an allow decision with selected_policy_id + +// attribution_group_id + window_seconds, the middleware emits all three +// onto the metadata bag so the post-flight llm_limit_record +// middleware has everything it needs to tick the right counter. +func TestInvoke_AllowStampsAttributionMetadata(t *testing.T) { + mgmt := &fakeMgmt{ + checkResp: &proto.CheckLLMPolicyLimitsResponse{ + Decision: "allow", + SelectedPolicyId: "pol-X", + AttributionGroupId: "grp-engineers", + WindowSeconds: 86_400, + }, + } + m := New(mgmt, nil) + + out := runInvoke(t, m, &middleware.Input{ + AccountID: "acc-1", + UserID: "user-bob", + UserGroups: []string{"grp-engineers"}, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMResolvedProviderID, Value: "prov-1"}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + }, + }) + + assert.Equal(t, middleware.DecisionAllow, out.Decision) + assert.Equal(t, "acc-1", mgmt.checkReq.GetAccountId(), "account_id must round-trip onto the RPC") + assert.Equal(t, "user-bob", mgmt.checkReq.GetUserId()) + assert.Equal(t, []string{"grp-engineers"}, mgmt.checkReq.GetGroupIds()) + assert.Equal(t, "prov-1", mgmt.checkReq.GetProviderId(), "resolved provider id must come from metadata") + assert.Equal(t, "gpt-4o", mgmt.checkReq.GetModel(), "model must come from metadata") + + want := map[string]string{ + middleware.KeyLLMPolicyDecision: "allow", + middleware.KeyLLMSelectedPolicyID: "pol-X", + middleware.KeyLLMAttributionGroupID: "grp-engineers", + middleware.KeyLLMAttributionWindowS: "86400", + } + got := map[string]string{} + for _, kv := range out.Metadata { + got[kv.Key] = kv.Value + } + assert.Equal(t, want, got, "attribution metadata must land on the bag for the response leg to consume") +} + +// TestInvoke_DenyConvertsToProxyDeny proves the deny envelope round- +// trips: management's deny code becomes the proxy framework's deny +// payload at status 403, and the deny reason text is preserved so +// operators can debug from the access log. +func TestInvoke_DenyConvertsToProxyDeny(t *testing.T) { + mgmt := &fakeMgmt{ + checkResp: &proto.CheckLLMPolicyLimitsResponse{ + Decision: "deny", + DenyCode: "llm_policy.token_cap_exceeded", + DenyReason: "group token cap exhausted on policy pol-X (used 1000 of 1000)", + }, + } + m := New(mgmt, nil) + + out := runInvoke(t, m, &middleware.Input{ + AccountID: "acc-1", + UserGroups: []string{"grp-engineers"}, + Metadata: []middleware.KV{{Key: middleware.KeyLLMResolvedProviderID, Value: "prov-1"}}, + }) + + assert.Equal(t, middleware.DecisionDeny, out.Decision) + assert.Equal(t, 403, out.DenyStatus, "policy denials are 403 — same as llm_router's") + require.NotNil(t, out.DenyReason, "deny envelope must carry a reason payload") + assert.Equal(t, "llm_policy.token_cap_exceeded", out.DenyReason.Code, "canonical deny code surfaces to the caller") + // The public message must stay generic: the management reason names + // internal quota detail (used/cap, rule id) that must not leak. + assert.Equal(t, "LLM policy limit exceeded", out.DenyReason.Message, "public deny message must be generic") + assert.NotContains(t, out.DenyReason.Message, "exhausted", "internal quota detail must not reach the caller") + assert.NotContains(t, out.DenyReason.Message, "1000", "internal cap numbers must not reach the caller") +} + +// TestInvoke_NoMgmtClientPassesThrough proves the partial-wiring +// safety: a middleware constructed without a management client +// allows every request without attribution. This makes a half-set-up +// environment indistinguishable from "no enforcement" rather than +// breaking the chain. +func TestInvoke_NoMgmtClientPassesThrough(t *testing.T) { + m := New(nil, nil) + + out := runInvoke(t, m, &middleware.Input{ + AccountID: "acc-1", + UserGroups: []string{"grp-engineers"}, + Metadata: []middleware.KV{{Key: middleware.KeyLLMResolvedProviderID, Value: "prov-1"}}, + }) + + assert.Equal(t, middleware.DecisionAllow, out.Decision) + for _, kv := range out.Metadata { + assert.NotEqual(t, middleware.KeyLLMSelectedPolicyID, kv.Key, + "no mgmt client = no attribution metadata; record middleware then skips its write") + } +} + +// TestInvoke_NoResolvedProviderPassesThrough covers the defensive +// path: when llm_router didn't set llm.resolved_provider_id (which +// only happens on the deny side of llm_router), the gate must NOT +// stack a second deny on top — pass through and let the upstream +// deny stand. +func TestInvoke_NoResolvedProviderPassesThrough(t *testing.T) { + m := New(&fakeMgmt{}, nil) + + out := runInvoke(t, m, &middleware.Input{ + AccountID: "acc-1", + Metadata: []middleware.KV{}, + }) + + assert.Equal(t, middleware.DecisionAllow, out.Decision, + "no resolved provider = the gate has nothing to check; never deny on top of an upstream allow") +} + +// TestInvoke_RPCErrorFailsOpen proves the fail-open contract: a +// transport error from management does NOT deny the request. v1 +// trades enforcement strictness for availability — an unreachable +// management server otherwise turns into a total LLM outage. +func TestInvoke_RPCErrorFailsOpen(t *testing.T) { + m := New(&fakeMgmt{checkErr: errors.New("connection refused")}, nil) + + out := runInvoke(t, m, &middleware.Input{ + AccountID: "acc-1", + UserGroups: []string{"grp-engineers"}, + Metadata: []middleware.KV{{Key: middleware.KeyLLMResolvedProviderID, Value: "prov-1"}}, + }) + + assert.Equal(t, middleware.DecisionAllow, out.Decision, + "transport errors must not cascade into total LLM outages — operators audit via access log") +} + +// TestMetadataKeys_Allowlist locks the closed set this middleware can +// emit. The accumulator drops anything outside this list; adding a +// new emission means updating both the slice and this test. +func TestMetadataKeys_Allowlist(t *testing.T) { + keys := New(nil, nil).MetadataKeys() + want := []string{ + middleware.KeyLLMSelectedPolicyID, + middleware.KeyLLMAttributionGroupID, + middleware.KeyLLMAttributionWindowS, + middleware.KeyLLMPolicyDecision, + middleware.KeyLLMPolicyReason, + } + assert.ElementsMatch(t, want, keys) +} diff --git a/proxy/internal/middleware/builtin/llm_limit_record/factory.go b/proxy/internal/middleware/builtin/llm_limit_record/factory.go new file mode 100644 index 000000000..b42931c0c --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_limit_record/factory.go @@ -0,0 +1,35 @@ +// Package llm_limit_record is the SlotOnResponse middleware that +// posts the served request's token + cost deltas back to management +// so the per-(user, group, window) consumption counters tick. Reads +// the attribution metadata stamped by llm_limit_check on the request +// leg + the token / cost metadata stamped by llm_response_parser and +// cost_meter; skips the write entirely when no attribution metadata +// is present (e.g. catch-all-allow policy with no caps configured). +package llm_limit_record + +import ( + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" +) + +// ID is the registry identifier for this middleware. +const ID = "llm_limit_record" + +// Factory builds a configured llm_limit_record instance bound to the +// FactoryContext's MgmtClient. nil-MgmtClient disables the post-flight +// write entirely (no-op pass-through), matching the request-leg gate's +// behaviour so a partially wired environment is consistent. +type Factory struct{} + +// ID returns the registry identifier matching the middleware ID. +func (Factory) ID() string { return ID } + +// New ignores the rawConfig payload (no per-target config today). +func (Factory) New(_ []byte) (middleware.Middleware, error) { + ctx := builtin.Context() + return New(ctx.MgmtClient, ctx.Logger), nil +} + +func init() { + builtin.Register(Factory{}) +} diff --git a/proxy/internal/middleware/builtin/llm_limit_record/middleware.go b/proxy/internal/middleware/builtin/llm_limit_record/middleware.go new file mode 100644 index 000000000..52dfc73f0 --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_limit_record/middleware.go @@ -0,0 +1,144 @@ +package llm_limit_record + +import ( + "context" + "strconv" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// Version is reported via Middleware.Version(). +const Version = "1.0.0" + +// callTimeout caps the wall-clock budget for the post-flight RPC. +// Longer than the pre-flight gate because this runs after the +// upstream returned and is not on the user-facing latency path — +// a slow record is just a delayed counter increment, not a delayed +// response to the caller. +const callTimeout = 5 * time.Second + +// Middleware posts token + cost deltas to management after a served +// request. Stateless; per-call values come entirely from metadata +// emitted upstream. +type Middleware struct { + mgmt builtin.MgmtClient + logger *log.Logger +} + +// New constructs a Middleware bound to the supplied management +// client. mgmt may be nil — that disables the write entirely so a +// partially wired environment doesn't attempt to dial nothing. +func New(mgmt builtin.MgmtClient, logger *log.Logger) *Middleware { + if logger == nil { + logger = log.StandardLogger() + } + return &Middleware{mgmt: mgmt, logger: logger} +} + +// ID returns the registry identifier. +func (m *Middleware) ID() string { return ID } + +// Version returns the implementation version. +func (m *Middleware) Version() string { return Version } + +// Slot reports that the middleware runs after the upstream call. +func (m *Middleware) Slot() middleware.Slot { return middleware.SlotOnResponse } + +// AcceptedContentTypes is empty: this middleware never inspects +// bodies. It only reads metadata emitted upstream. +func (m *Middleware) AcceptedContentTypes() []string { return []string{} } + +// MetadataKeys is empty — the record middleware never emits its own +// metadata. Its only side effect is the gRPC write to management. +func (m *Middleware) MetadataKeys() []string { return []string{} } + +// MutationsSupported reports that the middleware never mutates the +// response. Its outcome is always Allow. +func (m *Middleware) MutationsSupported() bool { return false } + +// Close releases resources owned by the middleware. Stateless. +func (m *Middleware) Close() error { return nil } + +// Invoke reads the attribution + tokens + cost metadata, calls +// management's RecordLLMUsage, and always returns Allow. RPC errors +// are logged at debug level — the response has already been served +// to the client by the time we get here, so a record failure must +// not surface back through the proxy. +func (m *Middleware) Invoke(ctx context.Context, in *middleware.Input) (*middleware.Output, error) { + out := &middleware.Output{Decision: middleware.DecisionAllow} + if m.mgmt == nil { + return out, nil + } + + tokensIn, _ := strconv.ParseInt(lookupKV(in.Metadata, middleware.KeyLLMInputTokens), 10, 64) + tokensOut, _ := strconv.ParseInt(lookupKV(in.Metadata, middleware.KeyLLMOutputTokens), 10, 64) + costUSD, _ := strconv.ParseFloat(lookupKV(in.Metadata, middleware.KeyCostUSDTotal), 64) + if tokensIn == 0 && tokensOut == 0 && costUSD == 0 { + // llm_response_parser couldn't read usage off the upstream + // response (streaming-not-yet-supported, malformed body, …). + // Skipping the write keeps phantom rows out of the + // consumption table. + return out, nil + } + + windowStr := lookupKV(in.Metadata, middleware.KeyLLMAttributionWindowS) + windowSeconds, _ := strconv.ParseInt(windowStr, 10, 64) + groupID := lookupKV(in.Metadata, middleware.KeyLLMAttributionGroupID) + + // A zero attribution window means no policy cap bound this request (deny at + // the gate, or a catch-all-allow policy). We still record so account-level + // budget rules — which live in their own windows and bind independently of + // policies — accumulate. The management side books the policy dimensions + // only when window_seconds > 0 and fans out to account rules regardless. + if in.UserID == "" && groupID == "" && len(in.UserGroups) == 0 { + m.logger.WithField("middleware", ID). + WithField("account_id", in.AccountID). + Debugf("post-flight skipped: no user/group/groups to attribute (tokens=%d/%d cost=%g window=%d)", tokensIn, tokensOut, costUSD, windowSeconds) + return out, nil + } + + rpcCtx, cancel := context.WithTimeout(ctx, callTimeout) + defer cancel() + + m.logger.WithField("middleware", ID). + WithField("account_id", in.AccountID). + WithField("user_id", in.UserID). + WithField("group_id", groupID). + WithField("group_ids_len", len(in.UserGroups)). + Debugf("post-flight sending RecordLLMUsage (tokens=%d/%d cost=%g window=%d)", tokensIn, tokensOut, costUSD, windowSeconds) + + if _, err := m.mgmt.RecordLLMUsage(rpcCtx, &proto.RecordLLMUsageRequest{ + AccountId: in.AccountID, + UserId: in.UserID, + GroupId: groupID, + WindowSeconds: windowSeconds, + TokensInput: tokensIn, + TokensOutput: tokensOut, + CostUsd: costUSD, + GroupIds: append([]string(nil), in.UserGroups...), + }); err != nil { + m.logger.WithError(err). + WithField("middleware", ID). + WithField("account_id", in.AccountID). + WithField("user_id", in.UserID). + WithField("group_id", groupID). + Debugf("post-flight record failed; counter will lag this request") + } + return out, nil +} + +// lookupKV returns the value associated with key, or the empty +// string when absent. +func lookupKV(kvs []middleware.KV, key string) string { + for _, kv := range kvs { + if kv.Key == key { + return kv.Value + } + } + return "" +} diff --git a/proxy/internal/middleware/builtin/llm_limit_record/middleware_test.go b/proxy/internal/middleware/builtin/llm_limit_record/middleware_test.go new file mode 100644 index 000000000..a98ce9b9f --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_limit_record/middleware_test.go @@ -0,0 +1,191 @@ +package llm_limit_record + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/shared/management/proto" +) + +type fakeMgmt struct { + recordReq *proto.RecordLLMUsageRequest + recordCalled bool + recordErr error +} + +func (f *fakeMgmt) CheckLLMPolicyLimits(_ context.Context, _ *proto.CheckLLMPolicyLimitsRequest, _ ...grpc.CallOption) (*proto.CheckLLMPolicyLimitsResponse, error) { + return &proto.CheckLLMPolicyLimitsResponse{Decision: "allow"}, nil +} + +func (f *fakeMgmt) RecordLLMUsage(_ context.Context, in *proto.RecordLLMUsageRequest, _ ...grpc.CallOption) (*proto.RecordLLMUsageResponse, error) { + f.recordCalled = true + f.recordReq = in + return &proto.RecordLLMUsageResponse{}, f.recordErr +} + +func runInvoke(t *testing.T, m *Middleware, in *middleware.Input) *middleware.Output { + t.Helper() + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out) + return out +} + +// TestInvoke_PostsAttributionWithTokensAndCost covers the happy path: +// when the request leg stamped attribution + the upstream parsers +// stamped tokens + cost, the post-flight call carries every field +// through to RecordLLMUsage. +func TestInvoke_PostsAttributionWithTokensAndCost(t *testing.T) { + mgmt := &fakeMgmt{} + m := New(mgmt, nil) + + out := runInvoke(t, m, &middleware.Input{ + AccountID: "acc-1", + UserID: "user-bob", + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMAttributionGroupID, Value: "grp-engineers"}, + {Key: middleware.KeyLLMAttributionWindowS, Value: "86400"}, + {Key: middleware.KeyLLMInputTokens, Value: "150"}, + {Key: middleware.KeyLLMOutputTokens, Value: "75"}, + {Key: middleware.KeyCostUSDTotal, Value: "0.0125"}, + }, + }) + + assert.Equal(t, middleware.DecisionAllow, out.Decision) + require.True(t, mgmt.recordCalled, "record must be invoked when attribution + usage are both present") + assert.Equal(t, "acc-1", mgmt.recordReq.GetAccountId()) + assert.Equal(t, "user-bob", mgmt.recordReq.GetUserId()) + assert.Equal(t, "grp-engineers", mgmt.recordReq.GetGroupId()) + assert.Equal(t, int64(86_400), mgmt.recordReq.GetWindowSeconds()) + assert.Equal(t, int64(150), mgmt.recordReq.GetTokensInput()) + assert.Equal(t, int64(75), mgmt.recordReq.GetTokensOutput()) + assert.InDelta(t, 0.0125, mgmt.recordReq.GetCostUsd(), 1e-9) +} + +// TestInvoke_NoAttributionWindowStillRecordsForAccountFanOut proves the +// catch-all-allow path now STILL records (window 0): account-level budget +// rules live in their own windows and bind independently of policies, so the +// management side needs the post-flight call even when no policy cap applied. +// The full group set is forwarded so the account fan-out can attribute. +func TestInvoke_NoAttributionWindowStillRecordsForAccountFanOut(t *testing.T) { + mgmt := &fakeMgmt{} + m := New(mgmt, nil) + + runInvoke(t, m, &middleware.Input{ + AccountID: "acc-1", + UserID: "user-bob", + UserGroups: []string{"grp-eng", "grp-oncall"}, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMInputTokens, Value: "150"}, + {Key: middleware.KeyLLMOutputTokens, Value: "75"}, + {Key: middleware.KeyCostUSDTotal, Value: "0.0125"}, + }, + }) + + require.True(t, mgmt.recordCalled, "must record even without a policy window so account budgets accumulate") + assert.Equal(t, int64(0), mgmt.recordReq.GetWindowSeconds(), "no policy window is forwarded as 0") + assert.Empty(t, mgmt.recordReq.GetGroupId(), "no attribution group without a policy") + assert.Equal(t, []string{"grp-eng", "grp-oncall"}, mgmt.recordReq.GetGroupIds(), "full group set must be forwarded for the account fan-out") +} + +// TestInvoke_NoPrincipalSkipsRecord proves that with neither a user nor any +// groups there is nothing to attribute, so the write is skipped. +func TestInvoke_NoPrincipalSkipsRecord(t *testing.T) { + mgmt := &fakeMgmt{} + m := New(mgmt, nil) + + runInvoke(t, m, &middleware.Input{ + AccountID: "acc-1", + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMInputTokens, Value: "150"}, + {Key: middleware.KeyCostUSDTotal, Value: "0.0125"}, + }, + }) + + assert.False(t, mgmt.recordCalled, "no user and no groups = nothing to attribute") +} + +// TestInvoke_ZeroUsageSkipsRecord proves the no-usage-no-write path: +// when the upstream parser couldn't extract token counts (streaming, +// malformed body, …), skipping the write keeps phantom rows out of +// the consumption table. +func TestInvoke_ZeroUsageSkipsRecord(t *testing.T) { + mgmt := &fakeMgmt{} + m := New(mgmt, nil) + + runInvoke(t, m, &middleware.Input{ + AccountID: "acc-1", + UserID: "user-bob", + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMAttributionGroupID, Value: "grp-engineers"}, + {Key: middleware.KeyLLMAttributionWindowS, Value: "86400"}, + }, + }) + + assert.False(t, mgmt.recordCalled, "zero tokens AND zero cost = nothing to record; an upstream parse miss must not surface as a row") +} + +// TestInvoke_RPCErrorIsSwallowed proves the post-flight isolation +// contract: management errors must NOT cascade back to the proxy +// because the upstream response has already been served — failing +// the chain at this point would corrupt the response. Errors are +// logged at debug level and swallowed. +func TestInvoke_RPCErrorIsSwallowed(t *testing.T) { + mgmt := &fakeMgmt{recordErr: errors.New("management down")} + m := New(mgmt, nil) + + out := runInvoke(t, m, &middleware.Input{ + AccountID: "acc-1", + UserID: "user-bob", + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMAttributionGroupID, Value: "grp-engineers"}, + {Key: middleware.KeyLLMAttributionWindowS, Value: "86400"}, + {Key: middleware.KeyLLMInputTokens, Value: "100"}, + }, + }) + + assert.Equal(t, middleware.DecisionAllow, out.Decision, + "a record failure must not surface — the upstream response is already on the wire") +} + +// TestInvoke_NoMgmtClientPassesThrough mirrors the gate's safety +// contract: a partial wiring is consistent. No mgmt client = silent +// skip rather than an unhandled nil-deref. +func TestInvoke_NoMgmtClientPassesThrough(t *testing.T) { + m := New(nil, nil) + out := runInvoke(t, m, &middleware.Input{ + AccountID: "acc-1", + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMAttributionGroupID, Value: "grp-engineers"}, + {Key: middleware.KeyLLMAttributionWindowS, Value: "86400"}, + {Key: middleware.KeyLLMInputTokens, Value: "100"}, + }, + }) + assert.Equal(t, middleware.DecisionAllow, out.Decision) +} + +// TestInvoke_NoIdentitySkipsRecord covers a defensive guard: stamped +// attribution but no user_id AND no group_id (shouldn't happen, but +// possible if the gate ever changes shape) must not write a row keyed +// on empty dimension ids. +func TestInvoke_NoIdentitySkipsRecord(t *testing.T) { + mgmt := &fakeMgmt{} + m := New(mgmt, nil) + + runInvoke(t, m, &middleware.Input{ + AccountID: "acc-1", + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMAttributionWindowS, Value: "86400"}, + {Key: middleware.KeyLLMInputTokens, Value: "100"}, + }, + }) + + assert.False(t, mgmt.recordCalled, + "empty user + group identity must skip the write — never key on empty dimension ids") +} diff --git a/proxy/internal/middleware/builtin/llm_request_parser/bedrock_test.go b/proxy/internal/middleware/builtin/llm_request_parser/bedrock_test.go new file mode 100644 index 000000000..827b81d07 --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_request_parser/bedrock_test.go @@ -0,0 +1,55 @@ +package llm_request_parser + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNormalizeBedrockModel(t *testing.T) { + cases := map[string]string{ + "eu.anthropic.claude-sonnet-4-5-20250929-v1:0": "anthropic.claude-sonnet-4-5", + "us.anthropic.claude-opus-4-8-20250101-v1:0": "anthropic.claude-opus-4-8", + "apac.anthropic.claude-haiku-4-5-v1:0": "anthropic.claude-haiku-4-5", + "anthropic.claude-sonnet-4-5-20250929-v1:0": "anthropic.claude-sonnet-4-5", + "meta.llama3-3-70b-instruct-v1:0": "meta.llama3-3-70b-instruct", + "amazon.nova-pro-v1:0": "amazon.nova-pro", + "amazon.nova-2-lite-v1:0": "amazon.nova-2-lite", + // Inference-profile ARN — model id lives in the last path segment. + "arn:aws:bedrock:eu-central-1:123456789012:inference-profile/eu.anthropic.claude-sonnet-4-5-20250929-v1:0": "anthropic.claude-sonnet-4-5", + } + for in, want := range cases { + require.Equal(t, want, normalizeBedrockModel(in), "normalize %q", in) + } +} + +func TestParseBedrockPath(t *testing.T) { + tests := []struct { + path string + model string + stream bool + ok bool + }{ + {"/model/eu.anthropic.claude-sonnet-4-5-20250929-v1:0/invoke", "anthropic.claude-sonnet-4-5", false, true}, + {"/model/eu.anthropic.claude-sonnet-4-5-20250929-v1:0/invoke-with-response-stream", "anthropic.claude-sonnet-4-5", true, true}, + {"/model/eu.anthropic.claude-sonnet-4-5-20250929-v1:0/converse", "anthropic.claude-sonnet-4-5", false, true}, + {"/model/eu.anthropic.claude-sonnet-4-5-20250929-v1:0/converse-stream", "anthropic.claude-sonnet-4-5", true, true}, + // URL-encoded colon in the version suffix. + {"/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke", "anthropic.claude-sonnet-4-5", false, true}, + // Optional "/bedrock" gateway-namespace prefix. + {"/bedrock/model/eu.anthropic.claude-sonnet-4-5-20250929-v1:0/invoke-with-response-stream", "anthropic.claude-sonnet-4-5", true, true}, + {"/bedrock/model/anthropic.claude-sonnet-4-5-20250929-v1:0/converse", "anthropic.claude-sonnet-4-5", false, true}, + {"/v1/chat/completions", "", false, false}, + {"/model/foo", "", false, false}, + {"/model//invoke", "", false, false}, + {"/model/x/unknown-action", "", false, false}, + } + for _, tt := range tests { + br, ok := parseBedrockPath(tt.path) + require.Equal(t, tt.ok, ok, "ok for %q", tt.path) + if tt.ok { + require.Equal(t, tt.model, br.model, "model for %q", tt.path) + require.Equal(t, tt.stream, br.stream, "stream for %q", tt.path) + } + } +} diff --git a/proxy/internal/middleware/builtin/llm_request_parser/factory.go b/proxy/internal/middleware/builtin/llm_request_parser/factory.go new file mode 100644 index 000000000..8b3776877 --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_request_parser/factory.go @@ -0,0 +1,71 @@ +package llm_request_parser + +import ( + "bytes" + "encoding/json" + "fmt" + + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" +) + +// config is the on-wire config envelope for the middleware. +// +// ProviderID, when set, names the parser to use directly (matched +// against llm.ParserByName, e.g. "openai", "anthropic"). The +// agent-network synthesiser stamps this so requests routed through a +// synthesised provider service don't depend on URL-shape sniffing, +// which is the only signal the middleware otherwise has. +type config struct { + ProviderID string `json:"provider_id,omitempty"` + // RedactPii, when true, runs PII redaction over the captured raw prompt + // before it is emitted as llm.request_prompt_raw — so the + // agent-network access-log row does NOT carry raw emails / SSNs / + // phone numbers even though the framework's per-key redactor (Scan) + // doesn't cover those prompt-shaped patterns. Sourced by the + // synthesiser from the account's redact_pii toggle. + RedactPii bool `json:"redact_pii,omitempty"` + // CapturePrompt gates emission of llm.request_prompt_raw. A nil pointer + // preserves the legacy default (emit), so callers that don't know about + // the toggle (or pre-existing tests with empty config) keep working. + // The synthesiser sets this explicitly to the account's + // enable_prompt_collection toggle: false here suppresses the key + // entirely so the access-log row carries no prompt content at all, + // independent of redact_pii (which only controls the form of the + // content when it IS emitted). + CapturePrompt *bool `json:"capture_prompt,omitempty"` +} + +// Factory builds llm_request_parser instances from raw config bytes. +type Factory struct{} + +// ID returns the registry identifier. +func (Factory) ID() string { return ID } + +// New constructs a middleware instance. Empty, null, and {} configs are +// accepted; non-empty rawConfig that fails to unmarshal is rejected so +// misconfigurations surface at chain build time. +func (Factory) New(rawConfig []byte) (middleware.Middleware, error) { + var cfg config + if len(bytes.TrimSpace(rawConfig)) > 0 { + // Strict decode: a typo'd field (e.g. "capture_prompts") must fail + // chain build rather than silently fall back to the emit-everything + // default and leak prompts. + dec := json.NewDecoder(bytes.NewReader(rawConfig)) + dec.DisallowUnknownFields() + if err := dec.Decode(&cfg); err != nil { + return nil, fmt.Errorf("decode config: %w", err) + } + } + // Default capturePrompt to true (legacy emission) when the field is + // absent so non-agent-network callers and pre-toggle tests keep working. + capturePrompt := true + if cfg.CapturePrompt != nil { + capturePrompt = *cfg.CapturePrompt + } + return middlewareImpl{providerID: cfg.ProviderID, redactPii: cfg.RedactPii, capturePrompt: capturePrompt}, nil +} + +func init() { + builtin.Register(Factory{}) +} diff --git a/proxy/internal/middleware/builtin/llm_request_parser/middleware.go b/proxy/internal/middleware/builtin/llm_request_parser/middleware.go new file mode 100644 index 000000000..64ca04e6a --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_request_parser/middleware.go @@ -0,0 +1,453 @@ +// Package llm_request_parser implements the SlotOnRequest middleware +// that detects the LLM provider from the request URL, parses the JSON +// request body for model and streaming flags, and extracts the user +// prompt text. Emitted metadata feeds downstream middlewares (guardrail, +// cost meter) and the access-log terminal sink. +package llm_request_parser + +import ( + "context" + "net/url" + "regexp" + "strconv" + "strings" + "unicode/utf8" + + "github.com/netbirdio/netbird/proxy/internal/llm" + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_guardrail" +) + +// ID is the registry key for this middleware. +const ID = "llm_request_parser" + +// Version is reported via Middleware.Version(). +const Version = "1.0.0" + +// maxPromptBytes caps llm.request_prompt_raw at a size that fits within +// MaxMetadataValueBytes with headroom. Truncation is rune-safe. +const maxPromptBytes = 3500 + +// middlewareImpl is the concrete implementation. providerID, when set, +// names the parser to use directly (bypasses URL sniffing). It is empty +// for non-agent-network targets, which fall back to DetectParser on the +// request path. +type middlewareImpl struct { + providerID string + redactPii bool + capturePrompt bool +} + +// ID returns the registry identifier. +func (middlewareImpl) ID() string { return ID } + +// Version returns the implementation version. +func (middlewareImpl) Version() string { return Version } + +// Slot reports the request slot. +func (middlewareImpl) Slot() middleware.Slot { return middleware.SlotOnRequest } + +// AcceptedContentTypes restricts body inspection to JSON. +func (middlewareImpl) AcceptedContentTypes() []string { + return []string{"application/json"} +} + +// MetadataKeys lists the closed allowlist of keys this middleware emits. +func (middlewareImpl) MetadataKeys() []string { + return []string{ + middleware.KeyLLMProvider, + middleware.KeyLLMModel, + middleware.KeyLLMStream, + middleware.KeyLLMRequestPromptRaw, + middleware.KeyLLMCaptureTruncated, + middleware.KeyLLMSessionID, + } +} + +// MutationsSupported reports that this middleware never mutates. +func (middlewareImpl) MutationsSupported() bool { return false } + +// Close is a no-op; the middleware is stateless. +func (middlewareImpl) Close() error { return nil } + +// Invoke detects the LLM provider, parses request facts, and emits +// metadata. Always returns DecisionAllow; never errors. Provider +// selection prefers the configured providerID (synthesiser-stamped on +// agent-network targets) so requests routed to a custom upstream URL +// still resolve. Falls back to URL sniffing when no providerID is set. +func (m middlewareImpl) Invoke(_ context.Context, in *middleware.Input) (*middleware.Output, error) { + out := &middleware.Output{Decision: middleware.DecisionAllow} + if in == nil { + return out, nil + } + + // Google Vertex AI carries the model + publisher (vendor) in the URL path, + // not the body, so it needs a dedicated extraction path. + if vx, okv := parseVertexPath(extractPath(in.URL)); okv { + return m.invokeVertex(in, vx), nil + } + + // AWS Bedrock likewise carries the model in the URL path (/model/{id}/{action}). + if br, okb := parseBedrockPath(extractPath(in.URL)); okb { + return m.invokeBedrock(in, br), nil + } + + parser, ok := llm.ParserByName(m.providerID) + if !ok { + parser, ok = llm.DetectParser(extractPath(in.URL)) + } + if !ok { + return out, nil + } + + md := []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: parser.ProviderName()}, + } + + // Session id is an opaque grouping identifier, not prompt content, so + // it's emitted regardless of the prompt-collection toggle — session + // grouping must work even when prompt capture is off. Prefer a header + // (Codex sends the session as an HTTP header, and headers survive an + // oversized request whose body capture was bypassed) and resolve it + // before ParseRequest so a malformed body still keeps the header id. + sessionID := sessionIDFromHeaders(in.Headers) + if sessionID == "" { + sessionID = parser.ExtractSessionID(in.Body) + } + appendSessionID := func(md []middleware.KV) []middleware.KV { + if sessionID != "" { + return append(md, middleware.KV{Key: middleware.KeyLLMSessionID, Value: sessionID}) + } + return md + } + + facts, err := parser.ParseRequest(in.Body) + if err != nil { + if logger := builtin.Context().Logger; logger != nil { + logger.Debugf("llm_request_parser: parse request body: %v", err) + } + md = appendSessionID(md) + md = appendCaptureTruncated(md, false, in.BodyTruncated) + out.Metadata = md + return out, nil + } + + if facts.Model != "" { + md = append(md, middleware.KV{Key: middleware.KeyLLMModel, Value: facts.Model}) + } + md = append(md, middleware.KV{Key: middleware.KeyLLMStream, Value: strconv.FormatBool(facts.Stream)}) + md = appendSessionID(md) + + prompt, promptTruncated := truncatePrompt(parser.ExtractPrompt(in.Body)) + if prompt != "" && m.capturePrompt { + if m.redactPii { + // Apply redaction BEFORE the value lands in the metadata bag, so + // the access-log row never carries raw emails / SSNs / phones. + // The downstream llm_guardrail middleware reads this key to + // produce llm.request_prompt; RedactPII is idempotent so its + // second pass is a no-op. Redaction can grow the text, so + // re-truncate to keep the value within the metadata cap. + prompt = llm_guardrail.RedactPII(prompt) + var redactedTruncated bool + prompt, redactedTruncated = truncatePrompt(prompt) + promptTruncated = promptTruncated || redactedTruncated + } + md = append(md, middleware.KV{Key: middleware.KeyLLMRequestPromptRaw, Value: prompt}) + } + + md = appendCaptureTruncated(md, promptTruncated, in.BodyTruncated) + out.Metadata = md + return out, nil +} + +// sessionIDHeaders are request header names that may carry a client +// session identifier, checked in order, case-insensitively. Matching is +// against Go's canonical header form, so use the hyphenated names the +// clients actually send: "x-claude-code-session-id" (Claude Code), +// "session-id" (OpenAI Codex — confirmed on the wire as "Session-Id"), +// and "x-session-id" as a generic convention. +var sessionIDHeaders = []string{"x-claude-code-session-id", "session-id", "x-session-id"} + +// sessionIDFromHeaders returns the first non-empty value among the known +// session header names, or "" when none is present. Headers arrive in +// canonical form, so the match is case-insensitive. +func sessionIDFromHeaders(headers []middleware.KV) string { + for _, want := range sessionIDHeaders { + for _, kv := range headers { + if strings.EqualFold(kv.Key, want) && kv.Value != "" { + return kv.Value + } + } + } + return "" +} + +// appendCaptureTruncated stamps the capture_truncated marker reflecting +// either prompt-side truncation or upstream body truncation. +func appendCaptureTruncated(md []middleware.KV, promptTruncated, bodyTruncated bool) []middleware.KV { + value := "false" + if promptTruncated || bodyTruncated { + value = "true" + } + return append(md, middleware.KV{Key: middleware.KeyLLMCaptureTruncated, Value: value}) +} + +// truncatePrompt clamps a prompt string to maxPromptBytes on a UTF-8 +// rune boundary. Returns the clamped string and whether truncation +// occurred. +func truncatePrompt(s string) (string, bool) { + if len(s) <= maxPromptBytes { + return s, false + } + cut := maxPromptBytes + for cut > 0 && !utf8.RuneStart(s[cut]) { + cut-- + } + return s[:cut], true +} + +// extractPath returns the path component of a URL that may be absolute +// or already a path. Parse errors fall back to the raw input. +func extractPath(raw string) string { + if raw == "" { + return "" + } + u, err := url.Parse(raw) + if err != nil || u.Path == "" { + return raw + } + return u.Path +} + +// vertexRequest is the model + vendor extracted from a Vertex AI publisher +// path (the model is in the URL, not the body). +type vertexRequest struct { + publisher string + model string + stream bool +} + +// parseVertexPath extracts the publisher, model, and streaming flag from a +// Vertex publisher endpoint: +// +// /v1/projects/{project}/locations/{region}/publishers/{publisher}/models/{model}:{action} +// +// The model's "@version" suffix is stripped so it matches catalog/pricing. +func parseVertexPath(reqPath string) (vertexRequest, bool) { + const pubSep, modSep = "/publishers/", "/models/" + if !strings.HasPrefix(reqPath, "/v1/projects/") { + return vertexRequest{}, false + } + pubIdx := strings.Index(reqPath, pubSep) + modIdx := strings.Index(reqPath, modSep) + if pubIdx < 0 || modIdx <= pubIdx { + return vertexRequest{}, false + } + publisher := reqPath[pubIdx+len(pubSep) : modIdx] + rest := reqPath[modIdx+len(modSep):] // {model}:{action} + if publisher == "" || rest == "" { + return vertexRequest{}, false + } + model, action := rest, "" + if c := strings.LastIndex(rest, ":"); c >= 0 { + model, action = rest[:c], rest[c+1:] + } + if at := strings.Index(model, "@"); at >= 0 { + model = model[:at] + } + if model == "" { + return vertexRequest{}, false + } + return vertexRequest{publisher: publisher, model: model, stream: strings.HasPrefix(action, "stream")}, true +} + +// vertexPublisherVendor maps a Vertex publisher to the parser surface its +// requests/responses speak. Empty for publishers without a parser yet +// (e.g. google/gemini) — the request still routes, but isn't metered. +func vertexPublisherVendor(publisher string) string { + switch strings.ToLower(publisher) { + case "anthropic": + return "anthropic" + case "openai": + return "openai" + default: + return "" + } +} + +// invokeVertex emits the model/vendor/session/prompt for a Vertex publisher +// request, using the publisher's parser to read the (vendor-native) body. +func (m middlewareImpl) invokeVertex(in *middleware.Input, vx vertexRequest) *middleware.Output { + out := &middleware.Output{Decision: middleware.DecisionAllow} + vendor := vertexPublisherVendor(vx.publisher) + + md := []middleware.KV{} + if vendor != "" { + md = append(md, middleware.KV{Key: middleware.KeyLLMProvider, Value: vendor}) + } + md = append(md, middleware.KV{Key: middleware.KeyLLMModel, Value: vx.model}) + md = append(md, middleware.KV{Key: middleware.KeyLLMStream, Value: strconv.FormatBool(vx.stream)}) + + var parser llm.Parser + if vendor != "" { + parser, _ = llm.ParserByName(vendor) + } + + sessionID := sessionIDFromHeaders(in.Headers) + if sessionID == "" && parser != nil { + sessionID = parser.ExtractSessionID(in.Body) + } + if sessionID != "" { + md = append(md, middleware.KV{Key: middleware.KeyLLMSessionID, Value: sessionID}) + } + + promptTruncated := false + if parser != nil && m.capturePrompt { + var prompt string + prompt, promptTruncated = truncatePrompt(parser.ExtractPrompt(in.Body)) + if prompt != "" { + if m.redactPii { + prompt = llm_guardrail.RedactPII(prompt) + var rt bool + prompt, rt = truncatePrompt(prompt) + promptTruncated = promptTruncated || rt + } + md = append(md, middleware.KV{Key: middleware.KeyLLMRequestPromptRaw, Value: prompt}) + } + } + md = appendCaptureTruncated(md, promptTruncated, in.BodyTruncated) + out.Metadata = md + return out +} + +// bedrockRequest is the model + streaming flag extracted from an AWS Bedrock +// model path. The InvokeModel vs Converse distinction is recovered downstream +// from the response body shape, so only the streaming flag is carried here. +type bedrockRequest struct { + model string + stream bool +} + +// bedrockNamespacePrefix is an optional gateway-namespace prefix some clients +// put before the native Bedrock path to disambiguate it from other providers +// that also use "/model/...". +const bedrockNamespacePrefix = "/bedrock" + +// trimBedrockNamespace removes an optional "/bedrock" namespace prefix, leaving +// the native Bedrock path ("/model/..."). +func trimBedrockNamespace(reqPath string) string { + if strings.HasPrefix(reqPath, bedrockNamespacePrefix+"/") { + return strings.TrimPrefix(reqPath, bedrockNamespacePrefix) + } + return reqPath +} + +// bedrockRegionPrefixes are the cross-region inference-profile prefixes that +// front a Bedrock model id (e.g. "eu.anthropic.claude-..."). +var bedrockRegionPrefixes = []string{"us.", "eu.", "apac.", "global."} + +// bedrockVersionSuffix matches the trailing "-vN[:N]" or "-YYYYMMDD-vN[:N]" +// version/throughput suffix of a Bedrock model id. +var bedrockVersionSuffix = regexp.MustCompile(`-(\d{8}-)?v\d+(:\d+)?$`) + +// parseBedrockPath extracts the model and streaming/converse flags from an AWS +// Bedrock runtime model endpoint: +// +// /model/{modelId}/{action} +// +// action ∈ {invoke, invoke-with-response-stream, converse, converse-stream}. +// The modelId may be URL-encoded and may carry a cross-region inference-profile +// prefix and a version suffix; normalizeBedrockModel strips both so the model +// matches catalog pricing. +func parseBedrockPath(reqPath string) (bedrockRequest, bool) { + reqPath = trimBedrockNamespace(reqPath) + const prefix = "/model/" + if !strings.HasPrefix(reqPath, prefix) { + return bedrockRequest{}, false + } + rest := reqPath[len(prefix):] + slash := strings.LastIndex(rest, "/") + if slash <= 0 || slash == len(rest)-1 { + return bedrockRequest{}, false + } + rawModel, action := rest[:slash], rest[slash+1:] + if decoded, err := url.PathUnescape(rawModel); err == nil { + rawModel = decoded + } + model := normalizeBedrockModel(rawModel) + if model == "" { + return bedrockRequest{}, false + } + switch action { + case "invoke", "converse": + return bedrockRequest{model: model}, true + case "invoke-with-response-stream", "converse-stream": + return bedrockRequest{model: model, stream: true}, true + default: + return bedrockRequest{}, false + } +} + +// normalizeBedrockModel strips an ARN wrapper, a cross-region inference-profile +// prefix, and the version/throughput suffix from a Bedrock model id so it +// matches the catalog/pricing key, e.g. +// "eu.anthropic.claude-sonnet-4-5-20250929-v1:0" -> "anthropic.claude-sonnet-4-5" +// and "arn:aws:bedrock:eu-central-1:123:inference-profile/eu.anthropic.claude-sonnet-4-5-20250929-v1:0" +// -> "anthropic.claude-sonnet-4-5". +func normalizeBedrockModel(modelID string) string { + m := modelID + // A full ARN (inference-profile / provisioned-throughput / foundation-model) + // carries the model id in its last path segment. + if strings.HasPrefix(m, "arn:") { + if i := strings.LastIndex(m, "/"); i >= 0 { + m = m[i+1:] + } + } + for _, p := range bedrockRegionPrefixes { + if strings.HasPrefix(m, p) { + m = m[len(p):] + break + } + } + return bedrockVersionSuffix.ReplaceAllString(m, "") +} + +// invokeBedrock emits the model/provider/session/prompt for an AWS Bedrock +// request. Bedrock is metered under the dedicated "bedrock" parser, which reads +// both the InvokeModel and Converse response shapes. +func (m middlewareImpl) invokeBedrock(in *middleware.Input, br bedrockRequest) *middleware.Output { + out := &middleware.Output{Decision: middleware.DecisionAllow} + md := []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: llm.ProviderNameBedrock}, + {Key: middleware.KeyLLMModel, Value: br.model}, + {Key: middleware.KeyLLMStream, Value: strconv.FormatBool(br.stream)}, + } + + parser, _ := llm.ParserByName(llm.ProviderNameBedrock) + sessionID := sessionIDFromHeaders(in.Headers) + if sessionID == "" && parser != nil { + sessionID = parser.ExtractSessionID(in.Body) + } + if sessionID != "" { + md = append(md, middleware.KV{Key: middleware.KeyLLMSessionID, Value: sessionID}) + } + + promptTruncated := false + if parser != nil && m.capturePrompt { + var prompt string + prompt, promptTruncated = truncatePrompt(parser.ExtractPrompt(in.Body)) + if prompt != "" { + if m.redactPii { + prompt = llm_guardrail.RedactPII(prompt) + var rt bool + prompt, rt = truncatePrompt(prompt) + promptTruncated = promptTruncated || rt + } + md = append(md, middleware.KV{Key: middleware.KeyLLMRequestPromptRaw, Value: prompt}) + } + } + md = appendCaptureTruncated(md, promptTruncated, in.BodyTruncated) + out.Metadata = md + return out +} diff --git a/proxy/internal/middleware/builtin/llm_request_parser/middleware_test.go b/proxy/internal/middleware/builtin/llm_request_parser/middleware_test.go new file mode 100644 index 000000000..bc185b295 --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_request_parser/middleware_test.go @@ -0,0 +1,418 @@ +package llm_request_parser + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/middleware" +) + +func metaValue(t *testing.T, kvs []middleware.KV, key string) (string, bool) { + t.Helper() + for _, kv := range kvs { + if kv.Key == key { + return kv.Value, true + } + } + return "", false +} + +func newMiddleware(t *testing.T) middleware.Middleware { + t.Helper() + mw, err := Factory{}.New(nil) + require.NoError(t, err, "factory must accept nil config") + return mw +} + +func TestMiddleware_StaticSurface(t *testing.T) { + mw := newMiddleware(t) + assert.Equal(t, ID, mw.ID(), "ID must match the registered constant") + assert.Equal(t, Version, mw.Version(), "Version must match the constant") + assert.Equal(t, middleware.SlotOnRequest, mw.Slot(), "must run in the request slot") + assert.Equal(t, []string{"application/json"}, mw.AcceptedContentTypes(), "only JSON bodies are needed") + assert.False(t, mw.MutationsSupported(), "request parser never mutates") + assert.NoError(t, mw.Close(), "Close on stateless middleware is a no-op") + + keys := mw.MetadataKeys() + expected := []string{ + middleware.KeyLLMProvider, + middleware.KeyLLMModel, + middleware.KeyLLMStream, + middleware.KeyLLMRequestPromptRaw, + middleware.KeyLLMCaptureTruncated, + middleware.KeyLLMSessionID, + } + assert.Equal(t, expected, keys, "metadata key allowlist must match the spec") +} + +func TestFactory_AcceptsEmptyAndJSONConfig(t *testing.T) { + cases := [][]byte{nil, {}, []byte("null"), []byte("{}"), []byte(" ")} + for _, raw := range cases { + mw, err := Factory{}.New(raw) + require.NoError(t, err, "empty/null/object config must be accepted") + require.NotNil(t, mw, "factory must return a middleware instance") + } +} + +func TestFactory_RejectsMalformedConfig(t *testing.T) { + mw, err := Factory{}.New([]byte("{not json")) + require.Error(t, err, "malformed config must surface at construction") + assert.Nil(t, mw, "no instance is returned on error") +} + +func TestInvoke_OpenAIBufferedChatCompletion(t *testing.T) { + mw := newMiddleware(t) + body := []byte(`{"model":"gpt-4o-mini","stream":false,"messages":[{"role":"user","content":"Hello, world!"}]}`) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + URL: "/v1/chat/completions", + Body: body, + }) + require.NoError(t, err) + require.NotNil(t, out, "output must be returned") + assert.Equal(t, middleware.DecisionAllow, out.Decision, "request parser always allows") + + provider, ok := metaValue(t, out.Metadata, middleware.KeyLLMProvider) + require.True(t, ok, "provider metadata must be set") + assert.Equal(t, "openai", provider, "OpenAI provider detected from path") + + model, ok := metaValue(t, out.Metadata, middleware.KeyLLMModel) + require.True(t, ok, "model metadata must be set") + assert.Equal(t, "gpt-4o-mini", model, "model echoed from request body") + + stream, ok := metaValue(t, out.Metadata, middleware.KeyLLMStream) + require.True(t, ok, "stream metadata must be set") + assert.Equal(t, "false", stream, "buffered request reports stream=false") + + prompt, ok := metaValue(t, out.Metadata, middleware.KeyLLMRequestPromptRaw) + require.True(t, ok, "prompt metadata must be set when extractable") + assert.Contains(t, prompt, "Hello, world!", "extracted prompt carries the user message") + + truncated, ok := metaValue(t, out.Metadata, middleware.KeyLLMCaptureTruncated) + require.True(t, ok, "capture_truncated must always be emitted on success") + assert.Equal(t, "false", truncated, "no truncation on a small body") +} + +func TestInvoke_EmitsSessionID(t *testing.T) { + mw := newMiddleware(t) + + t.Run("codex session from client_metadata", func(t *testing.T) { + body := []byte(`{"model":"gpt-5.5","client_metadata":{"session_id":"sess-codex-1"},"input":[]}`) + out, err := mw.Invoke(context.Background(), &middleware.Input{URL: "/v1/responses", Body: body}) + require.NoError(t, err) + sid, ok := metaValue(t, out.Metadata, middleware.KeyLLMSessionID) + require.True(t, ok, "session id must be emitted for Codex requests") + assert.Equal(t, "sess-codex-1", sid, "session id must come from client_metadata.session_id") + }) + + t.Run("no session id key when absent", func(t *testing.T) { + body := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`) + out, err := mw.Invoke(context.Background(), &middleware.Input{URL: "/v1/chat/completions", Body: body}) + require.NoError(t, err) + _, ok := metaValue(t, out.Metadata, middleware.KeyLLMSessionID) + assert.False(t, ok, "no session id key emitted when the request carries none") + }) + + t.Run("claude code session header", func(t *testing.T) { + body := []byte(`{"model":"claude-opus-4-8","messages":[{"role":"user","content":"hi"}]}`) + out, err := mw.Invoke(context.Background(), &middleware.Input{ + URL: "/v1/messages", + Body: body, + Headers: []middleware.KV{{Key: "X-Claude-Code-Session-Id", Value: "cc-sess-1"}}, + }) + require.NoError(t, err) + sid, ok := metaValue(t, out.Metadata, middleware.KeyLLMSessionID) + require.True(t, ok, "Claude Code session id must be read from X-Claude-Code-Session-Id") + assert.Equal(t, "cc-sess-1", sid, "session id must come from the Claude Code session header") + }) + + t.Run("codex Session-Id header", func(t *testing.T) { + // Codex sends the session as the canonical header "Session-Id". + body := []byte(`{"model":"gpt-5.5","input":[]}`) + out, err := mw.Invoke(context.Background(), &middleware.Input{ + URL: "/v1/responses", + Body: body, + Headers: []middleware.KV{{Key: "Session-Id", Value: "sess-hdr-1"}}, + }) + require.NoError(t, err) + sid, ok := metaValue(t, out.Metadata, middleware.KeyLLMSessionID) + require.True(t, ok, "session id must be read from the Session-Id header") + assert.Equal(t, "sess-hdr-1", sid, "session id must come from the Codex Session-Id header") + }) + + t.Run("header wins over body and survives bypassed body", func(t *testing.T) { + // Oversized request: body was bypassed to a routing stub with no + // client_metadata, but the header still carries the session. + out, err := mw.Invoke(context.Background(), &middleware.Input{ + URL: "/v1/responses", + Body: []byte(`{"model":"gpt-5.5","stream":true}`), + Headers: []middleware.KV{{Key: "X-Session-Id", Value: "sess-hdr-2"}}, + }) + require.NoError(t, err) + sid, _ := metaValue(t, out.Metadata, middleware.KeyLLMSessionID) + assert.Equal(t, "sess-hdr-2", sid, "x-session-id header must be honoured when the body carries no marker") + }) +} + +func TestInvoke_OpenAIStreamingChatCompletion(t *testing.T) { + mw := newMiddleware(t) + body := []byte(`{"model":"gpt-4o-mini","stream":true,"messages":[{"role":"user","content":"hi"}]}`) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + URL: "/v1/chat/completions", + Body: body, + }) + require.NoError(t, err) + + stream, ok := metaValue(t, out.Metadata, middleware.KeyLLMStream) + require.True(t, ok, "stream metadata must be set") + assert.Equal(t, "true", stream, "stream flag echoed for SSE-bound request") +} + +func TestInvoke_AnthropicMessages(t *testing.T) { + mw := newMiddleware(t) + body := []byte(`{"model":"claude-sonnet-4-5","stream":false,"messages":[{"role":"user","content":"What is the weather?"}]}`) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + URL: "/v1/messages", + Body: body, + }) + require.NoError(t, err) + + provider, ok := metaValue(t, out.Metadata, middleware.KeyLLMProvider) + require.True(t, ok, "provider metadata must be set") + assert.Equal(t, "anthropic", provider, "Anthropic provider detected from path") + + model, _ := metaValue(t, out.Metadata, middleware.KeyLLMModel) + assert.Equal(t, "claude-sonnet-4-5", model, "anthropic model echoed") + + prompt, ok := metaValue(t, out.Metadata, middleware.KeyLLMRequestPromptRaw) + require.True(t, ok, "prompt metadata must be set") + assert.Contains(t, prompt, "What is the weather?", "anthropic message text extracted") +} + +func TestInvoke_UnknownURLNoMetadata(t *testing.T) { + mw := newMiddleware(t) + out, err := mw.Invoke(context.Background(), &middleware.Input{ + URL: "/healthz", + Body: []byte(`{"model":"x"}`), + }) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "unknown paths still allow") + assert.Empty(t, out.Metadata, "no metadata is emitted when no parser matches") +} + +func TestInvoke_ProviderIDConfigBypassesURLSniff(t *testing.T) { + mw, err := Factory{}.New([]byte(`{"provider_id":"openai"}`)) + require.NoError(t, err, "factory must accept provider_id config") + + // URL doesn't match any of the OpenAI path hints — the provider_id + // config is the only signal the middleware has. + out, err := mw.Invoke(context.Background(), &middleware.Input{ + URL: "/custom/gateway/foo/bar", + Body: []byte(`{"model":"gpt-4o-mini","messages":[{"role":"user","content":"Hi"}]}`), + }) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision) + + provider, ok := metaValue(t, out.Metadata, middleware.KeyLLMProvider) + require.True(t, ok, "provider must be emitted when provider_id is configured even on unknown URLs") + assert.Equal(t, "openai", provider, "provider_id config selects the OpenAI parser") + + model, ok := metaValue(t, out.Metadata, middleware.KeyLLMModel) + require.True(t, ok, "model still extracted from the body") + assert.Equal(t, "gpt-4o-mini", model) +} + +func TestInvoke_UnknownProviderIDFallsBackToURL(t *testing.T) { + mw, err := Factory{}.New([]byte(`{"provider_id":"not-a-real-parser"}`)) + require.NoError(t, err, "factory must accept any provider_id string") + + // URL hits the OpenAI surface, so URL sniffing should still resolve + // even though the configured provider_id doesn't match a parser. + out, err := mw.Invoke(context.Background(), &middleware.Input{ + URL: "/v1/chat/completions", + Body: []byte(`{"model":"gpt-4o-mini"}`), + }) + require.NoError(t, err) + require.NotNil(t, out) + + provider, ok := metaValue(t, out.Metadata, middleware.KeyLLMProvider) + require.True(t, ok, "fallback URL sniffing must populate the provider") + assert.Equal(t, "openai", provider) +} + +func TestInvoke_MalformedBodyAllowsWithProvider(t *testing.T) { + mw := newMiddleware(t) + out, err := mw.Invoke(context.Background(), &middleware.Input{ + URL: "/v1/chat/completions", + Body: []byte(`{not json`), + }) + require.NoError(t, err, "malformed body must not error") + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "decision is always allow") + + provider, ok := metaValue(t, out.Metadata, middleware.KeyLLMProvider) + require.True(t, ok, "provider metadata is emitted before body parse") + assert.Equal(t, "openai", provider, "provider stays even when body parse fails") + + _, hasModel := metaValue(t, out.Metadata, middleware.KeyLLMModel) + assert.False(t, hasModel, "no model metadata when parse fails") + + truncated, ok := metaValue(t, out.Metadata, middleware.KeyLLMCaptureTruncated) + require.True(t, ok, "capture_truncated is emitted on parse error path") + assert.Equal(t, "false", truncated, "no truncation marker without truncated body or prompt") +} + +func TestInvoke_TruncatesLongPrompt(t *testing.T) { + mw := newMiddleware(t) + long := strings.Repeat("x", maxPromptBytes*2) + body := []byte(`{"model":"gpt-4o-mini","messages":[{"role":"user","content":"` + long + `"}]}`) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + URL: "/v1/chat/completions", + Body: body, + }) + require.NoError(t, err) + + prompt, ok := metaValue(t, out.Metadata, middleware.KeyLLMRequestPromptRaw) + require.True(t, ok, "prompt metadata must be set") + assert.LessOrEqual(t, len(prompt), maxPromptBytes, "prompt must respect the byte budget") + + truncated, ok := metaValue(t, out.Metadata, middleware.KeyLLMCaptureTruncated) + require.True(t, ok, "capture_truncated must be set") + assert.Equal(t, "true", truncated, "truncation marker raised when prompt is clipped") +} + +func TestInvoke_TruncatesOnRuneBoundary(t *testing.T) { + mw := newMiddleware(t) + // Each ☃ is 3 bytes in UTF-8; build a string whose byte length exceeds + // maxPromptBytes with snowmen straddling the cut point. + rune3 := "☃" + repeats := (maxPromptBytes / len(rune3)) + 5 + long := strings.Repeat(rune3, repeats) + body := []byte(`{"model":"gpt-4o-mini","messages":[{"role":"user","content":"` + long + `"}]}`) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + URL: "/v1/chat/completions", + Body: body, + }) + require.NoError(t, err) + + prompt, ok := metaValue(t, out.Metadata, middleware.KeyLLMRequestPromptRaw) + require.True(t, ok, "prompt metadata must be set") + assert.LessOrEqual(t, len(prompt), maxPromptBytes, "prompt must respect the byte budget") + assert.True(t, strings.HasSuffix(prompt, rune3) || !strings.ContainsRune(prompt[len(prompt)-1:], 0xFFFD), + "truncation must not split a multi-byte rune") +} + +func TestInvoke_BodyTruncatedRaisesCaptureTruncated(t *testing.T) { + mw := newMiddleware(t) + body := []byte(`{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hi"}]}`) + + out, err := mw.Invoke(context.Background(), &middleware.Input{ + URL: "/v1/chat/completions", + Body: body, + BodyTruncated: true, + }) + require.NoError(t, err) + + truncated, ok := metaValue(t, out.Metadata, middleware.KeyLLMCaptureTruncated) + require.True(t, ok, "capture_truncated must be set") + assert.Equal(t, "true", truncated, "BodyTruncated input flips the marker even when prompt fits") +} + +// TestInvoke_RedactPii_RedactsBeforeEmittingRawPrompt covers the GC contract: +// when the synthesiser sets redact_pii=true on the parser config, the value +// emitted as llm.request_prompt_raw must already be redacted, so the +// access-log row never carries raw emails / SSNs / phones — even though the +// downstream llm_guardrail middleware also runs. +func TestInvoke_RedactPii_RedactsBeforeEmittingRawPrompt(t *testing.T) { + mw, err := Factory{}.New([]byte(`{"redact_pii":true}`)) + require.NoError(t, err) + + body := []byte(`{"model":"gpt-4o-mini","stream":false,"messages":[{"role":"user","content":"contact alice.johnson@example.com SSN 123-45-6789 phone (202) 555-0147 and bob 202/555/0108"}]}`) + out, err := mw.Invoke(context.Background(), &middleware.Input{URL: "/v1/chat/completions", Body: body}) + require.NoError(t, err) + require.NotNil(t, out) + + raw, ok := metaValue(t, out.Metadata, middleware.KeyLLMRequestPromptRaw) + require.True(t, ok, "raw prompt key must still be emitted") + assert.Contains(t, raw, "[REDACTED:email]", "email must be redacted before emit") + assert.Contains(t, raw, "[REDACTED:ssn]", "ssn must be redacted before emit") + assert.Contains(t, raw, "[REDACTED:phone]", "phone must be redacted before emit") + assert.NotContains(t, raw, "alice.johnson@example.com", "raw email must not survive") + assert.NotContains(t, raw, "123-45-6789", "raw SSN must not survive") + assert.NotContains(t, raw, "(202) 555-0147", "parenthesised phone must not survive") + assert.NotContains(t, raw, "202/555/0108", "slash-separated phone must not survive") +} + +// TestInvoke_CapturePromptOff_DoesNotEmitRawPrompt covers the contract for +// the account-level enable_prompt_collection toggle: when the synthesiser sets +// capture_prompt=false (operator hasn't opted in to prompt content), the +// parser MUST NOT emit llm.request_prompt_raw at all — otherwise the access +// log carries the user's input even though log collection is meant to be +// metadata-only (provider, model, tokens, cost). The other facts the parser +// emits (provider, model, stream, capture_truncated) stay. +func TestInvoke_CapturePromptOff_DoesNotEmitRawPrompt(t *testing.T) { + mw, err := Factory{}.New([]byte(`{"capture_prompt":false}`)) + require.NoError(t, err) + body := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"contact alice@example.com SSN 123-45-6789"}]}`) + out, err := mw.Invoke(context.Background(), &middleware.Input{URL: "/v1/chat/completions", Body: body}) + require.NoError(t, err) + require.NotNil(t, out) + + _, ok := metaValue(t, out.Metadata, middleware.KeyLLMRequestPromptRaw) + assert.False(t, ok, "llm.request_prompt_raw must NOT be emitted when capture_prompt is false") + // Non-content facts must still flow. + _, ok = metaValue(t, out.Metadata, middleware.KeyLLMModel) + assert.True(t, ok, "model fact must still be emitted") + _, ok = metaValue(t, out.Metadata, middleware.KeyLLMProvider) + assert.True(t, ok, "provider fact must still be emitted") +} + +// TestInvoke_CapturePromptUnset_PreservesLegacyEmission documents the default +// behavior: an empty / legacy config (no capture_prompt field) keeps the +// existing emission, so non-agent-network callers and pre-toggle tests don't +// suddenly lose data. +func TestInvoke_CapturePromptUnset_PreservesLegacyEmission(t *testing.T) { + mw, err := Factory{}.New([]byte(`{}`)) + require.NoError(t, err) + body := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"hello"}]}`) + out, err := mw.Invoke(context.Background(), &middleware.Input{URL: "/v1/chat/completions", Body: body}) + require.NoError(t, err) + _, ok := metaValue(t, out.Metadata, middleware.KeyLLMRequestPromptRaw) + assert.True(t, ok, "absent capture_prompt must preserve emission (backwards-compatible default)") +} + +// TestInvoke_RedactPii_OffShipsRawPrompt is the inverse: when redact_pii is +// false (default) the operator opted out and the raw prompt is shipped +// verbatim, so audit / debugging consumers still get the full body. +func TestInvoke_RedactPii_OffShipsRawPrompt(t *testing.T) { + mw, err := Factory{}.New([]byte(`{}`)) + require.NoError(t, err) + + body := []byte(`{"model":"gpt-4o-mini","messages":[{"role":"user","content":"alice.johnson@example.com"}]}`) + out, err := mw.Invoke(context.Background(), &middleware.Input{URL: "/v1/chat/completions", Body: body}) + require.NoError(t, err) + + raw, ok := metaValue(t, out.Metadata, middleware.KeyLLMRequestPromptRaw) + require.True(t, ok) + assert.Contains(t, raw, "alice.johnson@example.com", "redact off → raw email passes through") + assert.NotContains(t, raw, "[REDACTED:", "redact off → no markers") +} + +func TestInvoke_NilInputAllows(t *testing.T) { + mw := newMiddleware(t) + out, err := mw.Invoke(context.Background(), nil) + require.NoError(t, err, "nil input must not panic or error") + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "nil input still allows") + assert.Empty(t, out.Metadata, "nil input emits no metadata") +} diff --git a/proxy/internal/middleware/builtin/llm_response_parser/factory.go b/proxy/internal/middleware/builtin/llm_response_parser/factory.go new file mode 100644 index 000000000..e7d634109 --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_response_parser/factory.go @@ -0,0 +1,43 @@ +package llm_response_parser + +import ( + "bytes" + "encoding/json" + "fmt" + + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" +) + +// Factory constructs configured Middleware instances for the registry. +type Factory struct{} + +// ID returns the registry identifier. +func (Factory) ID() string { return ID } + +// New decodes RawConfig (empty / null / "{}" all accepted) and returns +// a configured Middleware. Construction never fails on a well-formed +// empty config; only structurally invalid JSON is rejected. +func (Factory) New(rawConfig []byte) (middleware.Middleware, error) { + cfg, err := decodeConfig(rawConfig) + if err != nil { + return nil, fmt.Errorf("decode config: %w", err) + } + return New(cfg), nil +} + +func decodeConfig(raw []byte) (config, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) { + return config{}, nil + } + var cfg config + if err := json.Unmarshal(trimmed, &cfg); err != nil { + return config{}, err + } + return cfg, nil +} + +func init() { + builtin.Register(Factory{}) +} diff --git a/proxy/internal/middleware/builtin/llm_response_parser/gzip_test.go b/proxy/internal/middleware/builtin/llm_response_parser/gzip_test.go new file mode 100644 index 000000000..017153e4c --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_response_parser/gzip_test.go @@ -0,0 +1,133 @@ +package llm_response_parser + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "compress/zlib" + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/middleware" +) + +// gzipBytes returns data gzip-compressed — the wire shape Anthropic +// returns when the client (Claude Code) negotiated Accept-Encoding: gzip. +func gzipBytes(t *testing.T, data []byte) []byte { + t.Helper() + var buf bytes.Buffer + w := gzip.NewWriter(&buf) + _, err := w.Write(data) + require.NoError(t, err, "gzip write must succeed") + require.NoError(t, w.Close(), "gzip close must succeed") + return buf.Bytes() +} + +// TestInvoke_AnthropicStreaming_Gzip is the regression guard for the live +// bug: Claude Code negotiates gzip, Anthropic gzips the SSE stream, the +// proxy captures the compressed bytes, and the parser must decompress +// before accumulating — otherwise token usage is silently dropped and +// cost_meter skips with missing_tokens. +func TestInvoke_AnthropicStreaming_Gzip(t *testing.T) { + m := newTestMiddleware(t) + body := gzipBytes(t, loadFixture(t, "anthropic_stream.txt")) + + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{ + {Key: "Content-Type", Value: "text/event-stream; charset=utf-8"}, + {Key: "Content-Encoding", Value: "gzip"}, + }, + RespBody: body, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "anthropic"}, + {Key: middleware.KeyLLMModel, Value: "claude-opus-4-8"}, + }, + } + + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "Invoke must not error on a gzip-encoded streaming body") + + in123, ok := metaValue(out.Metadata, middleware.KeyLLMInputTokens) + require.True(t, ok, "input tokens must be emitted from a gzip SSE stream") + assert.Equal(t, "123", in123, "input tokens must survive gzip decompression") + + outTok, _ := metaValue(out.Metadata, middleware.KeyLLMOutputTokens) + assert.Equal(t, "45", outTok, "output tokens must survive gzip decompression") + + totTok, _ := metaValue(out.Metadata, middleware.KeyLLMTotalTokens) + assert.Equal(t, "168", totTok, "total tokens must survive gzip decompression") +} + +// TestInvoke_AnthropicBuffered_Gzip covers the non-streaming JSON path +// under gzip — the same decode must happen before ParseResponse. +func TestInvoke_AnthropicBuffered_Gzip(t *testing.T) { + m := newTestMiddleware(t) + body := gzipBytes(t, loadFixture(t, "anthropic_messages.json")) + + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{ + {Key: "Content-Type", Value: "application/json"}, + {Key: "Content-Encoding", Value: "gzip"}, + }, + RespBody: body, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "anthropic"}, + {Key: middleware.KeyLLMModel, Value: "claude-opus-4-8"}, + }, + } + + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "Invoke must not error on a gzip-encoded buffered body") + + _, ok := metaValue(out.Metadata, middleware.KeyLLMInputTokens) + require.True(t, ok, "input tokens must be emitted from a gzip JSON body") +} + +// TestDecodeResponseBody covers the encoding matrix directly. +func TestDecodeResponseBody(t *testing.T) { + plain := []byte(`{"hello":"world"}`) + + t.Run("identity passthrough", func(t *testing.T) { + assert.Equal(t, plain, decodeResponseBody(plain, "")) + assert.Equal(t, plain, decodeResponseBody(plain, "identity")) + }) + + t.Run("gzip", func(t *testing.T) { + assert.Equal(t, plain, decodeResponseBody(gzipBytes(t, plain), "gzip")) + }) + + t.Run("gzip with multi-coding header takes outermost", func(t *testing.T) { + assert.Equal(t, plain, decodeResponseBody(gzipBytes(t, plain), "identity, gzip")) + }) + + t.Run("deflate zlib-wrapped", func(t *testing.T) { + var buf bytes.Buffer + zw := zlib.NewWriter(&buf) + _, _ = zw.Write(plain) + _ = zw.Close() + assert.Equal(t, plain, decodeResponseBody(buf.Bytes(), "deflate")) + }) + + t.Run("deflate raw flate fallback", func(t *testing.T) { + var buf bytes.Buffer + fw, _ := flate.NewWriter(&buf, flate.DefaultCompression) + _, _ = fw.Write(plain) + _ = fw.Close() + assert.Equal(t, plain, decodeResponseBody(buf.Bytes(), "deflate")) + }) + + t.Run("gzip header but not actually gzip falls back to raw", func(t *testing.T) { + assert.Equal(t, plain, decodeResponseBody(plain, "gzip")) + }) + + t.Run("unknown encoding (br) returns raw", func(t *testing.T) { + assert.Equal(t, plain, decodeResponseBody(plain, "br")) + }) +} diff --git a/proxy/internal/middleware/builtin/llm_response_parser/middleware.go b/proxy/internal/middleware/builtin/llm_response_parser/middleware.go new file mode 100644 index 000000000..a204460cb --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_response_parser/middleware.go @@ -0,0 +1,339 @@ +// Package llm_response_parser implements the SlotOnResponse middleware +// that decodes OpenAI- and Anthropic-shaped LLM responses (buffered or +// streaming) and emits token usage and completion metadata. Provider +// and model are read from the request-side metadata bag emitted by +// llm_request_parser; without that context the middleware is a no-op. +package llm_response_parser + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "compress/zlib" + "context" + "io" + "strconv" + "strings" + "unicode/utf8" + + "github.com/netbirdio/netbird/proxy/internal/llm" + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_guardrail" +) + +// ID is the registry identifier for this middleware. +const ID = "llm_response_parser" + +const version = "1.0.0" + +// maxCompletionBytes is the rune-safe cap applied to the extracted +// completion text before emitting it as metadata. +const maxCompletionBytes = 3500 + +// maxDecodedBytes bounds the inflated size of a compressed response body +// so a small gzip/deflate payload can't expand into a memory blow-up. The +// captured input is already capped (per-direction body cap), so this only +// bounds the decompression ratio; the parser is best-effort and tolerates a +// truncated decode. +const maxDecodedBytes = 16 << 20 + +var ( + acceptedContentTypes = []string{"application/json", "text/event-stream"} + metadataKeys = []string{ + middleware.KeyLLMInputTokens, + middleware.KeyLLMOutputTokens, + middleware.KeyLLMTotalTokens, + middleware.KeyLLMCachedInputTokens, + middleware.KeyLLMCacheCreationTokens, + middleware.KeyLLMResponseCompletion, + } +) + +// config is the wire-side configuration for this middleware. RedactPii, when +// true, runs PII redaction on the extracted completion text BEFORE it is +// emitted as llm.response_completion — keeping the access-log row free of +// emails / SSNs / phone numbers the model itself generated. CaptureCompletion +// gates emission of the completion key entirely: a nil pointer preserves +// legacy emission (so callers without the toggle aren't broken), an explicit +// false suppresses the key so the access-log row carries token / cost facts +// only. Both are sourced by the synthesiser from the account's redact_pii +// and enable_prompt_collection toggles respectively. +type config struct { + RedactPii bool `json:"redact_pii,omitempty"` + CaptureCompletion *bool `json:"capture_completion,omitempty"` +} + +// Middleware implements middleware.Middleware. +type Middleware struct { + parsers []llm.Parser + redactPii bool + captureCompletion bool +} + +// New constructs a configured Middleware instance. +func New(cfg config) *Middleware { + capture := true + if cfg.CaptureCompletion != nil { + capture = *cfg.CaptureCompletion + } + return &Middleware{parsers: llm.Parsers(), redactPii: cfg.RedactPii, captureCompletion: capture} +} + +// ID returns the registry identifier. +func (m *Middleware) ID() string { return ID } + +// Version returns the implementation version. +func (m *Middleware) Version() string { return version } + +// Slot reports that the middleware runs after the upstream call. +func (m *Middleware) Slot() middleware.Slot { return middleware.SlotOnResponse } + +// AcceptedContentTypes lists the response content types the middleware +// inspects. +func (m *Middleware) AcceptedContentTypes() []string { + return append([]string(nil), acceptedContentTypes...) +} + +// MetadataKeys returns the closed allowlist of keys this middleware +// may emit. +func (m *Middleware) MetadataKeys() []string { + return append([]string(nil), metadataKeys...) +} + +// MutationsSupported reports that this middleware never mutates the +// response. +func (m *Middleware) MutationsSupported() bool { return false } + +// Close releases any resources held by the middleware. The parser-set +// is stateless so this is a no-op. +func (m *Middleware) Close() error { return nil } + +// Invoke decodes the response body and emits token-usage and completion +// metadata. The decision is always DecisionAllow; parse errors degrade +// silently to omitted metadata rather than chain failures. +func (m *Middleware) Invoke(_ context.Context, in *middleware.Input) (*middleware.Output, error) { + out := &middleware.Output{Decision: middleware.DecisionAllow} + if in == nil { + return out, nil + } + + provider := lookupKV(in.Metadata, middleware.KeyLLMProvider) + if provider == "" { + return out, nil + } + + parser := m.parserByName(provider) + if parser == nil { + return out, nil + } + + // Upstreams compress the response when the client negotiated it + // (Claude Code sends Accept-Encoding: gzip). The transport leaves it + // compressed because the request carried an explicit Accept-Encoding, + // so the captured copy is gzip/deflate bytes — decompress it before + // parsing or token usage is silently lost. The forwarded client + // stream is untouched; this only affects our parse copy. + body := decodeResponseBody(in.RespBody, headerLookup(in.RespHeaders, "Content-Encoding")) + + contentType := headerLookup(in.RespHeaders, "Content-Type") + switch { + case isEventStream(contentType), isAWSEventStream(contentType): + out.Metadata = m.invokeStreaming(parser, body) + case isJSON(contentType): + out.Metadata = m.invokeBuffered(parser, in, contentType, body) + } + + return out, nil +} + +// invokeBuffered decodes a non-streaming JSON response body. Status +// codes >= 400 short-circuit because providers don't include usage on +// error responses. +func (m *Middleware) invokeBuffered(parser llm.Parser, in *middleware.Input, contentType string, body []byte) []middleware.KV { + if in.Status >= 400 { + return nil + } + + var md []middleware.KV + + usage, err := parser.ParseResponse(in.Status, contentType, body) + if err == nil { + md = appendUsage(md, usage) + } + + if completion := truncateCompletion(parser.ExtractCompletion(in.Status, contentType, body)); completion != "" && m.captureCompletion { + if m.redactPii { + completion = llm_guardrail.RedactPII(completion) + } + md = append(md, middleware.KV{Key: middleware.KeyLLMResponseCompletion, Value: completion}) + } + + return md +} + +// invokeStreaming walks the buffered SSE prefix and accumulates token +// deltas plus completion text. Truncated bodies are processed +// best-effort; partial usage is preferred over no metadata. +func (m *Middleware) invokeStreaming(parser llm.Parser, body []byte) []middleware.KV { + if len(body) == 0 { + return nil + } + + usage, completion := accumulateStream(parser.ProviderName(), body) + + var md []middleware.KV + if usage.InputTokens > 0 || usage.OutputTokens > 0 || usage.TotalTokens > 0 { + md = appendUsage(md, usage) + } + if c := truncateCompletion(completion); c != "" && m.captureCompletion { + if m.redactPii { + c = llm_guardrail.RedactPII(c) + } + md = append(md, middleware.KV{Key: middleware.KeyLLMResponseCompletion, Value: c}) + } + return md +} + +// parserByName returns the parser matching the provider label emitted +// by llm_request_parser, or nil when none claims it. +func (m *Middleware) parserByName(name string) llm.Parser { + for _, p := range m.parsers { + if p.ProviderName() == name { + return p + } + } + return nil +} + +// appendUsage emits the three baseline token-count metadata keys plus +// optional cached / cache-creation bucket counts when nonzero. Total +// is computed when the provider omitted one but reported per-direction +// counts; cache buckets are excluded from the legacy total because +// llm.input_tokens already absorbs the OpenAI cached subset and the +// sum-of-everything is a separate downstream concern. +func appendUsage(md []middleware.KV, usage llm.Usage) []middleware.KV { + total := usage.TotalTokens + if total == 0 && (usage.InputTokens > 0 || usage.OutputTokens > 0) { + total = usage.InputTokens + usage.OutputTokens + } + md = append(md, + middleware.KV{Key: middleware.KeyLLMInputTokens, Value: strconv.FormatInt(usage.InputTokens, 10)}, + middleware.KV{Key: middleware.KeyLLMOutputTokens, Value: strconv.FormatInt(usage.OutputTokens, 10)}, + middleware.KV{Key: middleware.KeyLLMTotalTokens, Value: strconv.FormatInt(total, 10)}, + ) + if usage.CachedInputTokens > 0 { + md = append(md, middleware.KV{ + Key: middleware.KeyLLMCachedInputTokens, + Value: strconv.FormatInt(usage.CachedInputTokens, 10), + }) + } + if usage.CacheCreationTokens > 0 { + md = append(md, middleware.KV{ + Key: middleware.KeyLLMCacheCreationTokens, + Value: strconv.FormatInt(usage.CacheCreationTokens, 10), + }) + } + return md +} + +// truncateCompletion clamps an extracted completion to maxCompletionBytes. +// The cut is rune-safe so we never split a multi-byte UTF-8 sequence. +func truncateCompletion(s string) string { + if len(s) <= maxCompletionBytes { + return s + } + cut := maxCompletionBytes + for cut > 0 && !utf8.RuneStart(s[cut]) { + cut-- + } + return s[:cut] +} + +func lookupKV(kvs []middleware.KV, key string) string { + for _, kv := range kvs { + if kv.Key == key { + return kv.Value + } + } + return "" +} + +func headerLookup(h []middleware.KV, name string) string { + lower := strings.ToLower(name) + for _, kv := range h { + if strings.ToLower(kv.Key) == lower { + return kv.Value + } + } + return "" +} + +func isEventStream(contentType string) bool { + return strings.Contains(strings.ToLower(contentType), "text/event-stream") +} + +// isAWSEventStream reports whether contentType is the AWS binary event-stream +// framing used by Bedrock's streaming endpoints. +func isAWSEventStream(contentType string) bool { + return strings.Contains(strings.ToLower(contentType), "application/vnd.amazon.eventstream") +} + +func isJSON(contentType string) bool { + lower := strings.ToLower(contentType) + return strings.Contains(lower, "application/json") || strings.Contains(lower, "+json") +} + +// decodeResponseBody returns body decompressed per its Content-Encoding, +// or the original bytes when the encoding is identity, unrecognised +// (e.g. br — no stdlib decoder), or the body isn't actually compressed. +// Decoding is best-effort: a truncated stream (capture hit the byte cap) +// yields the decompressed prefix rather than an error, which is enough to +// recover the leading message_start usage on Anthropic SSE. +func decodeResponseBody(body []byte, contentEncoding string) []byte { + enc := strings.ToLower(strings.TrimSpace(contentEncoding)) + // Content-Encoding may list multiple codings; the last applied is + // the outermost on the wire. + if idx := strings.LastIndex(enc, ","); idx >= 0 { + enc = strings.TrimSpace(enc[idx+1:]) + } + switch enc { + case "", "identity": + return body + case "gzip", "x-gzip": + zr, err := gzip.NewReader(bytes.NewReader(body)) + if err != nil { + return body + } + defer zr.Close() + if out := readCapped(zr); len(out) > 0 { + return out + } + return body + case "deflate": + // "deflate" on the wire is usually zlib-wrapped; fall back to raw + // flate when there's no zlib header. + if zr, err := zlib.NewReader(bytes.NewReader(body)); err == nil { + defer zr.Close() + if out := readCapped(zr); len(out) > 0 { + return out + } + return body + } + fr := flate.NewReader(bytes.NewReader(body)) + defer fr.Close() + if out := readCapped(fr); len(out) > 0 { + return out + } + return body + default: + return body + } +} + +// readCapped reads at most maxDecodedBytes from r, discarding any excess. +// Best-effort: a read error returns whatever was decoded so far, which is +// enough for the parser to recover leading usage events. +func readCapped(r io.Reader) []byte { + out, _ := io.ReadAll(io.LimitReader(r, maxDecodedBytes)) + return out +} diff --git a/proxy/internal/middleware/builtin/llm_response_parser/middleware_test.go b/proxy/internal/middleware/builtin/llm_response_parser/middleware_test.go new file mode 100644 index 000000000..084118802 --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_response_parser/middleware_test.go @@ -0,0 +1,433 @@ +package llm_response_parser + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/middleware" +) + +func loadFixture(t *testing.T, name string) []byte { + t.Helper() + root, err := os.Getwd() + require.NoError(t, err, "must resolve cwd to locate fixture") + + dir := root + for i := 0; i < 8; i++ { + candidate := filepath.Join(dir, "proxy", "internal", "llm", "fixtures", name) + if data, err := os.ReadFile(candidate); err == nil { + return data + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + t.Fatalf("fixture %q not found relative to %q", name, root) + return nil +} + +func metaValue(kvs []middleware.KV, key string) (string, bool) { + for _, kv := range kvs { + if kv.Key == key { + return kv.Value, true + } + } + return "", false +} + +func newTestMiddleware(t *testing.T) *Middleware { + t.Helper() + mw, err := Factory{}.New(nil) + require.NoError(t, err, "factory must accept empty config") + concrete, ok := mw.(*Middleware) + require.True(t, ok, "factory must return *Middleware") + return concrete +} + +func TestMiddleware_StaticSurface(t *testing.T) { + m := newTestMiddleware(t) + assert.Equal(t, ID, m.ID(), "ID must match registry constant") + assert.Equal(t, "1.0.0", m.Version(), "Version must be 1.0.0") + assert.Equal(t, middleware.SlotOnResponse, m.Slot(), "Slot must be SlotOnResponse") + assert.False(t, m.MutationsSupported(), "response parser does not mutate") + assert.ElementsMatch(t, + []string{"application/json", "text/event-stream"}, + m.AcceptedContentTypes(), + "AcceptedContentTypes must list JSON and SSE", + ) + assert.ElementsMatch(t, + []string{ + middleware.KeyLLMInputTokens, + middleware.KeyLLMOutputTokens, + middleware.KeyLLMTotalTokens, + middleware.KeyLLMCachedInputTokens, + middleware.KeyLLMCacheCreationTokens, + middleware.KeyLLMResponseCompletion, + }, + m.MetadataKeys(), + "MetadataKeys must be the documented response-side keys, including the optional cache buckets emitted only when nonzero", + ) + require.NoError(t, m.Close(), "Close must be a no-op") +} + +func TestFactory_AcceptsEmptyAndNullConfig(t *testing.T) { + for name, raw := range map[string][]byte{ + "nil": nil, + "empty": {}, + "null": []byte("null"), + "obj": []byte("{}"), + "ws": []byte(" "), + } { + t.Run(name, func(t *testing.T) { + mw, err := Factory{}.New(raw) + require.NoError(t, err, "factory must accept %s config", name) + require.NotNil(t, mw, "factory must return middleware for %s", name) + }) + } +} + +func TestFactory_RejectsMalformedJSON(t *testing.T) { + _, err := Factory{}.New([]byte("not-json")) + require.Error(t, err, "malformed config must surface a decode error") +} + +func TestInvoke_OpenAIBuffered(t *testing.T) { + m := newTestMiddleware(t) + body := loadFixture(t, "openai_chat_completion.json") + + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "application/json"}}, + RespBody: body, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o-mini"}, + }, + } + + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "Invoke must not error on a valid buffered response") + require.Equal(t, middleware.DecisionAllow, out.Decision, "decision must be Allow") + + in123, ok := metaValue(out.Metadata, middleware.KeyLLMInputTokens) + require.True(t, ok, "input tokens must be emitted") + assert.Equal(t, "123", in123, "input tokens must match fixture prompt_tokens") + + outTok, ok := metaValue(out.Metadata, middleware.KeyLLMOutputTokens) + require.True(t, ok, "output tokens must be emitted") + assert.Equal(t, "45", outTok, "output tokens must match fixture completion_tokens") + + totTok, ok := metaValue(out.Metadata, middleware.KeyLLMTotalTokens) + require.True(t, ok, "total tokens must be emitted") + assert.Equal(t, "168", totTok, "total tokens must match fixture") + + completion, ok := metaValue(out.Metadata, middleware.KeyLLMResponseCompletion) + require.True(t, ok, "completion must be emitted") + assert.Equal(t, "Hello, world!", completion, "completion text must match fixture") +} + +func TestInvoke_AnthropicBuffered(t *testing.T) { + m := newTestMiddleware(t) + body := loadFixture(t, "anthropic_messages.json") + + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "application/json"}}, + RespBody: body, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "anthropic"}, + {Key: middleware.KeyLLMModel, Value: "claude-sonnet-4-5"}, + }, + } + + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "Invoke must not error on a valid buffered response") + require.Equal(t, middleware.DecisionAllow, out.Decision, "decision must be Allow") + + in123, _ := metaValue(out.Metadata, middleware.KeyLLMInputTokens) + assert.Equal(t, "123", in123, "input tokens must match anthropic fixture") + + outTok, _ := metaValue(out.Metadata, middleware.KeyLLMOutputTokens) + assert.Equal(t, "45", outTok, "output tokens must match anthropic fixture") + + totTok, _ := metaValue(out.Metadata, middleware.KeyLLMTotalTokens) + assert.Equal(t, "168", totTok, "total tokens must be input+output for anthropic") + + completion, ok := metaValue(out.Metadata, middleware.KeyLLMResponseCompletion) + require.True(t, ok, "completion must be emitted for anthropic") + assert.Equal(t, "Hello, world!", completion, "completion text must match fixture") +} + +// TestInvoke_OpenAICachedTokensSurfaceOnMetadata covers the +// end-to-end path from the JSON usage block to the +// llm.cached_input_tokens metadata key the cost meter consumes. +// llm.cache_creation_tokens is NOT emitted for OpenAI because +// OpenAI has no cache_creation analogue. +func TestInvoke_OpenAICachedTokensSurfaceOnMetadata(t *testing.T) { + m := newTestMiddleware(t) + body := []byte(`{"usage":{"prompt_tokens":1024,"completion_tokens":200,"total_tokens":1224,"prompt_tokens_details":{"cached_tokens":768}}}`) + + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "application/json"}}, + RespBody: body, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o"}, + }, + } + + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err) + cached, ok := metaValue(out.Metadata, middleware.KeyLLMCachedInputTokens) + require.True(t, ok, "cached_input_tokens must land on the bag when the OpenAI response carries cached_tokens") + assert.Equal(t, "768", cached) + + _, hasCreation := metaValue(out.Metadata, middleware.KeyLLMCacheCreationTokens) + assert.False(t, hasCreation, "cache_creation_tokens must NOT be emitted for OpenAI — no analogue in the OpenAI shape") +} + +// TestInvoke_AnthropicCacheBucketsSurfaceOnMetadata covers the +// Anthropic shape: both cache_read and cache_creation values flow +// onto the metadata bag so the cost meter can apply per-bucket +// rates. +func TestInvoke_AnthropicCacheBucketsSurfaceOnMetadata(t *testing.T) { + m := newTestMiddleware(t) + body := []byte(`{"usage":{"input_tokens":256,"output_tokens":200,"cache_read_input_tokens":768,"cache_creation_input_tokens":512}}`) + + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "application/json"}}, + RespBody: body, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "anthropic"}, + {Key: middleware.KeyLLMModel, Value: "claude-sonnet-4-5"}, + }, + } + + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err) + + cached, ok := metaValue(out.Metadata, middleware.KeyLLMCachedInputTokens) + require.True(t, ok, "cache_read_input_tokens lands under cached_input_tokens — same key carries OpenAI cached subset and Anthropic cache reads, meter switches formula on provider") + assert.Equal(t, "768", cached) + + creation, ok := metaValue(out.Metadata, middleware.KeyLLMCacheCreationTokens) + require.True(t, ok, "cache_creation_input_tokens lands under cache_creation_tokens for Anthropic") + assert.Equal(t, "512", creation) +} + +func TestInvoke_NoProviderMetadata_NoOp(t *testing.T) { + m := newTestMiddleware(t) + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "application/json"}}, + RespBody: loadFixture(t, "openai_chat_completion.json"), + } + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "missing provider metadata is not an error") + assert.Equal(t, middleware.DecisionAllow, out.Decision, "decision must be Allow") + assert.Empty(t, out.Metadata, "no metadata when provider context is missing") +} + +func TestInvoke_UnknownProvider_NoOp(t *testing.T) { + m := newTestMiddleware(t) + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "application/json"}}, + RespBody: loadFixture(t, "openai_chat_completion.json"), + Metadata: []middleware.KV{{Key: middleware.KeyLLMProvider, Value: "cohere"}}, + } + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "unknown provider must not surface an error") + assert.Empty(t, out.Metadata, "unknown providers emit no metadata") +} + +func TestInvoke_ErrorStatus_NoUsageEmitted(t *testing.T) { + m := newTestMiddleware(t) + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 500, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "application/json"}}, + RespBody: []byte(`{"error":{"message":"upstream blew up"}}`), + Metadata: []middleware.KV{{Key: middleware.KeyLLMProvider, Value: "openai"}}, + } + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "error responses must not surface as middleware error") + _, ok := metaValue(out.Metadata, middleware.KeyLLMInputTokens) + assert.False(t, ok, "no usage metadata on >=400 responses") +} + +func TestInvoke_NonInspectedContentType_NoOp(t *testing.T) { + m := newTestMiddleware(t) + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "text/plain"}}, + RespBody: []byte("not json"), + Metadata: []middleware.KV{{Key: middleware.KeyLLMProvider, Value: "openai"}}, + } + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "Invoke must tolerate non-inspected content types") + assert.Empty(t, out.Metadata, "no metadata for non-JSON, non-SSE bodies") +} + +func TestInvoke_NilInput(t *testing.T) { + m := newTestMiddleware(t) + out, err := m.Invoke(context.Background(), nil) + require.NoError(t, err, "nil input must not error") + require.Equal(t, middleware.DecisionAllow, out.Decision, "decision must be Allow even on nil input") + assert.Empty(t, out.Metadata, "no metadata for nil input") +} + +func TestInvoke_CompletionTruncatedAt3500Bytes(t *testing.T) { + m := newTestMiddleware(t) + long := strings.Repeat("x", 5000) + body := []byte(`{"id":"x","choices":[{"message":{"role":"assistant","content":"` + long + `"}}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`) + + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "application/json"}}, + RespBody: body, + Metadata: []middleware.KV{{Key: middleware.KeyLLMProvider, Value: "openai"}}, + } + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "long-completion body must parse cleanly") + + completion, ok := metaValue(out.Metadata, middleware.KeyLLMResponseCompletion) + require.True(t, ok, "completion must be emitted for long body") + assert.LessOrEqual(t, len(completion), maxCompletionBytes, "completion must be truncated to <=3500 bytes") + assert.Equal(t, maxCompletionBytes, len(completion), "completion must be truncated exactly at the cap when input is ASCII and longer") +} + +// TestInvoke_RedactPii_RedactsCompletionBeforeEmit covers the GC contract on +// the response leg: when the synthesiser sets redact_pii=true, the value +// emitted as llm.response_completion must already be redacted, so the +// access-log row never carries raw emails / SSNs / phones the model generated. +// Without this, the response side leaked dozens of raw PII tokens per request. +func TestInvoke_RedactPii_RedactsCompletionBeforeEmit(t *testing.T) { + mw, err := Factory{}.New([]byte(`{"redact_pii":true}`)) + require.NoError(t, err) + + piiCompletion := "Sample record: Alice Johnson, alice.johnson@example.com, SSN 123-45-6789, phone (202) 555-0147. Bob: 202/555/0108." + body := []byte(`{"id":"x","choices":[{"message":{"role":"assistant","content":"` + piiCompletion + `"}}],"usage":{"prompt_tokens":10,"completion_tokens":50,"total_tokens":60}}`) + + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "application/json"}}, + RespBody: body, + Metadata: []middleware.KV{{Key: middleware.KeyLLMProvider, Value: "openai"}}, + } + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out) + + completion, ok := metaValue(out.Metadata, middleware.KeyLLMResponseCompletion) + require.True(t, ok, "completion key must be emitted") + assert.Contains(t, completion, "[REDACTED:email]", "email must be redacted before emit") + assert.Contains(t, completion, "[REDACTED:ssn]", "ssn must be redacted before emit") + assert.Contains(t, completion, "[REDACTED:phone]", "phone must be redacted before emit") + assert.NotContains(t, completion, "alice.johnson@example.com", "raw email must not survive") + assert.NotContains(t, completion, "123-45-6789", "raw SSN must not survive") + assert.NotContains(t, completion, "(202) 555-0147", "parens-phone must not survive") + assert.NotContains(t, completion, "202/555/0108", "slash-phone must not survive") +} + +// TestInvoke_CaptureCompletionOff_DoesNotEmitCompletion mirrors the request +// parser test: when capture_completion=false (operator has enable_prompt_ +// collection off), llm.response_completion MUST NOT appear in the access log. +// The token / cost / usage facts the response parser also emits stay so +// operators still get billing data on log-only mode. +func TestInvoke_CaptureCompletionOff_DoesNotEmitCompletion(t *testing.T) { + mw, err := Factory{}.New([]byte(`{"capture_completion":false}`)) + require.NoError(t, err) + body := []byte(`{"id":"x","choices":[{"message":{"role":"assistant","content":"alice@example.com 123-45-6789"}}],"usage":{"prompt_tokens":10,"completion_tokens":20,"total_tokens":30}}`) + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "application/json"}}, + RespBody: body, + Metadata: []middleware.KV{{Key: middleware.KeyLLMProvider, Value: "openai"}}, + } + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + + _, ok := metaValue(out.Metadata, middleware.KeyLLMResponseCompletion) + assert.False(t, ok, "llm.response_completion must NOT be emitted when capture_completion is false") + + // Token facts must still flow. + _, ok = metaValue(out.Metadata, middleware.KeyLLMInputTokens) + assert.True(t, ok, "input tokens fact must still be emitted") + _, ok = metaValue(out.Metadata, middleware.KeyLLMOutputTokens) + assert.True(t, ok, "output tokens fact must still be emitted") +} + +// TestInvoke_CaptureCompletionUnset_PreservesLegacyEmission documents the +// default behavior: empty config keeps emitting completion, so callers +// without the toggle aren't broken. +func TestInvoke_CaptureCompletionUnset_PreservesLegacyEmission(t *testing.T) { + mw, err := Factory{}.New([]byte(`{}`)) + require.NoError(t, err) + body := []byte(`{"id":"x","choices":[{"message":{"role":"assistant","content":"hello"}}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`) + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "application/json"}}, + RespBody: body, + Metadata: []middleware.KV{{Key: middleware.KeyLLMProvider, Value: "openai"}}, + } + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + _, ok := metaValue(out.Metadata, middleware.KeyLLMResponseCompletion) + assert.True(t, ok, "absent capture_completion must preserve emission (backwards-compatible default)") +} + +// TestInvoke_RedactPii_OffShipsRawCompletion covers the inverse: with +// redact_pii=false (default) the model output is shipped verbatim. +func TestInvoke_RedactPii_OffShipsRawCompletion(t *testing.T) { + mw, err := Factory{}.New(nil) + require.NoError(t, err) + + body := []byte(`{"id":"x","choices":[{"message":{"role":"assistant","content":"alice@example.com 123-45-6789"}}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`) + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "application/json"}}, + RespBody: body, + Metadata: []middleware.KV{{Key: middleware.KeyLLMProvider, Value: "openai"}}, + } + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + + completion, ok := metaValue(out.Metadata, middleware.KeyLLMResponseCompletion) + require.True(t, ok) + assert.Contains(t, completion, "alice@example.com", "redact off → raw email passes through") + assert.Contains(t, completion, "123-45-6789", "redact off → raw SSN passes through") + assert.NotContains(t, completion, "[REDACTED:", "redact off → no markers") +} + +func TestInvoke_CompletionTruncationRuneSafe(t *testing.T) { + rune4 := "\xf0\x9f\x98\x80" // 4-byte emoji + body := strings.Repeat("a", maxCompletionBytes-1) + rune4 + require.Greater(t, len(body), maxCompletionBytes, "test setup must exceed the cap") + + got := truncateCompletion(body) + assert.True(t, len(got) < maxCompletionBytes, "truncated bytes must drop the partial rune entirely") + assert.NotContains(t, got, "\x80", "truncated text must not end on a continuation byte") +} diff --git a/proxy/internal/middleware/builtin/llm_response_parser/responses_stream_test.go b/proxy/internal/middleware/builtin/llm_response_parser/responses_stream_test.go new file mode 100644 index 000000000..0475b8cca --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_response_parser/responses_stream_test.go @@ -0,0 +1,69 @@ +package llm_response_parser + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/middleware" +) + +// TestInvoke_OpenAIResponsesStreaming is the regression guard for the live +// bug where Codex hits /v1/responses (the OpenAI Responses API), whose SSE +// shape differs from chat.completions: completion text rides +// response.output_text.delta and usage rides response.completed under +// response.usage. The old parser only knew the chat.completions shape, so +// resp_meta came back empty (no tokens, no cost). +func TestInvoke_OpenAIResponsesStreaming(t *testing.T) { + m := newTestMiddleware(t) + body := loadFixture(t, "openai_responses_stream.txt") + + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "text/event-stream; charset=utf-8"}}, + RespBody: body, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "gpt-5.5"}, + }, + } + + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "Invoke must not error on a Responses-API streaming body") + + inTok, ok := metaValue(out.Metadata, middleware.KeyLLMInputTokens) + require.True(t, ok, "input tokens must be emitted from a Responses-API stream") + assert.Equal(t, "123", inTok, "input_tokens must come from response.completed usage") + + outTok, _ := metaValue(out.Metadata, middleware.KeyLLMOutputTokens) + assert.Equal(t, "45", outTok, "output_tokens must come from response.completed usage") + + totTok, _ := metaValue(out.Metadata, middleware.KeyLLMTotalTokens) + assert.Equal(t, "168", totTok, "total_tokens must come from response.completed usage") + + cached, ok := metaValue(out.Metadata, middleware.KeyLLMCachedInputTokens) + require.True(t, ok, "cached input tokens must surface from input_tokens_details") + assert.Equal(t, "40", cached, "cached_tokens subset must surface for cost discounting") + + completion, ok := metaValue(out.Metadata, middleware.KeyLLMResponseCompletion) + require.True(t, ok, "completion must be emitted for Responses-API streams") + assert.Equal(t, "Hello, world!", completion, "output_text.delta events must concatenate") +} + +// TestAccumulateOpenAIStream_ResponsesNoUsage confirms that a Responses-API +// stream with text but no terminal usage frame still yields the completion +// and leaves tokens at zero rather than erroring. +func TestAccumulateOpenAIStream_ResponsesNoUsage(t *testing.T) { + body := []byte(`event: response.output_text.delta +data: {"type":"response.output_text.delta","delta":"partial"} + +`) + + usage, completion := accumulateOpenAIStream(body) + assert.Equal(t, int64(0), usage.InputTokens, "no usage frame leaves input tokens at zero") + assert.Equal(t, int64(0), usage.OutputTokens, "no usage frame leaves output tokens at zero") + assert.Equal(t, "partial", completion, "output_text deltas accumulate even without a usage frame") +} diff --git a/proxy/internal/middleware/builtin/llm_response_parser/streaming.go b/proxy/internal/middleware/builtin/llm_response_parser/streaming.go new file mode 100644 index 000000000..6462ab4ae --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_response_parser/streaming.go @@ -0,0 +1,270 @@ +package llm_response_parser + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "strings" + + "github.com/netbirdio/netbird/proxy/internal/llm" +) + +// openAIDoneSentinel is the OpenAI end-of-stream marker. The scanner +// stops once this data frame is observed. +const openAIDoneSentinel = "[DONE]" + +// accumulateStream walks the SSE byte slice, dispatches per provider, +// and returns the running token-usage and concatenated completion text. +// Errors from the scanner short-circuit accumulation but never panic +// — partial results are returned for truncated bodies. +func accumulateStream(provider string, body []byte) (llm.Usage, string) { + switch provider { + case "openai": + return accumulateOpenAIStream(body) + case "anthropic": + return accumulateAnthropicStream(body) + case llm.ProviderNameBedrock: + return accumulateBedrockStream(body) + default: + return llm.Usage{}, "" + } +} + +// openAIStreamUsage is the usage block shared by both OpenAI streaming +// envelopes. Pointer fields tell "absent" from zero; the chat.completions +// (prompt_/completion_) and Responses-API (input_/output_) names are both +// accepted so a single decode covers either endpoint. +type openAIStreamUsage struct { + PromptTokens *int64 `json:"prompt_tokens"` + CompletionTokens *int64 `json:"completion_tokens"` + InputTokens *int64 `json:"input_tokens"` + OutputTokens *int64 `json:"output_tokens"` + TotalTokens *int64 `json:"total_tokens"` + PromptTokensDetails *struct { + CachedTokens *int64 `json:"cached_tokens"` + } `json:"prompt_tokens_details"` + InputTokensDetails *struct { + CachedTokens *int64 `json:"cached_tokens"` + } `json:"input_tokens_details"` +} + +// openAIStreamChunk matches both OpenAI streaming envelopes. The +// chat.completions chunk carries text in choices[].delta.content and a +// trailing top-level usage block. The Responses API (/v1/responses) emits +// typed events instead: completion text rides response.output_text.delta +// (top-level "delta" string) and the final usage rides response.completed +// under response.usage. Only fields used for accumulation are declared. +type openAIStreamChunk struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + } `json:"choices"` + Usage *openAIStreamUsage `json:"usage"` + + Type string `json:"type"` + Delta json.RawMessage `json:"delta"` + Response *struct { + Usage *openAIStreamUsage `json:"usage"` + } `json:"response"` +} + +// accumulateOpenAIStream sums per-chunk content deltas and lifts the usage +// block off the final frame, handling both the chat.completions and the +// Responses-API event shapes. Clients without stream_options.include_usage +// (chat.completions) and any provider that omits the final usage simply +// leave tokens at zero; the caller chooses what to emit. +func accumulateOpenAIStream(body []byte) (llm.Usage, string) { + var ( + usage llm.Usage + completion strings.Builder + ) + scanner := llm.NewScanner(bytes.NewReader(body)) + for { + ev, err := scanner.Next() + if err != nil { + break + } + if ev.Data == openAIDoneSentinel { + break + } + if ev.Data == "" { + continue + } + + var chunk openAIStreamChunk + if err := json.Unmarshal([]byte(ev.Data), &chunk); err != nil { + continue + } + for _, c := range chunk.Choices { + completion.WriteString(c.Delta.Content) + } + if chunk.Type == "response.output_text.delta" { + if s, ok := decodeJSONString(chunk.Delta); ok { + completion.WriteString(s) + } + } + + u := chunk.Usage + if u == nil && chunk.Response != nil { + u = chunk.Response.Usage + } + applyOpenAIStreamUsage(u, &usage) + } + return usage, completion.String() +} + +// applyOpenAIStreamUsage lifts the token counts off a final-frame usage +// block into the running usage, normalising the chat.completions +// (prompt_/completion_) and Responses-API (input_/output_) names and +// backfilling total tokens when the provider omits them. +func applyOpenAIStreamUsage(u *openAIStreamUsage, usage *llm.Usage) { + if u == nil { + return + } + usage.InputTokens = pickInt64(u.InputTokens, u.PromptTokens) + usage.OutputTokens = pickInt64(u.OutputTokens, u.CompletionTokens) + usage.TotalTokens = derefInt64(u.TotalTokens) + if u.InputTokensDetails != nil { + if v := derefInt64(u.InputTokensDetails.CachedTokens); v > 0 { + usage.CachedInputTokens = v + } + } + if usage.CachedInputTokens == 0 && u.PromptTokensDetails != nil { + usage.CachedInputTokens = derefInt64(u.PromptTokensDetails.CachedTokens) + } + if usage.TotalTokens == 0 && (usage.InputTokens > 0 || usage.OutputTokens > 0) { + usage.TotalTokens = usage.InputTokens + usage.OutputTokens + } +} + +// decodeJSONString unmarshals a JSON-encoded string value, returning +// ok=false when the raw message is empty or not a string. +func decodeJSONString(raw json.RawMessage) (string, bool) { + if len(raw) == 0 { + return "", false + } + var s string + if err := json.Unmarshal(raw, &s); err != nil { + return "", false + } + return s, true +} + +// anthropicStreamEvent captures the union of Messages-API stream event +// payloads we care about. Each named event on the wire fills only its +// shape's fields; unknown keys are ignored. +type anthropicStreamUsage struct { + InputTokens *int64 `json:"input_tokens"` + OutputTokens *int64 `json:"output_tokens"` + CacheReadInputTokens *int64 `json:"cache_read_input_tokens"` + CacheCreationInputTokens *int64 `json:"cache_creation_input_tokens"` +} + +type anthropicStreamEvent struct { + Type string `json:"type"` + Message *struct { + Usage *anthropicStreamUsage `json:"usage"` + } `json:"message"` + Delta *struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"delta"` + Usage *anthropicStreamUsage `json:"usage"` +} + +// accumulateAnthropicStream tracks input_tokens from message_start, +// output_tokens from message_delta, and concatenates text_delta payloads +// from content_block_delta events. Final usage prefers message_delta +// values which carry the post-completion totals. +func accumulateAnthropicStream(body []byte) (llm.Usage, string) { + var ( + usage llm.Usage + completion strings.Builder + ) + scanner := llm.NewScanner(bytes.NewReader(body)) + for { + ev, err := scanner.Next() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + break + } + if ev.Data == "" { + continue + } + + var payload anthropicStreamEvent + if err := json.Unmarshal([]byte(ev.Data), &payload); err != nil { + continue + } + + eventType := ev.Type + if eventType == "" { + eventType = payload.Type + } + applyAnthropicStreamEvent(eventType, payload, &usage, &completion) + } + if usage.InputTokens > 0 || usage.OutputTokens > 0 { + usage.TotalTokens = usage.InputTokens + usage.OutputTokens + usage.CachedInputTokens + usage.CacheCreationTokens + } + return usage, completion.String() +} + +// applyAnthropicStreamEvent folds one parsed Anthropic Messages stream event +// into the running usage/completion. Shared by the SSE accumulator and the +// Bedrock InvokeModel event-stream, whose chunks wrap the same event JSON. +func applyAnthropicStreamEvent(eventType string, payload anthropicStreamEvent, usage *llm.Usage, completion *strings.Builder) { + switch eventType { + case "message_start": + if payload.Message != nil { + applyAnthropicStreamUsage(payload.Message.Usage, usage) + } + case "content_block_delta": + if payload.Delta != nil && payload.Delta.Type == "text_delta" { + completion.WriteString(payload.Delta.Text) + } + case "message_delta": + applyAnthropicStreamUsage(payload.Usage, usage) + case "message_stop": + // No-op; Anthropic does not emit usage here. + } +} + +// applyAnthropicStreamUsage folds a non-nil Anthropic usage block into the +// running totals. Each field overwrites only when present and positive, so +// message_delta's post-completion counts supersede the message_start seed +// without zeroing dimensions a later event omits. +func applyAnthropicStreamUsage(u *anthropicStreamUsage, usage *llm.Usage) { + if u == nil { + return + } + if v := derefInt64(u.InputTokens); v > 0 { + usage.InputTokens = v + } + if v := derefInt64(u.OutputTokens); v > 0 { + usage.OutputTokens = v + } + if v := derefInt64(u.CacheReadInputTokens); v > 0 { + usage.CachedInputTokens = v + } + if v := derefInt64(u.CacheCreationInputTokens); v > 0 { + usage.CacheCreationTokens = v + } +} + +func pickInt64(preferred, fallback *int64) int64 { + if preferred != nil { + return *preferred + } + return derefInt64(fallback) +} + +func derefInt64(v *int64) int64 { + if v == nil { + return 0 + } + return *v +} diff --git a/proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock.go b/proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock.go new file mode 100644 index 000000000..a82a9cdbc --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock.go @@ -0,0 +1,110 @@ +package llm_response_parser + +import ( + "bytes" + "encoding/json" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" + + "github.com/netbirdio/netbird/proxy/internal/llm" +) + +// bedrockEventTypeHeader names each AWS event-stream frame's event type. +const bedrockEventTypeHeader = ":event-type" + +// accumulateBedrockStream decodes the AWS binary event-stream returned by +// Bedrock's streaming endpoints and folds it into running usage/completion. +// Two framings are handled: +// - InvokeModel (invoke-with-response-stream): each "chunk" frame's payload is +// {"bytes":""} wrapping a vendor-native (Anthropic) stream event. +// - Converse (converse-stream): native frames (contentBlockDelta, metadata, …) +// whose payload JSON carries text deltas and a final usage block. +// +// A truncated stream (cut at the capture cap) decodes best-effort: frames up to +// the cut are applied and the partial usage is returned. +func accumulateBedrockStream(body []byte) (llm.Usage, string) { + var ( + usage llm.Usage + completion strings.Builder + ) + dec := eventstream.NewDecoder() + r := bytes.NewReader(body) + for { + msg, err := dec.Decode(r, nil) + if err != nil { + break // EOF or a partial trailing frame — return what we have. + } + eventType := "" + if v := msg.Headers.Get(bedrockEventTypeHeader); v != nil { + eventType = v.String() + } + if eventType == "chunk" { + applyBedrockInvokeChunk(msg.Payload, &usage, &completion) + continue + } + applyConverseStreamEvent(eventType, msg.Payload, &usage, &completion) + } + if usage.TotalTokens == 0 && (usage.InputTokens > 0 || usage.OutputTokens > 0) { + usage.TotalTokens = usage.InputTokens + usage.OutputTokens + usage.CachedInputTokens + usage.CacheCreationTokens + } + return usage, completion.String() +} + +// applyBedrockInvokeChunk decodes an InvokeModel stream "chunk" frame +// ({"bytes":""}) and folds the wrapped Anthropic event +// into usage/completion via the shared accumulator. +func applyBedrockInvokeChunk(payload []byte, usage *llm.Usage, completion *strings.Builder) { + var wrap struct { + Bytes []byte `json:"bytes"` // base64 string — encoding/json decodes it + } + if err := json.Unmarshal(payload, &wrap); err != nil || len(wrap.Bytes) == 0 { + return + } + var ev anthropicStreamEvent + if err := json.Unmarshal(wrap.Bytes, &ev); err != nil { + return + } + applyAnthropicStreamEvent(ev.Type, ev, usage, completion) +} + +// converseStreamEvent captures the Converse stream frames carrying completion +// text (contentBlockDelta) and the final token usage (metadata). +type converseStreamEvent struct { + Delta *struct { + Text string `json:"text"` + } `json:"delta"` + Usage *struct { + InputTokens int64 `json:"inputTokens"` + OutputTokens int64 `json:"outputTokens"` + TotalTokens int64 `json:"totalTokens"` + } `json:"usage"` +} + +// applyConverseStreamEvent folds one native Converse stream frame into the +// running usage/completion: contentBlockDelta carries assistant text, and the +// trailing metadata frame carries the final usage block. +func applyConverseStreamEvent(eventType string, payload []byte, usage *llm.Usage, completion *strings.Builder) { + var ev converseStreamEvent + if err := json.Unmarshal(payload, &ev); err != nil { + return + } + switch eventType { + case "contentBlockDelta": + if ev.Delta != nil { + completion.WriteString(ev.Delta.Text) + } + case "metadata": + if ev.Usage != nil { + if ev.Usage.InputTokens > 0 { + usage.InputTokens = ev.Usage.InputTokens + } + if ev.Usage.OutputTokens > 0 { + usage.OutputTokens = ev.Usage.OutputTokens + } + if ev.Usage.TotalTokens > 0 { + usage.TotalTokens = ev.Usage.TotalTokens + } + } + } +} diff --git a/proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock_test.go b/proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock_test.go new file mode 100644 index 000000000..f93505882 --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_response_parser/streaming_bedrock_test.go @@ -0,0 +1,74 @@ +package llm_response_parser + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" + "github.com/stretchr/testify/require" +) + +// bedrockFrame encodes a single AWS event-stream frame with the given +// :event-type header and JSON payload, mirroring what Bedrock sends. +func bedrockFrame(t *testing.T, eventType string, payload []byte) []byte { + t.Helper() + var buf bytes.Buffer + enc := eventstream.NewEncoder() + err := enc.Encode(&buf, eventstream.Message{ + Headers: eventstream.Headers{{Name: ":event-type", Value: eventstream.StringValue(eventType)}}, + Payload: payload, + }) + require.NoError(t, err, "encode event-stream frame") + return buf.Bytes() +} + +func mustJSON(t *testing.T, v any) []byte { + t.Helper() + b, err := json.Marshal(v) + require.NoError(t, err) + return b +} + +func TestAccumulateBedrockStream_Invoke(t *testing.T) { + // invoke-with-response-stream: each "chunk" frame wraps a base64-encoded + // Anthropic stream event under {"bytes": ...}. + events := [][]byte{ + mustJSON(t, map[string]any{"type": "message_start", "message": map[string]any{"usage": map[string]any{"input_tokens": 13}}}), + mustJSON(t, map[string]any{"type": "content_block_delta", "delta": map[string]any{"type": "text_delta", "text": "po"}}), + mustJSON(t, map[string]any{"type": "content_block_delta", "delta": map[string]any{"type": "text_delta", "text": "ng"}}), + mustJSON(t, map[string]any{"type": "message_delta", "usage": map[string]any{"output_tokens": 5}}), + } + var body bytes.Buffer + for _, ev := range events { + wrap := mustJSON(t, map[string]any{"bytes": base64.StdEncoding.EncodeToString(ev)}) + body.Write(bedrockFrame(t, "chunk", wrap)) + } + + usage, completion := accumulateBedrockStream(body.Bytes()) + require.Equal(t, int64(13), usage.InputTokens, "input tokens from message_start") + require.Equal(t, int64(5), usage.OutputTokens, "output tokens from message_delta") + require.Equal(t, int64(18), usage.TotalTokens, "total is additive") + require.Equal(t, "pong", completion, "text deltas concatenated") +} + +func TestAccumulateBedrockStream_Converse(t *testing.T) { + var body bytes.Buffer + body.Write(bedrockFrame(t, "contentBlockDelta", mustJSON(t, map[string]any{"delta": map[string]any{"text": "po"}}))) + body.Write(bedrockFrame(t, "contentBlockDelta", mustJSON(t, map[string]any{"delta": map[string]any{"text": "ng"}}))) + body.Write(bedrockFrame(t, "metadata", mustJSON(t, map[string]any{"usage": map[string]any{"inputTokens": 11, "outputTokens": 3, "totalTokens": 14}}))) + + usage, completion := accumulateBedrockStream(body.Bytes()) + require.Equal(t, int64(11), usage.InputTokens, "input tokens from metadata frame") + require.Equal(t, int64(3), usage.OutputTokens, "output tokens from metadata frame") + require.Equal(t, int64(14), usage.TotalTokens, "total from metadata frame") + require.Equal(t, "pong", completion, "converse text deltas concatenated") +} + +func TestAccumulateBedrockStream_Truncated(t *testing.T) { + // A body cut mid-frame must not panic; partial usage is returned. + full := bedrockFrame(t, "metadata", mustJSON(t, map[string]any{"usage": map[string]any{"inputTokens": 11, "outputTokens": 3}})) + usage, _ := accumulateBedrockStream(full[:len(full)-4]) + require.Zero(t, usage.OutputTokens, "truncated trailing frame is dropped, not panicked on") +} diff --git a/proxy/internal/middleware/builtin/llm_response_parser/streaming_test.go b/proxy/internal/middleware/builtin/llm_response_parser/streaming_test.go new file mode 100644 index 000000000..400aac0bd --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_response_parser/streaming_test.go @@ -0,0 +1,169 @@ +package llm_response_parser + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/middleware" +) + +func TestInvoke_OpenAIStreamingWithUsage(t *testing.T) { + m := newTestMiddleware(t) + body := loadFixture(t, "openai_stream.txt") + + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "text/event-stream"}}, + RespBody: body, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "openai"}, + {Key: middleware.KeyLLMModel, Value: "gpt-4o-mini"}, + }, + } + + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "Invoke must not error on streaming OpenAI body") + + in123, _ := metaValue(out.Metadata, middleware.KeyLLMInputTokens) + assert.Equal(t, "123", in123, "input tokens must come from final-chunk usage block") + + outTok, _ := metaValue(out.Metadata, middleware.KeyLLMOutputTokens) + assert.Equal(t, "45", outTok, "output tokens must come from final-chunk usage block") + + totTok, _ := metaValue(out.Metadata, middleware.KeyLLMTotalTokens) + assert.Equal(t, "168", totTok, "total tokens must come from final-chunk usage block") + + completion, ok := metaValue(out.Metadata, middleware.KeyLLMResponseCompletion) + require.True(t, ok, "completion must be emitted for streaming responses") + assert.Equal(t, "Hello, world!", completion, "deltas must concatenate into the buffered fixture's text") +} + +func TestInvoke_OpenAIStreamingWithoutUsage(t *testing.T) { + body := []byte(`data: {"choices":[{"delta":{"content":"Hi"}}]} + +data: {"choices":[{"delta":{"content":" there"}}]} + +data: [DONE] + +`) + + usage, completion := accumulateOpenAIStream(body) + assert.Equal(t, int64(0), usage.InputTokens, "input tokens must stay zero without a usage frame") + assert.Equal(t, int64(0), usage.OutputTokens, "output tokens must stay zero without a usage frame") + assert.Equal(t, int64(0), usage.TotalTokens, "total tokens must stay zero without a usage frame") + assert.Equal(t, "Hi there", completion, "deltas must still accumulate when usage is absent") +} + +func TestInvoke_OpenAIStreamingNoUsage_OmitsUsageMetadata(t *testing.T) { + m := newTestMiddleware(t) + body := []byte(`data: {"choices":[{"delta":{"content":"Hello"}}]} + +data: [DONE] + +`) + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "text/event-stream"}}, + RespBody: body, + Metadata: []middleware.KV{{Key: middleware.KeyLLMProvider, Value: "openai"}}, + } + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "Invoke must not error on usage-less streams") + + _, hasIn := metaValue(out.Metadata, middleware.KeyLLMInputTokens) + _, hasOut := metaValue(out.Metadata, middleware.KeyLLMOutputTokens) + _, hasTot := metaValue(out.Metadata, middleware.KeyLLMTotalTokens) + assert.False(t, hasIn, "input tokens omitted when no usage frame") + assert.False(t, hasOut, "output tokens omitted when no usage frame") + assert.False(t, hasTot, "total tokens omitted when no usage frame") + + completion, ok := metaValue(out.Metadata, middleware.KeyLLMResponseCompletion) + require.True(t, ok, "completion must still be emitted from deltas") + assert.Equal(t, "Hello", completion, "completion must come from delta accumulation") +} + +func TestInvoke_AnthropicStreaming(t *testing.T) { + m := newTestMiddleware(t) + body := loadFixture(t, "anthropic_stream.txt") + + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "text/event-stream"}}, + RespBody: body, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: "anthropic"}, + {Key: middleware.KeyLLMModel, Value: "claude-sonnet-4-5"}, + }, + } + + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "Invoke must not error on streaming Anthropic body") + + in123, _ := metaValue(out.Metadata, middleware.KeyLLMInputTokens) + assert.Equal(t, "123", in123, "input tokens must come from message_start usage") + + outTok, _ := metaValue(out.Metadata, middleware.KeyLLMOutputTokens) + assert.Equal(t, "45", outTok, "output tokens must come from message_delta usage") + + totTok, _ := metaValue(out.Metadata, middleware.KeyLLMTotalTokens) + assert.Equal(t, "168", totTok, "total tokens must be input+output for anthropic streaming") + + completion, ok := metaValue(out.Metadata, middleware.KeyLLMResponseCompletion) + require.True(t, ok, "completion must be emitted from text_delta accumulation") + assert.Equal(t, "Hello, world!", completion, "anthropic streaming text must accumulate across content_block_delta events") +} + +func TestInvoke_StreamingTruncatedBody_BestEffort(t *testing.T) { + m := newTestMiddleware(t) + full := loadFixture(t, "anthropic_stream.txt") + cut := len(full) / 2 + truncated := full[:cut] + + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "text/event-stream"}}, + RespBody: truncated, + RespBodyTruncated: true, + Metadata: []middleware.KV{{Key: middleware.KeyLLMProvider, Value: "anthropic"}}, + } + + require.NotPanics(t, func() { + _, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "truncated streaming body must not surface as error") + }, "Invoke must never panic on a truncated SSE body") +} + +func TestInvoke_StreamingEmptyBody(t *testing.T) { + m := newTestMiddleware(t) + in := &middleware.Input{ + Slot: middleware.SlotOnResponse, + Status: 200, + RespHeaders: []middleware.KV{{Key: "Content-Type", Value: "text/event-stream"}}, + RespBody: nil, + Metadata: []middleware.KV{{Key: middleware.KeyLLMProvider, Value: "openai"}}, + } + out, err := m.Invoke(context.Background(), in) + require.NoError(t, err, "empty SSE body must not surface as error") + assert.Empty(t, out.Metadata, "no metadata for empty SSE body") +} + +func TestAccumulateAnthropicStream_PartialUsage(t *testing.T) { + body := []byte(`event: message_start +data: {"type":"message_start","message":{"usage":{"input_tokens":10}}} + +event: content_block_delta +data: {"type":"content_block_delta","delta":{"type":"text_delta","text":"hi"}} + +`) + usage, completion := accumulateAnthropicStream(body) + assert.Equal(t, int64(10), usage.InputTokens, "partial input_tokens must survive truncated stream") + assert.Equal(t, int64(0), usage.OutputTokens, "output_tokens stays zero without message_delta") + assert.Equal(t, "hi", completion, "completion must come from observed text_delta events") +} diff --git a/proxy/internal/middleware/builtin/llm_router/factory.go b/proxy/internal/middleware/builtin/llm_router/factory.go new file mode 100644 index 000000000..3c3b607ac --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_router/factory.go @@ -0,0 +1,106 @@ +package llm_router + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" +) + +// ProviderRoute describes one upstream LLM provider the router can +// hand a request to. Models lists the model identifiers the provider +// claims; UpstreamScheme + UpstreamHost replace the synth target's +// placeholder URL on a match. UpstreamPath is the path component of +// the configured upstream URL — the router uses it to disambiguate +// providers that claim the same model: when more than one provider +// matches the model, the route whose UpstreamPath is a prefix of the +// incoming request path is preferred (longest match wins, empty path +// is the catchall). AuthHeaderName + AuthHeaderValue are the +// per-provider credential the router injects after stripping the +// vendor auth headers from the inbound request. +// +// AllowedGroupIDs is the union of source-group IDs across every +// enabled policy that authorises this provider. The router treats it +// as a hard filter: a route whose AllowedGroupIDs has no intersection +// with the caller's UserGroups is removed from the candidate list +// before the path-prefix tiebreak. A route with empty AllowedGroupIDs +// is unreachable; the synthesiser only emits policy-bound routes. +type ProviderRoute struct { + ID string `json:"id"` + // Vendor is the parser surface this provider speaks ("openai", + // "anthropic", …), matching the llm.provider value llm_request_parser + // emits from the request. When set, the router keeps a vendor-tagged + // request on a same-vendor route so catch-all gateways of a different + // vendor can't swallow it. Empty disables vendor filtering for this + // route. + Vendor string `json:"vendor,omitempty"` + Models []string `json:"models"` + UpstreamScheme string `json:"upstream_scheme"` + UpstreamHost string `json:"upstream_host"` + UpstreamPath string `json:"upstream_path,omitempty"` + AuthHeaderName string `json:"auth_header_name"` + AuthHeaderValue string `json:"auth_header_value"` + AllowedGroupIDs []string `json:"allowed_group_ids"` + // Vertex marks a Google Vertex AI provider. Vertex requests carry the + // model in the URL path, so the router selects this route by path + // (isVertexPath) and bypasses the model/vendor table entirely. + Vertex bool `json:"vertex,omitempty"` + // Bedrock marks an AWS Bedrock provider. Bedrock requests carry the model + // in the URL path (/model/{id}/{action}), so the router selects this route + // by path (isBedrockPath) and bypasses the model/vendor table; auth is the + // static AuthHeaderValue bearer token (no token minting). + Bedrock bool `json:"bedrock,omitempty"` + // GCPServiceAccountKeyB64 is a base64-encoded GCP service-account JSON + // key. When set, the router mints + refreshes a short-lived OAuth2 access + // token from it at request time and injects it as the auth header value + // (instead of the static AuthHeaderValue) — so the gateway holds a durable + // Vertex credential rather than a 1-hour token. + GCPServiceAccountKeyB64 string `json:"gcp_sa_key_b64,omitempty"` +} + +// Config is the on-wire configuration accepted by the factory. An +// empty Providers slice yields a router that denies every request as +// not-routable; the synthesiser is responsible for stamping the +// account's enabled providers into this slice. +type Config struct { + Providers []ProviderRoute `json:"providers"` +} + +// Factory builds llm_router instances from raw config bytes. +type Factory struct{} + +// ID returns the registry identifier. +func (Factory) ID() string { return ID } + +// New constructs a middleware instance. Empty, null, and {} configs +// yield a router with an empty Providers slice — every request denies +// with model_not_routable. Non-empty payloads must parse cleanly so +// misconfigurations surface at chain build time. +func (Factory) New(rawConfig []byte) (middleware.Middleware, error) { + cfg := Config{} + if !isEmptyJSON(rawConfig) { + if err := json.Unmarshal(rawConfig, &cfg); err != nil { + return nil, fmt.Errorf("decode config: %w", err) + } + } + return New(cfg), nil +} + +// isEmptyJSON reports whether the payload is whitespace, null, or an +// empty object/array. The caller skips Unmarshal in that case so the +// zero-value Config flows through unchanged. +func isEmptyJSON(raw []byte) bool { + trimmed := strings.TrimSpace(string(bytes.TrimSpace(raw))) + switch trimmed { + case "", "null", "{}", "[]": + return true + } + return false +} + +func init() { + builtin.Register(Factory{}) +} diff --git a/proxy/internal/middleware/builtin/llm_router/middleware.go b/proxy/internal/middleware/builtin/llm_router/middleware.go new file mode 100644 index 000000000..73cc59c95 --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_router/middleware.go @@ -0,0 +1,793 @@ +// Package llm_router implements the SlotOnRequest middleware that +// routes a request to an upstream LLM provider based on the model name +// emitted upstream by llm_request_parser. The router rewrites the +// request's outbound target (scheme + host), strips known LLM-vendor +// auth headers, and injects the per-provider auth header from the +// matched route. Unknown or unconfigured models deny with a 403 and +// the canonical llm_policy.model_not_routable code. +package llm_router + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "time" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + + "github.com/netbirdio/netbird/proxy/internal/middleware" +) + +// gcpScope is the OAuth2 scope minted for Vertex AI service-account auth. +const gcpScope = "https://www.googleapis.com/auth/cloud-platform" + +// gcpTokenTimeout bounds each GCP token mint/refresh HTTP call so a slow or +// unreachable token endpoint can't block the request indefinitely. +const gcpTokenTimeout = 10 * time.Second + +// ID is the registry key for this middleware. +const ID = "llm_router" + +// Version is reported via Middleware.Version(). +const Version = "1.0.0" + +const ( + denyCodeNotRoutable = "llm_policy.model_not_routable" + denyReasonNotRoutable = "model_not_routable" + denyCodeNoAuthorisedRoute = "llm_policy.no_authorised_provider" + denyReasonNoAuthorisedRoute = "no_authorised_provider" + //nolint:gosec // deny code label, not a credential + denyCodeUpstreamAuth = "llm_policy.upstream_auth_failed" + denyCodeUnmeterable = "llm_policy.unmeterable_publisher" + denyReasonUnmeterable = "unmeterable_publisher" +) + +// strippedAuthHeaders is the closed list of vendor authentication +// credentials the router clears before injecting the provider-specific +// credential. Strictly auth headers — vendor-specific metadata +// (anthropic-version, openai-organization, openai-project, etc.) is +// NOT stripped because the client SDK sets those and the upstream +// requires them (e.g. Anthropic returns 400 without +// anthropic-version). Each entry is canonicalised by Go's +// http.Header.Del/Set, so listing the canonical shapes here is +// sufficient. +var strippedAuthHeaders = []string{ + "Authorization", // OpenAI, OpenAI-compatible, most vendors, Bedrock bearer + "Proxy-Authorization", // upstream proxy auth (defense-in-depth) + "x-api-key", // Anthropic + "api-key", // Azure OpenAI + "X-Amz-Date", // AWS SigV4 — strip client-supplied AWS signing material + "X-Amz-Security-Token", + "X-Amz-Content-Sha256", +} + +// Middleware routes requests to upstream LLM providers based on the +// llm.model metadata emitted by llm_request_parser. +type Middleware struct { + cfg Config + // tokenSrc caches one auto-refreshing OAuth2 TokenSource per GCP + // service-account key (keyed by a hash of the key material), so Vertex + // token minting happens once and refreshes are amortised across requests. + tokenMu sync.Mutex + tokenSrc map[string]oauth2.TokenSource +} + +// New constructs a Middleware with the supplied configuration. Empty +// or nil Providers slice yields a router that denies every request as +// not-routable. +func New(cfg Config) *Middleware { + return &Middleware{cfg: cfg, tokenSrc: map[string]oauth2.TokenSource{}} +} + +// ID returns the registry identifier. +func (m *Middleware) ID() string { return ID } + +// Version returns the implementation version. +func (m *Middleware) Version() string { return Version } + +// Slot reports the chain slot the middleware lives in. +func (m *Middleware) Slot() middleware.Slot { return middleware.SlotOnRequest } + +// AcceptedContentTypes returns nil because the router only consults +// the metadata emitted by llm_request_parser. +func (m *Middleware) AcceptedContentTypes() []string { return nil } + +// MetadataKeys is the closed set of metadata keys this middleware may +// emit. The accumulator drops anything outside this allowlist. +func (m *Middleware) MetadataKeys() []string { + return []string{ + middleware.KeyLLMResolvedProviderID, + middleware.KeyLLMAuthorisingGroups, + middleware.KeyLLMPolicyDecision, + middleware.KeyLLMPolicyReason, + } +} + +// MutationsSupported reports that the middleware emits header and +// upstream-rewrite mutations. +func (m *Middleware) MutationsSupported() bool { return true } + +// Close releases resources owned by the middleware. The router is +// stateless, so this is a no-op. +func (m *Middleware) Close() error { return nil } + +// matchOutcome captures why matchRoute returned what it did so the +// caller can distinguish "no provider knows this model" from "providers +// know it but none authorise this peer's groups". +type matchOutcome int + +const ( + matchOutcomeFound matchOutcome = iota + matchOutcomeUnknownModel + matchOutcomeUnauthorised +) + +// Invoke resolves the model to a provider authorised for the caller's +// groups, strips known vendor auth headers, and injects the route's +// auth header. Unknown models deny with model_not_routable; models +// known to a provider that no policy authorises for the caller deny +// with no_authorised_provider. +func (m *Middleware) Invoke(_ context.Context, in *middleware.Input) (*middleware.Output, error) { + // Vertex AI carries the model in the URL path, not the body, and is + // selected by path rather than by the model/vendor table. Route it before + // the model lookup so a model the parser extracted from the path can't be + // claimed by a same-vendor direct provider (e.g. claude-* on api.anthropic.com). + reqPath := requestPath(in.URL) + if isVertexPath(reqPath) { + model, _ := lookupMetadata(in.Metadata, middleware.KeyLLMModel) + // The request parser emits no llm.provider for a Vertex publisher it + // can't parse (e.g. google/gemini). Forwarding such a request would + // bypass token/budget metering, so deny it rather than serve it + // unmetered. + if vendor, _ := lookupMetadata(in.Metadata, middleware.KeyLLMProvider); vendor == "" { + return denyUnmeterable(), nil + } + route, outcome := m.matchVertex(reqPath, model, in.UserGroups) + switch outcome { + case matchOutcomeFound: + return m.allowWithRoute(route, in.UserGroups), nil + case matchOutcomeUnauthorised: + return denyNoAuthorisedRoute(model), nil + default: + return denyUnknownModel(model), nil + } + } + + // Bedrock likewise carries the model in the URL path (/model/{id}/{action}), + // optionally behind a "/bedrock" gateway-namespace prefix. Route it by path + // before the model lookup; when the prefix is present, strip it from the + // forwarded path so the real Bedrock endpoint receives its native path. + if isBedrockPath(reqPath) { + model, _ := lookupMetadata(in.Metadata, middleware.KeyLLMModel) + native, hadPrefix := splitBedrockNamespace(reqPath) + route, outcome := m.matchBedrock(native, model, in.UserGroups) + switch outcome { + case matchOutcomeFound: + out := m.allowWithRoute(route, in.UserGroups) + if hadPrefix && out.Mutations != nil && out.Mutations.RewriteUpstream != nil { + out.Mutations.RewriteUpstream.StripPathPrefix = bedrockNamespacePrefix + } + return out, nil + case matchOutcomeUnauthorised: + return denyNoAuthorisedRoute(model), nil + default: + return denyUnknownModel(model), nil + } + } + + model, ok := lookupMetadata(in.Metadata, middleware.KeyLLMModel) + if !ok || model == "" { + // Non-inference endpoints (model listing) carry no model but still + // need rewriting from the synth placeholder to a real upstream; + // clients such as Codex call GET /v1/models at startup to enumerate + // availability and read a 403 as "model unavailable". + route, outcome := m.matchModelless(requestPath(in.URL), in.UserGroups) + switch outcome { + case matchOutcomeFound: + return m.allowWithRoute(route, in.UserGroups), nil + case matchOutcomeUnauthorised: + // A recognised model-less endpoint exists but no provider + // authorises the caller — deny as an authorisation failure + // rather than masking it as a missing model. + return denyNoAuthorisedRoute(model), nil + default: + return denyMissingModel(), nil + } + } + + vendor, _ := lookupMetadata(in.Metadata, middleware.KeyLLMProvider) + route, outcome := m.matchRoute(model, vendor, requestPath(in.URL), in.UserGroups) + switch outcome { + case matchOutcomeFound: + return m.allowWithRoute(route, in.UserGroups), nil + case matchOutcomeUnauthorised: + return denyNoAuthorisedRoute(model), nil + default: + return denyUnknownModel(model), nil + } +} + +// matchRoute returns the ProviderRoute that should serve the given +// model + request path for a caller in the given user-groups. Selection +// is: +// +// 1. Filter the configured providers to those whose Models list +// contains the model. +// 2. Filter the model-matched candidates to those whose +// AllowedGroupIDs intersect the caller's UserGroups. A route with +// no AllowedGroupIDs is the catch-all: it stays in the list. If +// the model was known but no candidate is authorised for this +// peer, return matchOutcomeUnauthorised so the caller can emit +// the dedicated no_authorised_provider deny code. +// 3. Vendor precedence: when the request carries a detected vendor +// (llm.provider) and at least one candidate is the same vendor, +// drop the rest — a vendor-tagged request must never cross to +// another vendor's route (e.g. an Anthropic call landing on an +// OpenAI-compatible gateway that also claims the model). +// 4. Model precedence over path: a route that explicitly lists the +// model beats a catch-all (empty Models) gateway. +// 5. Disambiguate the survivors by URL path prefix: longest +// UpstreamPath that prefix-matches the request path wins; an empty +// UpstreamPath is the catchall. If none prefix-matches, fall back +// to declaration order so the model stays routable. +func (m *Middleware) matchRoute(model, vendor, reqPath string, userGroups []string) (ProviderRoute, matchOutcome) { + var modelMatched []ProviderRoute + for _, route := range m.cfg.Providers { + if routeClaimsModel(route, model) { + modelMatched = append(modelMatched, route) + } + } + if len(modelMatched) == 0 { + return ProviderRoute{}, matchOutcomeUnknownModel + } + + // Vendor pinning runs BEFORE the group filter so a request the parser + // tagged with a vendor can never cross to another vendor's route — not + // even an authorised one. Narrow to same-vendor routes when any + // model-matched route declares that vendor; setups with no vendor tag on + // any route fall through unchanged. After narrowing, if no same-vendor + // route authorises the caller, that's matchOutcomeUnauthorised (no + // cross-vendor fallback). + if vendor != "" { + if vendorMatched := matchingVendor(modelMatched, vendor); len(vendorMatched) > 0 { + modelMatched = vendorMatched + } + } + + var candidates []ProviderRoute + for _, route := range modelMatched { + if routeAuthorisesGroups(route, userGroups) { + candidates = append(candidates, route) + } + } + if len(candidates) == 0 { + return ProviderRoute{}, matchOutcomeUnauthorised + } + + // Model routing takes precedence over path. A route that explicitly + // lists the model must beat a catch-all (empty Models) gateway that + // claims every model — otherwise an Anthropic request can fall through + // to an OpenAI-compatible gateway declared earlier. Only when no + // candidate explicitly claims the model do the catch-alls compete, and + // the path-prefix tiebreak applies within whichever tier wins. + if explicit := explicitlyClaiming(candidates, model); len(explicit) > 0 { + candidates = explicit + } + if len(candidates) == 1 { + return candidates[0], matchOutcomeFound + } + + best := candidates[0] + bestLen := -1 + for _, c := range candidates { + if !pathPrefixMatches(c.UpstreamPath, reqPath) { + continue + } + if len(c.UpstreamPath) > bestLen { + best = c + bestLen = len(c.UpstreamPath) + } + } + return best, matchOutcomeFound +} + +// isModelLessPath reports whether reqPath is a known OpenAI-shaped +// non-inference endpoint that legitimately carries no model in its +// request (the model-listing endpoints). These must route to an upstream +// rather than deny, so model enumeration works end to end. +func isModelLessPath(reqPath string) bool { + return reqPath == "/v1/models" || strings.HasPrefix(reqPath, "/v1/models/") +} + +// isVertexPath reports whether reqPath is a Google Vertex AI publisher +// endpoint: /v1/projects/{project}/locations/{region}/publishers/{publisher}/ +// models/{model}:{action}. The model + vendor live in the path, so these +// requests are routed by path to the Vertex provider rather than by model. +func isVertexPath(reqPath string) bool { + return strings.HasPrefix(reqPath, "/v1/projects/") && + strings.Contains(reqPath, "/publishers/") && + strings.Contains(reqPath, "/models/") +} + +// bedrockNamespacePrefix is an optional gateway-namespace prefix some clients +// place before the native Bedrock path to disambiguate it from other providers +// that also use "/model/...". It is stripped before forwarding upstream. +const bedrockNamespacePrefix = "/bedrock" + +// splitBedrockNamespace removes an optional "/bedrock" namespace prefix, +// returning the native Bedrock path and whether the prefix was present. +func splitBedrockNamespace(reqPath string) (string, bool) { + if strings.HasPrefix(reqPath, bedrockNamespacePrefix+"/") { + return strings.TrimPrefix(reqPath, bedrockNamespacePrefix), true + } + return reqPath, false +} + +// isBedrockPath reports whether reqPath is an AWS Bedrock runtime model +// endpoint: /model/{modelId}/{action} where action is invoke, +// invoke-with-response-stream, converse, or converse-stream — optionally behind +// a "/bedrock" gateway-namespace prefix. The model lives in the path, so these +// requests are routed by path to the Bedrock provider. +func isBedrockPath(reqPath string) bool { + native, _ := splitBedrockNamespace(reqPath) + if !strings.HasPrefix(native, "/model/") { + return false + } + return strings.HasSuffix(native, "/invoke") || + strings.HasSuffix(native, "/invoke-with-response-stream") || + strings.HasSuffix(native, "/converse") || + strings.HasSuffix(native, "/converse-stream") +} + +// matchVertex selects the Vertex provider authorised for the caller's groups +// and claiming the requested model. +func (m *Middleware) matchVertex(reqPath, model string, userGroups []string) (ProviderRoute, matchOutcome) { + return m.matchPathRoute(reqPath, model, userGroups, func(r ProviderRoute) bool { return r.Vertex }) +} + +// matchBedrock selects the Bedrock provider authorised for the caller's groups +// and claiming the requested model. +func (m *Middleware) matchBedrock(reqPath, model string, userGroups []string) (ProviderRoute, matchOutcome) { + return m.matchPathRoute(reqPath, model, userGroups, func(r ProviderRoute) bool { return r.Bedrock }) +} + +// matchPathRoute selects a path-routed provider (Vertex/Bedrock). These carry +// the model in the URL, so the model/vendor table is bypassed — but the route's +// configured Models allowlist is still enforced (empty Models = catch-all) so a +// provider credential can't be used for models the operator didn't authorise. +// Returns matchOutcomeUnauthorised when no style route authorises the caller's +// groups, matchOutcomeUnknownModel when an authorised route exists but none +// claims the model (or no style route exists at all), else the chosen route +// (longest UpstreamPath prefix-match wins among multiple). +func (m *Middleware) matchPathRoute(reqPath, model string, userGroups []string, isStyle func(ProviderRoute) bool) (ProviderRoute, matchOutcome) { + var styled []ProviderRoute + for _, route := range m.cfg.Providers { + if isStyle(route) { + styled = append(styled, route) + } + } + if len(styled) == 0 { + return ProviderRoute{}, matchOutcomeUnknownModel + } + + var authorised []ProviderRoute + for _, route := range styled { + if routeAuthorisesGroups(route, userGroups) { + authorised = append(authorised, route) + } + } + if len(authorised) == 0 { + return ProviderRoute{}, matchOutcomeUnauthorised + } + + var candidates []ProviderRoute + for _, route := range authorised { + if routeClaimsModel(route, model) { + candidates = append(candidates, route) + } + } + if len(candidates) == 0 { + return ProviderRoute{}, matchOutcomeUnknownModel + } + if len(candidates) == 1 { + return candidates[0], matchOutcomeFound + } + + best := candidates[0] + bestLen := -1 + for _, c := range candidates { + if !pathPrefixMatches(c.UpstreamPath, reqPath) { + continue + } + if len(c.UpstreamPath) > bestLen { + best = c + bestLen = len(c.UpstreamPath) + } + } + return best, matchOutcomeFound +} + +// matchModelless selects a route for a non-inference, model-less request. +// It mirrors matchRoute's group-authorisation filter and path-prefix +// tiebreak but skips the per-model filter, since any provider the caller's +// groups authorise can serve a model-listing request. Returns +// matchOutcomeFound with the chosen route (single authorised provider wins +// outright; multiple fall to the longest UpstreamPath prefix-match, then +// declaration order), matchOutcomeUnauthorised when no provider authorises +// the caller, or matchOutcomeUnknownModel when the path isn't a recognised +// model-less endpoint. +func (m *Middleware) matchModelless(reqPath string, userGroups []string) (ProviderRoute, matchOutcome) { + if !isModelLessPath(reqPath) { + return ProviderRoute{}, matchOutcomeUnknownModel + } + var candidates []ProviderRoute + for _, route := range m.cfg.Providers { + // Vertex/Bedrock are path-routed and don't serve OpenAI-style + // model-listing endpoints; including them here could rewrite a + // GET /v1/models to an upstream that 404s it. + if route.Vertex || route.Bedrock { + continue + } + if routeAuthorisesGroups(route, userGroups) { + candidates = append(candidates, route) + } + } + if len(candidates) == 0 { + return ProviderRoute{}, matchOutcomeUnauthorised + } + if len(candidates) == 1 { + return candidates[0], matchOutcomeFound + } + + best := candidates[0] + bestLen := -1 + for _, c := range candidates { + if !pathPrefixMatches(c.UpstreamPath, reqPath) { + continue + } + if len(c.UpstreamPath) > bestLen { + best = c + bestLen = len(c.UpstreamPath) + } + } + return best, matchOutcomeFound +} + +// routeAuthorisesGroups reports whether the route's AllowedGroupIDs +// intersect the caller's userGroups. A route with empty AllowedGroupIDs +// is unreachable: the synthesiser only emits routes bound to at least +// one enabled policy, so an empty list signals a misconfiguration that +// must not be allowed to fall through. +func routeAuthorisesGroups(r ProviderRoute, userGroups []string) bool { + for _, ug := range userGroups { + for _, ag := range r.AllowedGroupIDs { + if ug == ag { + return true + } + } + } + return false +} + +// authorisingGroupsCSV returns the sorted, deduplicated comma-separated +// intersection of routeGroups and userGroups — i.e. the groups that +// actually authorise the resolved route for this caller. Returns the +// empty string when the intersection is empty (shouldn't happen on the +// allow path, but defensive). +func authorisingGroupsCSV(routeGroups, userGroups []string) string { + if len(routeGroups) == 0 || len(userGroups) == 0 { + return "" + } + allowed := make(map[string]struct{}, len(routeGroups)) + for _, g := range routeGroups { + allowed[g] = struct{}{} + } + seen := make(map[string]struct{}, len(userGroups)) + out := make([]string, 0, len(userGroups)) + for _, ug := range userGroups { + if _, ok := allowed[ug]; !ok { + continue + } + if _, dup := seen[ug]; dup { + continue + } + seen[ug] = struct{}{} + out = append(out, ug) + } + if len(out) == 0 { + return "" + } + sort.Strings(out) + return strings.Join(out, ",") +} + +// matchingVendor returns the subset of routes whose Vendor equals the +// request's detected vendor. Routes with an empty Vendor never match — an +// untagged route can't be asserted to speak the request's surface, so it +// stays out of the vendor-filtered set (but remains eligible via the +// fall-through when no route matches the vendor at all). +func matchingVendor(routes []ProviderRoute, vendor string) []ProviderRoute { + var out []ProviderRoute + for _, r := range routes { + if r.Vendor == vendor { + out = append(out, r) + } + } + return out +} + +// explicitlyClaiming returns the subset of routes whose Models list +// names the model exactly. Catch-all routes (empty Models) are excluded, +// so callers can prefer a provider that genuinely declares the model over +// a gateway that claims everything. +func explicitlyClaiming(routes []ProviderRoute, model string) []ProviderRoute { + var out []ProviderRoute + for _, r := range routes { + for _, candidate := range r.Models { + if candidate == model { + out = append(out, r) + break + } + } + } + return out +} + +// routeClaimsModel reports whether the route's Models list contains +// the given model identifier. An empty Models list is treated as +// "claim every model" — used by gateway-style providers (LiteLLM, +// custom OpenAI-compatible endpoints) that proxy an open-ended set of +// upstream models the operator can't enumerate in NetBird's provider +// config. +func routeClaimsModel(route ProviderRoute, model string) bool { + if len(route.Models) == 0 { + return true + } + for _, candidate := range route.Models { + if candidate == model { + return true + } + } + return false +} + +// pathPrefixMatches reports whether upstreamPath matches reqPath on a path- +// segment boundary: an exact match, or reqPath continuing after +// upstreamPath at a "/" separator. This avoids a sibling base like +// "/openai" spuriously matching "/openai-test". An empty (or "/") +// upstreamPath always matches (catchall). +func pathPrefixMatches(upstreamPath, reqPath string) bool { + if upstreamPath == "" || upstreamPath == "/" { + return true + } + upstreamPath = strings.TrimRight(upstreamPath, "/") + return reqPath == upstreamPath || strings.HasPrefix(reqPath, upstreamPath+"/") +} + +// requestPath extracts the path component from an Input.URL string +// (which is r.URL.String() — typically "/path?query"). Returns the +// raw input on parse failure so the prefix check can still operate on +// the unparsed value. +func requestPath(raw string) string { + if raw == "" { + return "" + } + parsed, err := url.Parse(raw) + if err != nil { + return raw + } + return parsed.Path +} + +// allowWithRoute builds the Output for a successful route match. The +// returned Mutations carry the upstream rewrite plus — riding on it — +// the StripHeaders list and the AuthHeader to inject. +// +// The strip + inject MUST go through UpstreamRewrite (not HeadersAdd / +// HeadersRemove) because the framework's mutation gate runs every +// header change through a denylist that blocks Authorization, +// Cookie, etc. — exactly the headers the router is replacing. The +// proxy's upstream-build path applies AuthHeader / StripHeaders +// directly, bypassing the denylist by virtue of being a trusted +// proxy operation rather than an arbitrary middleware mutation. +// +// Emits the authorising-groups intersection alongside the resolved +// provider id so identity-stamping middlewares (llm_identity_inject) +// tag the request with ONLY the groups that authorised this specific +// route — not every group the peer happens to be in. +func (m *Middleware) allowWithRoute(route ProviderRoute, userGroups []string) *middleware.Output { + rewrite := &middleware.UpstreamRewrite{ + Scheme: route.UpstreamScheme, + Host: route.UpstreamHost, + // UpstreamPath is the path component the operator pasted on + // the provider record (e.g. "/v1/{account}/{gateway}/compat" + // for Cloudflare AI Gateway). Carrying it on the rewrite so + // the proxy's URL composer joins it with the agent's request + // path — without this, the operator's configured upstream + // path is silently dropped and the gateway returns a 4xx for + // the malformed URL. Empty value leaves the original + // target's path untouched. + Path: route.UpstreamPath, + StripHeaders: append([]string(nil), strippedAuthHeaders...), + } + authValue := route.AuthHeaderValue + if route.GCPServiceAccountKeyB64 != "" { + // Mint a short-lived OAuth2 token from the service-account key at + // request time (cached + auto-refreshed) instead of a static value. + bearer, err := m.gcpBearer(route.GCPServiceAccountKeyB64) + if err != nil { + return denyUpstreamAuth() + } + authValue = bearer + } + if route.AuthHeaderName != "" && authValue != "" { + rewrite.AuthHeader = &middleware.AuthHeader{ + Name: route.AuthHeaderName, + Value: authValue, + } + } + return &middleware.Output{ + Decision: middleware.DecisionAllow, + Mutations: &middleware.Mutations{RewriteUpstream: rewrite}, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMResolvedProviderID, Value: route.ID}, + {Key: middleware.KeyLLMAuthorisingGroups, Value: authorisingGroupsCSV(route.AllowedGroupIDs, userGroups)}, + {Key: middleware.KeyLLMPolicyDecision, Value: "allow"}, + }, + } +} + +// gcpBearer returns a "Bearer " value minted from a base64-encoded GCP +// service-account key, using a cached, auto-refreshing token source. +func (m *Middleware) gcpBearer(saKeyB64 string) (string, error) { + ts, err := m.gcpTokenSource(saKeyB64) + if err != nil { + return "", err + } + tok, err := ts.Token() + if err != nil { + return "", fmt.Errorf("mint gcp token: %w", err) + } + return "Bearer " + tok.AccessToken, nil +} + +// gcpTokenSource returns the cached TokenSource for the given service-account +// key, building it (decode base64 → parse JSON → cloud-platform scope) on first +// use. The returned source caches the token and refreshes it before expiry. +func (m *Middleware) gcpTokenSource(saKeyB64 string) (oauth2.TokenSource, error) { + sum := sha256.Sum256([]byte(saKeyB64)) + key := hex.EncodeToString(sum[:]) + + m.tokenMu.Lock() + defer m.tokenMu.Unlock() + if m.tokenSrc == nil { + m.tokenSrc = map[string]oauth2.TokenSource{} + } + if ts, ok := m.tokenSrc[key]; ok { + return ts, nil + } + jsonKey, err := base64.StdEncoding.DecodeString(strings.TrimSpace(saKeyB64)) + if err != nil { + return nil, fmt.Errorf("decode gcp service-account key: %w", err) + } + conf, err := google.JWTConfigFromJSON(jsonKey, gcpScope) + if err != nil { + return nil, fmt.Errorf("parse gcp service-account key: %w", err) + } + // Bound mint/refresh with a timeout HTTP client so a slow token endpoint + // can't hang the request. The oauth2 library uses this client for the + // lifetime of the (auto-refreshing) source. + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Timeout: gcpTokenTimeout}) + ts := conf.TokenSource(ctx) + m.tokenSrc[key] = ts + return ts, nil +} + +// denyUpstreamAuth is returned when the router cannot obtain the upstream +// credential (e.g. a malformed service-account key or an unreachable token +// endpoint). It surfaces as a 502 — an upstream problem, not a policy denial. +func denyUpstreamAuth() *middleware.Output { + return &middleware.Output{ + Decision: middleware.DecisionDeny, + DenyStatus: 502, + DenyReason: &middleware.DenyReason{ + Code: denyCodeUpstreamAuth, + Message: "could not obtain upstream credential", + }, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMPolicyDecision, Value: "deny"}, + {Key: middleware.KeyLLMPolicyReason, Value: "upstream_auth_failed"}, + }, + } +} + +// denyUnmeterable returns the deny envelope for a path-routed request whose +// publisher has no parser surface, so its usage can't be metered. Serving it +// would bypass token/budget caps, so it is rejected with a 403. +func denyUnmeterable() *middleware.Output { + return &middleware.Output{ + Decision: middleware.DecisionDeny, + DenyStatus: 403, + DenyReason: &middleware.DenyReason{ + Code: denyCodeUnmeterable, + Message: "request publisher is not supported for metering", + }, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMPolicyDecision, Value: "deny"}, + {Key: middleware.KeyLLMPolicyReason, Value: denyReasonUnmeterable}, + }, + } +} + +// denyMissingModel returns the deny envelope for a request whose +// envelope has no llm.model metadata. +func denyMissingModel() *middleware.Output { + return &middleware.Output{ + Decision: middleware.DecisionDeny, + DenyStatus: 403, + DenyReason: &middleware.DenyReason{ + Code: denyCodeNotRoutable, + Message: "missing llm.model on request envelope", + }, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMPolicyDecision, Value: "deny"}, + {Key: middleware.KeyLLMPolicyReason, Value: denyReasonNotRoutable}, + }, + } +} + +// denyUnknownModel returns the deny envelope for a model that no +// configured provider claims. +func denyUnknownModel(model string) *middleware.Output { + return &middleware.Output{ + Decision: middleware.DecisionDeny, + DenyStatus: 403, + DenyReason: &middleware.DenyReason{ + Code: denyCodeNotRoutable, + Message: fmt.Sprintf("no provider configured for model %s", model), + Details: map[string]string{"model": model}, + }, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMPolicyDecision, Value: "deny"}, + {Key: middleware.KeyLLMPolicyReason, Value: denyReasonNotRoutable}, + }, + } +} + +// denyNoAuthorisedRoute returns the deny envelope for a model that one +// or more providers claim, but where no policy authorises the caller's +// groups for any of those providers. +func denyNoAuthorisedRoute(model string) *middleware.Output { + return &middleware.Output{ + Decision: middleware.DecisionDeny, + DenyStatus: 403, + DenyReason: &middleware.DenyReason{ + Code: denyCodeNoAuthorisedRoute, + Message: fmt.Sprintf("no policy authorises model %s for the caller's groups", model), + Details: map[string]string{"model": model}, + }, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMPolicyDecision, Value: "deny"}, + {Key: middleware.KeyLLMPolicyReason, Value: denyReasonNoAuthorisedRoute}, + }, + } +} + +// lookupMetadata returns the value for key plus a presence flag so +// callers can distinguish absent from empty. +func lookupMetadata(meta []middleware.KV, key string) (string, bool) { + for _, kv := range meta { + if kv.Key == key { + return kv.Value, true + } + } + return "", false +} diff --git a/proxy/internal/middleware/builtin/llm_router/middleware_test.go b/proxy/internal/middleware/builtin/llm_router/middleware_test.go new file mode 100644 index 000000000..8ae03c5ba --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_router/middleware_test.go @@ -0,0 +1,840 @@ +package llm_router + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/middleware" +) + +// metaValue returns the value for the first KV with the given key. +func metaValue(t *testing.T, kvs []middleware.KV, key string) (string, bool) { + t.Helper() + for _, kv := range kvs { + if kv.Key == key { + return kv.Value, true + } + } + return "", false +} + +// defaultTestGroup is the group id used by routes and inputs in tests +// that don't specifically exercise the group-filter logic. Pairing it +// with the same id on every test route keeps the legacy assertions +// focused on routing/path behaviour without each one having to bake in +// its own ACL. +const defaultTestGroup = "grp-test" + +// newInputWithModel returns an Input carrying llm.model in its metadata +// bag, mimicking the post-llm_request_parser state the router observes +// in production. UserGroups is populated with defaultTestGroup so the +// router's group-filter pass authorises any test route whose +// AllowedGroupIDs contains the same id. +func newInputWithModel(model string) *middleware.Input { + return &middleware.Input{ + Slot: middleware.SlotOnRequest, + Metadata: []middleware.KV{{Key: middleware.KeyLLMModel, Value: model}}, + UserGroups: []string{defaultTestGroup}, + } +} + +// newInputWithModelAndURL returns an Input carrying both llm.model and +// a request URL so router tests can exercise path-based disambiguation. +func newInputWithModelAndURL(model, reqURL string) *middleware.Input { + in := newInputWithModel(model) + in.URL = reqURL + return in +} + +func TestMiddlewareIdentity(t *testing.T) { + mw := New(Config{}) + assert.Equal(t, ID, mw.ID(), "middleware ID must be llm_router") + assert.Equal(t, Version, mw.Version(), "version must match the constant") + assert.Equal(t, middleware.SlotOnRequest, mw.Slot(), "router must run in SlotOnRequest") + assert.True(t, mw.MutationsSupported(), "router must declare mutations support") + assert.Nil(t, mw.AcceptedContentTypes(), "router does not inspect bodies") + assert.ElementsMatch(t, + []string{ + middleware.KeyLLMResolvedProviderID, + middleware.KeyLLMAuthorisingGroups, + middleware.KeyLLMPolicyDecision, + middleware.KeyLLMPolicyReason, + }, + mw.MetadataKeys(), + "metadata key allowlist must match the spec", + ) + require.NoError(t, mw.Close()) +} + +func TestRouter_HappyPath(t *testing.T) { + route := ProviderRoute{ + ID: "openai-prod", + Models: []string{"gpt-4o"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.openai.com", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer sk-test-123", + } + mw := New(Config{Providers: []ProviderRoute{route}}) + + out, err := mw.Invoke(context.Background(), newInputWithModel("gpt-4o")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "matched model must allow") + + require.NotNil(t, out.Mutations, "matched route must emit mutations") + rewrite := out.Mutations.RewriteUpstream + require.NotNil(t, rewrite, "matched route must emit upstream rewrite") + assert.Equal(t, "https", rewrite.Scheme, "rewrite scheme must come from the matched route") + assert.Equal(t, "api.openai.com", rewrite.Host, "rewrite host must come from the matched route") + + assert.ElementsMatch(t, strippedAuthHeaders, rewrite.StripHeaders, + "strip list rides on UpstreamRewrite (bypasses framework denylist) and must cover every known vendor auth header") + require.NotNil(t, rewrite.AuthHeader, "router must inject the auth header via the rewrite (not HeadersAdd) so the proxy bypasses the denylist") + assert.Equal(t, "Authorization", rewrite.AuthHeader.Name, "injected header name must come from the route") + assert.Equal(t, "Bearer sk-test-123", rewrite.AuthHeader.Value, "injected header value must come from the route") + assert.Empty(t, out.Mutations.HeadersAdd, "router must not use HeadersAdd; auth flows through UpstreamRewrite.AuthHeader") + assert.Empty(t, out.Mutations.HeadersRemove, "router must not use HeadersRemove; strip flows through UpstreamRewrite.StripHeaders") + + resolved, ok := metaValue(t, out.Metadata, middleware.KeyLLMResolvedProviderID) + require.True(t, ok, "router must emit llm.resolved_provider_id on a match") + assert.Equal(t, "openai-prod", resolved, "resolved provider id must be the matched route's ID") + dec, _ := metaValue(t, out.Metadata, middleware.KeyLLMPolicyDecision) + assert.Equal(t, "allow", dec, "decision metadata must be allow on a match") +} + +func TestRouter_MissingModel(t *testing.T) { + mw := New(Config{Providers: []ProviderRoute{{ + ID: "openai-prod", + Models: []string{"gpt-4o"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.openai.com", + }}}) + + out, err := mw.Invoke(context.Background(), &middleware.Input{Slot: middleware.SlotOnRequest}) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionDeny, out.Decision, "missing llm.model must deny") + assert.Equal(t, 403, out.DenyStatus, "deny status must be 403") + require.NotNil(t, out.DenyReason, "deny reason must be populated") + assert.Equal(t, "llm_policy.model_not_routable", out.DenyReason.Code, "deny code must be model_not_routable") + assert.Equal(t, "missing llm.model on request envelope", out.DenyReason.Message, "deny message must match spec") + + dec, _ := metaValue(t, out.Metadata, middleware.KeyLLMPolicyDecision) + assert.Equal(t, "deny", dec, "decision metadata must be deny") + reason, _ := metaValue(t, out.Metadata, middleware.KeyLLMPolicyReason) + assert.Equal(t, "model_not_routable", reason, "reason metadata must be model_not_routable") +} + +// newModellessInput returns an Input with no llm.model and the given +// request path, mimicking a GET /v1/models call (which carries no body +// from which a model could be parsed). UserGroups matches defaultTestGroup. +func newModellessInput(reqURL string) *middleware.Input { + return &middleware.Input{ + Slot: middleware.SlotOnRequest, + URL: reqURL, + UserGroups: []string{defaultTestGroup}, + } +} + +func TestRouter_ModelLessPath_RoutesToAuthorisedProvider(t *testing.T) { + route := ProviderRoute{ + ID: "openai-prod", + Models: []string{"gpt-4o"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.openai.com", + } + mw := New(Config{Providers: []ProviderRoute{route}}) + + out, err := mw.Invoke(context.Background(), newModellessInput("/v1/models?client_version=1")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "GET /v1/models must pass through, not deny") + require.NotNil(t, out.Mutations, "a pass-through must rewrite the upstream") + require.NotNil(t, out.Mutations.RewriteUpstream, "model-less route must still rewrite to the real upstream") + assert.Equal(t, "api.openai.com", out.Mutations.RewriteUpstream.Host, "must target the authorised provider's host") + + provider, _ := metaValue(t, out.Metadata, middleware.KeyLLMResolvedProviderID) + assert.Equal(t, "openai-prod", provider, "resolved provider must be the authorised route") +} + +func TestRouter_ModelLessPath_MultiProviderDeclarationOrder(t *testing.T) { + first := ProviderRoute{ + ID: "openai-a", + Models: []string{"gpt-4o"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "a.example.com", + } + second := ProviderRoute{ + ID: "openai-b", + Models: []string{"gpt-4o-mini"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "b.example.com", + } + mw := New(Config{Providers: []ProviderRoute{first, second}}) + + out, err := mw.Invoke(context.Background(), newModellessInput("/v1/models")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "model-less path must pass through with multiple providers") + provider, _ := metaValue(t, out.Metadata, middleware.KeyLLMResolvedProviderID) + assert.Equal(t, "openai-a", provider, "no path-prefix match falls back to declaration order") +} + +func TestRouter_ModelLessPath_UnauthorisedDenies(t *testing.T) { + mw := New(Config{Providers: []ProviderRoute{{ + ID: "openai-prod", + Models: []string{"gpt-4o"}, + AllowedGroupIDs: []string{"some-other-group"}, + UpstreamScheme: "https", + UpstreamHost: "api.openai.com", + }}}) + + out, err := mw.Invoke(context.Background(), newModellessInput("/v1/models")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionDeny, out.Decision, "no provider authorising the caller must still deny") +} + +func TestRouter_NonModelLessBodilessStillDenies(t *testing.T) { + mw := New(Config{Providers: []ProviderRoute{{ + ID: "openai-prod", + Models: []string{"gpt-4o"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.openai.com", + }}}) + + // A bodiless POST to an inference path has no model and is NOT a + // model-less endpoint, so it must keep denying. + out, err := mw.Invoke(context.Background(), newModellessInput("/v1/responses")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionDeny, out.Decision, "bodiless inference request must still deny") + assert.Equal(t, "llm_policy.model_not_routable", out.DenyReason.Code, "deny code stays model_not_routable") +} + +// TestRouter_ExplicitModelBeatsCatchallGateway is the regression guard +// for multi-provider misrouting: a catch-all (empty Models) OpenAI-compat +// gateway declared first must NOT swallow a model an explicit provider +// claims. Anthropic's claude request must reach the Anthropic route even +// though the gateway claims every model and wins declaration order. +func TestRouter_ExplicitModelBeatsCatchallGateway(t *testing.T) { + gateway := ProviderRoute{ + ID: "openai-gateway", + Models: nil, // catch-all: claims every model + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.openai.com", + } + anthropic := ProviderRoute{ + ID: "anthropic-prod", + Models: []string{"claude-opus-4"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.anthropic.com", + } + // Gateway declared first to prove explicit claim beats declaration order. + mw := New(Config{Providers: []ProviderRoute{gateway, anthropic}}) + + out, err := mw.Invoke(context.Background(), newInputWithModelAndURL("claude-opus-4", "/v1/messages")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "explicit-model request must route, not deny") + require.NotNil(t, out.Mutations) + require.NotNil(t, out.Mutations.RewriteUpstream) + assert.Equal(t, "api.anthropic.com", out.Mutations.RewriteUpstream.Host, "claude must reach the explicit Anthropic route, not the catch-all gateway") + + provider, _ := metaValue(t, out.Metadata, middleware.KeyLLMResolvedProviderID) + assert.Equal(t, "anthropic-prod", provider, "resolved provider must be the explicit Anthropic route") +} + +// TestRouter_CatchallStillServesUnlistedModel confirms the catch-all +// gateway still wins models no explicit provider claims (its whole point). +func TestRouter_CatchallStillServesUnlistedModel(t *testing.T) { + gateway := ProviderRoute{ + ID: "openai-gateway", + Models: nil, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "gateway.example.com", + } + anthropic := ProviderRoute{ + ID: "anthropic-prod", + Models: []string{"claude-opus-4"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.anthropic.com", + } + mw := New(Config{Providers: []ProviderRoute{gateway, anthropic}}) + + out, err := mw.Invoke(context.Background(), newInputWithModelAndURL("some-exotic-model", "/v1/chat/completions")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "unlisted model must still route via the catch-all") + require.NotNil(t, out.Mutations.RewriteUpstream) + assert.Equal(t, "gateway.example.com", out.Mutations.RewriteUpstream.Host, "unlisted model falls to the catch-all gateway") +} + +// newInputVendorModelURL returns an Input carrying both the detected +// vendor (llm.provider) and the model, plus a request URL — mimicking the +// post-llm_request_parser state for a real inference call. +func newInputVendorModelURL(vendor, model, reqURL string) *middleware.Input { + return &middleware.Input{ + Slot: middleware.SlotOnRequest, + URL: reqURL, + Metadata: []middleware.KV{ + {Key: middleware.KeyLLMProvider, Value: vendor}, + {Key: middleware.KeyLLMModel, Value: model}, + }, + UserGroups: []string{defaultTestGroup}, + } +} + +// TestRouter_VendorKeepsAnthropicOffOpenAIGateway is the regression guard +// for the reported multi-provider break: two catch-all providers (neither +// enumerates models), the OpenAI one declared first. Without vendor +// awareness, a claude request matches both, no path prefixes, and +// declaration order sends it to OpenAI → 502. The detected vendor must +// pin it to the Anthropic route. +func TestRouter_VendorKeepsAnthropicOffOpenAIGateway(t *testing.T) { + openai := ProviderRoute{ + ID: "openai-gw", + Vendor: "openai", + Models: nil, // catch-all + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.openai.com", + } + anthropic := ProviderRoute{ + ID: "anthropic-gw", + Vendor: "anthropic", + Models: nil, // catch-all + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.anthropic.com", + } + mw := New(Config{Providers: []ProviderRoute{openai, anthropic}}) // openai first + + out, err := mw.Invoke(context.Background(), newInputVendorModelURL("anthropic", "claude-opus-4-8", "/v1/messages")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "claude request must route, not deny") + require.NotNil(t, out.Mutations.RewriteUpstream) + assert.Equal(t, "api.anthropic.com", out.Mutations.RewriteUpstream.Host, "anthropic vendor must pin to the anthropic route despite openai being declared first") + + provider, _ := metaValue(t, out.Metadata, middleware.KeyLLMResolvedProviderID) + assert.Equal(t, "anthropic-gw", provider) +} + +// TestRouter_VendorKeepsOpenAIOffAnthropic is the reciprocal: an OpenAI +// request must stay on the OpenAI route even when the Anthropic catch-all +// is declared first. +func TestRouter_VendorKeepsOpenAIOffAnthropic(t *testing.T) { + anthropic := ProviderRoute{ + ID: "anthropic-gw", + Vendor: "anthropic", + Models: nil, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.anthropic.com", + } + openai := ProviderRoute{ + ID: "openai-gw", + Vendor: "openai", + Models: nil, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.openai.com", + } + mw := New(Config{Providers: []ProviderRoute{anthropic, openai}}) // anthropic first + + out, err := mw.Invoke(context.Background(), newInputVendorModelURL("openai", "gpt-5.5", "/v1/responses")) + require.NoError(t, err) + require.NotNil(t, out) + require.NotNil(t, out.Mutations.RewriteUpstream) + assert.Equal(t, "api.openai.com", out.Mutations.RewriteUpstream.Host, "openai vendor must pin to the openai route despite anthropic being declared first") +} + +// TestRouter_VendorAbsentFallsBackToModelPath confirms vendor filtering is +// inert when the request carries no detected vendor: routing then relies on +// model/path as before. +func TestRouter_VendorAbsentFallsBackToModelPath(t *testing.T) { + openai := ProviderRoute{ + ID: "openai-gw", + Vendor: "openai", + Models: []string{"gpt-5.5"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.openai.com", + } + mw := New(Config{Providers: []ProviderRoute{openai}}) + + // No llm.provider in metadata — only the model. + out, err := mw.Invoke(context.Background(), newInputWithModel("gpt-5.5")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "explicit-model match must still route with no vendor present") +} + +func TestRouter_UnknownModel(t *testing.T) { + mw := New(Config{Providers: []ProviderRoute{{ + ID: "openai-prod", + Models: []string{"gpt-4o"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.openai.com", + }}}) + + out, err := mw.Invoke(context.Background(), newInputWithModel("claude-opus-4")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionDeny, out.Decision, "unrouted model must deny") + assert.Equal(t, 403, out.DenyStatus, "deny status must be 403") + require.NotNil(t, out.DenyReason, "deny reason must be populated") + assert.Equal(t, "llm_policy.model_not_routable", out.DenyReason.Code, "deny code must be model_not_routable") + assert.Equal(t, "no provider configured for model claude-opus-4", out.DenyReason.Message, "deny message must reference the offending model") + assert.Equal(t, "claude-opus-4", out.DenyReason.Details["model"], "deny details must include the offending model") +} + +func TestRouter_HeaderStripList(t *testing.T) { + mw := New(Config{Providers: []ProviderRoute{{ + ID: "openai-prod", + Models: []string{"gpt-4o"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.openai.com", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer sk-test-123", + }}}) + + out, err := mw.Invoke(context.Background(), newInputWithModel("gpt-4o")) + require.NoError(t, err) + require.NotNil(t, out) + require.NotNil(t, out.Mutations, "matched route must emit mutations") + + expected := []string{ + "Authorization", + "Proxy-Authorization", + "x-api-key", + "api-key", + } + require.NotNil(t, out.Mutations.RewriteUpstream, "matched route must emit upstream rewrite") + for _, header := range expected { + assert.Contains(t, out.Mutations.RewriteUpstream.StripHeaders, header, + "strip list (on UpstreamRewrite) must include the well-known vendor auth header %s", header) + } + + // Vendor metadata headers MUST NOT be stripped: the client SDK sets them + // and the upstream requires them. Anthropic returns 400 "anthropic-version: + // header is required" if we drop it. Lock the regression. + preserved := []string{"anthropic-version", "openai-organization", "openai-project"} + for _, header := range preserved { + assert.NotContains(t, out.Mutations.RewriteUpstream.StripHeaders, header, + "vendor metadata header %s must NOT be stripped — upstreams require it", header) + } +} + +func TestRouter_FirstMatchWins(t *testing.T) { + first := ProviderRoute{ + ID: "first", + Models: []string{"gpt-4o"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "first.test", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer first", + } + second := ProviderRoute{ + ID: "second", + Models: []string{"gpt-4o"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "second.test", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer second", + } + mw := New(Config{Providers: []ProviderRoute{first, second}}) + + out, err := mw.Invoke(context.Background(), newInputWithModel("gpt-4o")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "duplicate-model match must still allow") + require.NotNil(t, out.Mutations, "matched route must emit mutations") + require.NotNil(t, out.Mutations.RewriteUpstream, "matched route must emit upstream rewrite") + assert.Equal(t, "first.test", out.Mutations.RewriteUpstream.Host, "first-match-wins must pick the earlier route") + + resolved, _ := metaValue(t, out.Metadata, middleware.KeyLLMResolvedProviderID) + assert.Equal(t, "first", resolved, "resolved provider id must be the earlier route's ID") +} + +// TestRouter_PathDisambiguation_PrefixWinsOverCatchall locks in the +// rule the user nailed down: two providers claim the same model, one +// has an UpstreamPath that prefixes the incoming URL, the other has +// no path. The path-prefixed provider wins because the path is a +// strictly more specific match than the empty catchall. +func TestRouter_PathDisambiguation_PrefixWinsOverCatchall(t *testing.T) { + corp := ProviderRoute{ + ID: "corp-openai-compat", + Models: []string{"gpt-5"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "corp.example.com", + UpstreamPath: "/openai", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer corp", + } + openai := ProviderRoute{ + ID: "openai", + Models: []string{"gpt-5"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.openai.com", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer openai", + } + mw := New(Config{Providers: []ProviderRoute{openai, corp}}) // openai listed first to prove path beats declaration order + + out, err := mw.Invoke(context.Background(), newInputWithModelAndURL("gpt-5", "/openai/v1/chat/completions")) + require.NoError(t, err) + require.NotNil(t, out) + require.Equal(t, middleware.DecisionAllow, out.Decision, "path-prefix match must allow") + require.NotNil(t, out.Mutations.RewriteUpstream) + assert.Equal(t, "corp.example.com", out.Mutations.RewriteUpstream.Host, + "path-prefixed provider must beat the catchall when its UpstreamPath is a prefix of the request path") + resolved, _ := metaValue(t, out.Metadata, middleware.KeyLLMResolvedProviderID) + assert.Equal(t, "corp-openai-compat", resolved, "resolved provider id must reflect the path-prefix winner, not the first declared") +} + +// TestRouter_PathDisambiguation_CatchallWhenNoPrefixMatches is the +// inverse: the path-prefixed provider does NOT match the incoming +// path, so the empty-path catchall takes the request. +func TestRouter_PathDisambiguation_CatchallWhenNoPrefixMatches(t *testing.T) { + corp := ProviderRoute{ + ID: "corp-openai-compat", + Models: []string{"gpt-5"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "corp.example.com", + UpstreamPath: "/openai", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer corp", + } + openai := ProviderRoute{ + ID: "openai", + Models: []string{"gpt-5"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "api.openai.com", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer openai", + } + mw := New(Config{Providers: []ProviderRoute{corp, openai}}) + + out, err := mw.Invoke(context.Background(), newInputWithModelAndURL("gpt-5", "/v1/chat/completions")) + require.NoError(t, err) + require.NotNil(t, out) + require.Equal(t, middleware.DecisionAllow, out.Decision, "catchall must allow when no path prefix matches") + require.NotNil(t, out.Mutations.RewriteUpstream) + assert.Equal(t, "api.openai.com", out.Mutations.RewriteUpstream.Host, + "empty-path catchall must win when the path-prefixed provider's UpstreamPath does not match the request") + resolved, _ := metaValue(t, out.Metadata, middleware.KeyLLMResolvedProviderID) + assert.Equal(t, "openai", resolved, "resolved provider id must be the catchall") +} + +// TestRouter_PathDisambiguation_LongestPrefixWins covers the case +// where multiple providers have non-empty UpstreamPath values that +// both prefix the request — the longer (more specific) one wins. +func TestRouter_PathDisambiguation_LongestPrefixWins(t *testing.T) { + short := ProviderRoute{ + ID: "short-prefix", + Models: []string{"gpt-5"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "short.example.com", + UpstreamPath: "/openai", + } + long := ProviderRoute{ + ID: "long-prefix", + Models: []string{"gpt-5"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "long.example.com", + UpstreamPath: "/openai/v1", + } + mw := New(Config{Providers: []ProviderRoute{short, long}}) + + out, err := mw.Invoke(context.Background(), newInputWithModelAndURL("gpt-5", "/openai/v1/chat/completions")) + require.NoError(t, err) + require.NotNil(t, out) + require.NotNil(t, out.Mutations.RewriteUpstream) + assert.Equal(t, "long.example.com", out.Mutations.RewriteUpstream.Host, + "longest matching UpstreamPath must win — most specific match") +} + +// TestRouter_SingleMatchIgnoresPath proves the path-prefix rule is a +// disambiguation pass, not a gate: when only one provider claims the +// model, it wins regardless of UpstreamPath. Otherwise a path-scoped +// provider would 403 every request whose URL doesn't include the +// path, which would break SDKs configured to hit the gateway root. +func TestRouter_SingleMatchIgnoresPath(t *testing.T) { + only := ProviderRoute{ + ID: "only", + Models: []string{"gpt-5"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "only.example.com", + UpstreamPath: "/openai", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer only", + } + mw := New(Config{Providers: []ProviderRoute{only}}) + + out, err := mw.Invoke(context.Background(), newInputWithModelAndURL("gpt-5", "/v1/chat/completions")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, + "single model-matching provider must serve the request even when UpstreamPath doesn't prefix the URL — path is a tiebreaker, not a gate") + require.NotNil(t, out.Mutations.RewriteUpstream) + assert.Equal(t, "only.example.com", out.Mutations.RewriteUpstream.Host, "the only model-matching provider should be selected") +} + +// TestRouter_PathDisambiguation_FallbackWhenNoPrefixMatches covers +// the multi-candidate edge case where every candidate has a +// non-matching non-empty UpstreamPath. The router falls back to +// declaration order so the model is still routable rather than 403'd. +func TestRouter_PathDisambiguation_FallbackWhenNoPrefixMatches(t *testing.T) { + first := ProviderRoute{ + ID: "first", + Models: []string{"gpt-5"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "first.example.com", + UpstreamPath: "/openai", + } + second := ProviderRoute{ + ID: "second", + Models: []string{"gpt-5"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "second.example.com", + UpstreamPath: "/anthropic", + } + mw := New(Config{Providers: []ProviderRoute{first, second}}) + + out, err := mw.Invoke(context.Background(), newInputWithModelAndURL("gpt-5", "/v1/chat/completions")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "no path match among multi-candidates must still allow") + require.NotNil(t, out.Mutations.RewriteUpstream) + assert.Equal(t, "first.example.com", out.Mutations.RewriteUpstream.Host, + "when no candidate's UpstreamPath prefix-matches the request, fall back to declaration order") +} + +func TestRouter_FactoryRejectsBadJSON(t *testing.T) { + _, err := Factory{}.New([]byte("{not json")) + require.Error(t, err, "malformed JSON config must be rejected at chain build time") +} + +func TestRouter_FactoryAcceptsEmptyShapes(t *testing.T) { + cases := [][]byte{nil, []byte(""), []byte(" "), []byte("null"), []byte("{}"), []byte("[]")} + for _, raw := range cases { + mw, err := Factory{}.New(raw) + require.NoError(t, err, "empty-shaped config must yield a router with an empty Providers slice") + require.NotNil(t, mw, "factory must return a non-nil middleware on empty config") + + out, invErr := mw.Invoke(context.Background(), newInputWithModel("gpt-4o")) + require.NoError(t, invErr) + assert.Equal(t, middleware.DecisionDeny, out.Decision, + "router with no providers must deny every model as not-routable") + } +} + +// newInputWithModelAndGroups returns an Input carrying llm.model + the +// caller's UserGroups, mimicking the post-auth, post-llm_request_parser +// state the router observes. +func newInputWithModelAndGroups(model string, groups []string) *middleware.Input { + in := newInputWithModel(model) + in.UserGroups = append([]string(nil), groups...) + return in +} + +// TestRouter_GroupFilter_PicksAuthorisedAmongDuplicates pins the Fix A +// behaviour: when two providers claim the same model but each +// authorises a different group, the router must pick the route the +// caller's groups intersect, regardless of declaration order. +func TestRouter_GroupFilter_PicksAuthorisedAmongDuplicates(t *testing.T) { + mw := New(Config{Providers: []ProviderRoute{ + { + ID: "openai-marketing", + Models: []string{"gpt-4o-mini"}, + UpstreamScheme: "https", + UpstreamHost: "mkt-openai.example.com", + AllowedGroupIDs: []string{"grp-mkt"}, + }, + { + ID: "openai-engineering", + Models: []string{"gpt-4o-mini"}, + UpstreamScheme: "https", + UpstreamHost: "eng-openai.example.com", + AllowedGroupIDs: []string{"grp-eng"}, + }, + }}) + + out, err := mw.Invoke(context.Background(), + newInputWithModelAndGroups("gpt-4o-mini", []string{"grp-eng"})) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, + "authorised candidate exists; must allow") + + resolved, ok := metaValue(t, out.Metadata, middleware.KeyLLMResolvedProviderID) + require.True(t, ok) + assert.Equal(t, "openai-engineering", resolved, + "router must pick the route whose AllowedGroupIDs intersects the caller's groups, ignoring declaration order") +} + +// TestRouter_GroupFilter_NoIntersection_DeniesNoAuthorisedRoute pins +// the dedicated deny code that fires when the model is known to a +// provider but no candidate is authorised for the caller's groups. +func TestRouter_GroupFilter_NoIntersection_DeniesNoAuthorisedRoute(t *testing.T) { + mw := New(Config{Providers: []ProviderRoute{{ + ID: "openai-marketing", + Models: []string{"gpt-4o-mini"}, + UpstreamScheme: "https", + UpstreamHost: "mkt-openai.example.com", + AllowedGroupIDs: []string{"grp-mkt"}, + }}}) + + out, err := mw.Invoke(context.Background(), + newInputWithModelAndGroups("gpt-4o-mini", []string{"grp-eng"})) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionDeny, out.Decision, + "model exists but no route authorises grp-eng; must deny") + require.NotNil(t, out.DenyReason) + assert.Equal(t, "llm_policy.no_authorised_provider", out.DenyReason.Code, + "deny code must be no_authorised_provider, not model_not_routable") + assert.Equal(t, "gpt-4o-mini", out.DenyReason.Details["model"], + "deny details must reference the offending model") + + dec, _ := metaValue(t, out.Metadata, middleware.KeyLLMPolicyDecision) + assert.Equal(t, "deny", dec) + reason, _ := metaValue(t, out.Metadata, middleware.KeyLLMPolicyReason) + assert.Equal(t, "no_authorised_provider", reason) +} + +// TestRouter_GroupFilter_EmptyAllowedGroupsIsUnreachable pins the +// strict semantics: a route with no AllowedGroupIDs is unreachable. +// The synthesiser only emits policy-bound routes, so an empty ACL +// signals a misconfiguration that must not silently fall through. +func TestRouter_GroupFilter_EmptyAllowedGroupsIsUnreachable(t *testing.T) { + mw := New(Config{Providers: []ProviderRoute{{ + ID: "openai-shared", + Models: []string{"gpt-4o"}, + UpstreamScheme: "https", + UpstreamHost: "api.openai.com", + // AllowedGroupIDs intentionally left empty. + }}}) + + out, err := mw.Invoke(context.Background(), newInputWithModel("gpt-4o")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionDeny, out.Decision, + "empty AllowedGroupIDs must deny — there is no catch-all for routes without an authorising policy") + require.NotNil(t, out.DenyReason) + assert.Equal(t, "llm_policy.no_authorised_provider", out.DenyReason.Code, + "empty ACL fails the group-filter pass; deny code must reflect that") +} + +// TestRouter_GroupFilter_OverlapTiebreakUnchanged pins that when more +// than one route is authorised for the caller's groups, the existing +// path-prefix tiebreak still decides. Group filtering is a hard gate +// before the tiebreak; it does not change the tiebreak semantics. +func TestRouter_GroupFilter_OverlapTiebreakUnchanged(t *testing.T) { + mw := New(Config{Providers: []ProviderRoute{ + { + ID: "openai-a", + Models: []string{"gpt-4o-mini"}, + UpstreamScheme: "https", + UpstreamHost: "a.example.com", + UpstreamPath: "", + AllowedGroupIDs: []string{"grp-eng"}, + }, + { + ID: "openai-b", + Models: []string{"gpt-4o-mini"}, + UpstreamScheme: "https", + UpstreamHost: "b.example.com", + UpstreamPath: "/v1/chat", + AllowedGroupIDs: []string{"grp-eng"}, + }, + }}) + + in := newInputWithModelAndURL("gpt-4o-mini", "/v1/chat/completions") + in.UserGroups = []string{"grp-eng"} + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision) + + resolved, _ := metaValue(t, out.Metadata, middleware.KeyLLMResolvedProviderID) + assert.Equal(t, "openai-b", resolved, + "longest-prefix path tiebreak still wins among group-authorised candidates") +} + +// TestRouter_AuthorisingGroups_EmitsIntersection pins that the router +// emits llm.authorising_groups containing only the intersection of the +// caller's UserGroups with the resolved route's AllowedGroupIDs — not +// every group the peer happens to be in. +func TestRouter_AuthorisingGroups_EmitsIntersection(t *testing.T) { + mw := New(Config{Providers: []ProviderRoute{{ + ID: "openai-eng", + Models: []string{"gpt-4o-mini"}, + UpstreamScheme: "https", + UpstreamHost: "eng-openai.example.com", + AllowedGroupIDs: []string{"grp-eng", "grp-shared"}, + }}}) + + in := newInputWithModelAndGroups("gpt-4o-mini", + []string{"grp-eng", "grp-it", "grp-shared", "grp-oncall"}) + + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + require.Equal(t, middleware.DecisionAllow, out.Decision) + + csv, ok := metaValue(t, out.Metadata, middleware.KeyLLMAuthorisingGroups) + require.True(t, ok, "router must emit llm.authorising_groups on a match") + assert.Equal(t, "grp-eng,grp-shared", csv, + "only groups in BOTH UserGroups AND AllowedGroupIDs may appear; result must be sorted and unique") +} + +// TestRouter_EmptyModelsClaimsAnyModel pins that a route with no +// configured Models matches every model — used by gateway-style +// providers (LiteLLM, custom OpenAI-compatible endpoints) where the +// operator can't enumerate the upstream's model catalog in NetBird. +func TestRouter_EmptyModelsClaimsAnyModel(t *testing.T) { + mw := New(Config{Providers: []ProviderRoute{{ + ID: "litellm", + Models: nil, // catch-all + UpstreamScheme: "https", + UpstreamHost: "litellm.example.com", + AllowedGroupIDs: []string{defaultTestGroup}, + }}}) + + out, err := mw.Invoke(context.Background(), newInputWithModel("gpt-5.5")) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, middleware.DecisionAllow, out.Decision, + "a route with empty Models must claim any model so gateway-style providers can route open-ended sets") + resolved, _ := metaValue(t, out.Metadata, middleware.KeyLLMResolvedProviderID) + assert.Equal(t, "litellm", resolved) +} diff --git a/proxy/internal/middleware/builtin/llm_router/path_routed_test.go b/proxy/internal/middleware/builtin/llm_router/path_routed_test.go new file mode 100644 index 000000000..7dbd6b936 --- /dev/null +++ b/proxy/internal/middleware/builtin/llm_router/path_routed_test.go @@ -0,0 +1,159 @@ +package llm_router + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/middleware" +) + +// pathRoutedInput builds an Input mimicking the post-llm_request_parser state +// for a path-routed (Vertex/Bedrock) request: a request URL plus the model and +// (optionally) provider/vendor metadata the parser emits. +func pathRoutedInput(url, provider, model string) *middleware.Input { + md := []middleware.KV{{Key: middleware.KeyLLMModel, Value: model}} + if provider != "" { + md = append(md, middleware.KV{Key: middleware.KeyLLMProvider, Value: provider}) + } + return &middleware.Input{ + Slot: middleware.SlotOnRequest, + URL: url, + Metadata: md, + UserGroups: []string{defaultTestGroup}, + } +} + +func vertexRoute() ProviderRoute { + return ProviderRoute{ + ID: "vertex-prod", Vertex: true, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "europe-west1-aiplatform.googleapis.com", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer x", + } +} + +// A Vertex publisher with no parser surface (google/gemini emits no +// llm.provider) must be denied, not forwarded unmetered. +func TestRouter_VertexUnmeterablePublisherDenied(t *testing.T) { + mw := New(Config{Providers: []ProviderRoute{vertexRoute()}}) + in := pathRoutedInput( + "/v1/projects/p/locations/global/publishers/google/models/gemini-2.5-pro:generateContent", + "", // google -> request parser emits NO llm.provider + "gemini-2.5-pro", + ) + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + assert.Equal(t, middleware.DecisionDeny, out.Decision, "unmeterable Vertex publisher must deny") + assert.Equal(t, 403, out.DenyStatus, "unmeterable deny is a 403") + require.NotNil(t, out.DenyReason) + assert.Equal(t, denyCodeUnmeterable, out.DenyReason.Code, "deny code must flag the unmeterable publisher") +} + +// A Vertex publisher with a parser surface (anthropic) is allowed. +func TestRouter_VertexMeterablePublisherAllowed(t *testing.T) { + mw := New(Config{Providers: []ProviderRoute{vertexRoute()}}) + in := pathRoutedInput( + "/v1/projects/p/locations/global/publishers/anthropic/models/claude-sonnet-4-5:rawPredict", + "anthropic", + "claude-sonnet-4-5", + ) + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "meterable Vertex publisher must allow") +} + +// A path-routed provider with an explicit Models list must reject models not in +// the list (the provider credential can't be used for unauthorised models). +func TestRouter_PathRoutedModelAllowlistEnforced(t *testing.T) { + route := ProviderRoute{ + ID: "bedrock-prod", Bedrock: true, + Models: []string{"anthropic.claude-sonnet-4-5"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "bedrock-runtime.eu-central-1.amazonaws.com", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer x", + } + mw := New(Config{Providers: []ProviderRoute{route}}) + + allowed := pathRoutedInput( + "/model/eu.anthropic.claude-sonnet-4-5-20250929-v1:0/invoke", + "bedrock", "anthropic.claude-sonnet-4-5", + ) + out, err := mw.Invoke(context.Background(), allowed) + require.NoError(t, err) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "model in the allowlist must be served") + + denied := pathRoutedInput( + "/model/amazon.nova-pro-v1:0/invoke", + "bedrock", "amazon.nova-pro", + ) + out, err = mw.Invoke(context.Background(), denied) + require.NoError(t, err) + assert.Equal(t, middleware.DecisionDeny, out.Decision, "model outside the allowlist must deny") + require.NotNil(t, out.DenyReason) + assert.Equal(t, denyCodeNotRoutable, out.DenyReason.Code, "unlisted model denies as not-routable") +} + +// A "/bedrock" gateway-namespace prefix routes the same as the native path and +// records the prefix on the rewrite so the proxy strips it before forwarding. +func TestRouter_BedrockNamespacePrefixStripped(t *testing.T) { + route := ProviderRoute{ + ID: "bedrock-prod", Bedrock: true, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "bedrock-runtime.eu-central-1.amazonaws.com", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer x", + } + mw := New(Config{Providers: []ProviderRoute{route}}) + + prefixed := pathRoutedInput( + "/bedrock/model/eu.anthropic.claude-sonnet-4-5-20250929-v1:0/invoke-with-response-stream", + "bedrock", "anthropic.claude-sonnet-4-5", + ) + out, err := mw.Invoke(context.Background(), prefixed) + require.NoError(t, err) + require.Equal(t, middleware.DecisionAllow, out.Decision, "prefixed Bedrock path must route") + require.NotNil(t, out.Mutations) + require.NotNil(t, out.Mutations.RewriteUpstream) + assert.Equal(t, "/bedrock", out.Mutations.RewriteUpstream.StripPathPrefix, + "namespace prefix must be recorded so the proxy strips it before forwarding") + + native := pathRoutedInput( + "/model/eu.anthropic.claude-sonnet-4-5-20250929-v1:0/invoke", + "bedrock", "anthropic.claude-sonnet-4-5", + ) + out, err = mw.Invoke(context.Background(), native) + require.NoError(t, err) + require.Equal(t, middleware.DecisionAllow, out.Decision, "native Bedrock path must route") + require.NotNil(t, out.Mutations.RewriteUpstream) + assert.Empty(t, out.Mutations.RewriteUpstream.StripPathPrefix, + "native path carries no namespace prefix to strip") +} + +// A path-routed provider with no configured Models is catch-all: any model the +// credential can reach is served (preserves the zero-config behaviour). +func TestRouter_PathRoutedCatchAllServesAnyModel(t *testing.T) { + route := ProviderRoute{ + ID: "bedrock-catchall", Bedrock: true, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "bedrock-runtime.eu-central-1.amazonaws.com", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer x", + } + mw := New(Config{Providers: []ProviderRoute{route}}) + in := pathRoutedInput( + "/model/amazon.nova-pro-v1:0/invoke", + "bedrock", "amazon.nova-pro", + ) + out, err := mw.Invoke(context.Background(), in) + require.NoError(t, err) + assert.Equal(t, middleware.DecisionAllow, out.Decision, "catch-all path-routed provider serves any model") +} diff --git a/proxy/internal/middleware/chain.go b/proxy/internal/middleware/chain.go new file mode 100644 index 000000000..45d32cdb0 --- /dev/null +++ b/proxy/internal/middleware/chain.go @@ -0,0 +1,320 @@ +package middleware + +import ( + "context" + "net/http" + "sync" +) + +// boundMiddleware pairs a validated spec with the resolved middleware +// instance the chain will invoke. +type boundMiddleware struct { + spec Spec + mw Middleware +} + +// Chain is the ordered set of middlewares that run for a specific +// target. Chains are immutable once built; Manager produces a new +// Chain on every Rebuild. +// +// Ordering: middlewares are kept in registration order. RunRequest +// iterates the SlotOnRequest middlewares in order; RunResponse +// iterates the SlotOnResponse middlewares in reverse order +// (middleware-style LIFO so the last to see the request is the first +// to see the response); RunTerminal iterates the SlotTerminal +// middlewares in registration order, after every on_response slot has +// emitted, so the metadata bag they observe is complete. +// +// Close drains in-flight invocations and tears down each middleware. +// Callers swapping a chain via Manager invoke Close on the old chain +// after the swap so live requests finish on the previous instance. +type Chain struct { + targetID string + all []boundMiddleware + onRequest []int + onResponse []int + terminal []int + dispatcher *Dispatcher + inflight sync.WaitGroup +} + +// NewChain assembles a Chain from the bound middlewares. The slice +// order is the registration order; the chain captures index slices +// per slot so iteration does not re-scan the slot field per call. +func NewChain(targetID string, bound []boundMiddleware, d *Dispatcher) *Chain { + c := &Chain{ + targetID: targetID, + all: bound, + dispatcher: d, + } + for i, bm := range bound { + switch bm.spec.Slot { + case SlotOnRequest: + c.onRequest = append(c.onRequest, i) + case SlotOnResponse: + c.onResponse = append(c.onResponse, i) + case SlotTerminal: + c.terminal = append(c.terminal, i) + } + } + return c +} + +// Close waits for outstanding invocations against this chain to +// finish (bounded by ctx) and releases the middleware instances bound +// to it. Safe to call once the chain has been removed from the +// routing snapshot. Subsequent Run* calls are still safe (return +// without invoking) but Close itself is one-shot. +func (c *Chain) Close(ctx context.Context) error { + if c == nil { + return nil + } + if ctx == nil { + ctx = context.Background() + } + done := make(chan struct{}) + go func() { + c.inflight.Wait() + close(done) + }() + select { + case <-done: + case <-ctx.Done(): + // Drain timed out: requests may still be running against these + // middleware instances, so tearing them down now risks a + // use-after-close. Leave them (a bounded leak) and surface the + // timeout; the runaway backstop in the Manager already alerts. + return ctx.Err() + } + for _, bm := range c.all { + if bm.mw == nil { + continue + } + if err := bm.mw.Close(); err != nil { + c.dispatcher.logger.Debugf("middleware %s close: %v", bm.spec.ID, err) + } + } + return nil +} + +// Empty reports whether the chain has no middlewares. +func (c *Chain) Empty() bool { + return c == nil || len(c.all) == 0 +} + +// TargetID returns the key used to find this chain. +func (c *Chain) TargetID() string { + if c == nil { + return "" + } + return c.targetID +} + +// IDs returns the ordered list of middleware IDs bound to this chain. +func (c *Chain) IDs() []string { + if c == nil { + return nil + } + out := make([]string, len(c.all)) + for i, bm := range c.all { + out[i] = bm.spec.ID + } + return out +} + +// RunRequest iterates the on_request slot in registration order. Deny +// short-circuits the remaining middlewares and returns the deny +// output. The caller owns applying mutations to the real request and +// merging the metadata returned in `merged` into the captured-data +// bag passed to subsequent slots. +// +// Each middleware sees the metadata emitted by earlier middlewares in +// the same slot — this is how llm_guardrail reads +// llm.request_prompt_raw from llm_request_parser without a side +// channel, and how cost_meter reads tokens emitted by +// llm_response_parser on the response leg. +// +// If any middleware emits a non-nil Mutations.RewriteUpstream while +// satisfying the mutation gates (CanMutate && MutationsSupported), the +// latest such value is returned to the caller. Last-write-wins so the +// last middleware in the slot can override an earlier rewrite. +func (c *Chain) RunRequest(ctx context.Context, r *http.Request, in *Input, acc *Accumulator) (denied *Output, merged []KV, rewrite *UpstreamRewrite, err error) { + if c.Empty() || len(c.onRequest) == 0 { + return nil, nil, nil, nil + } + c.inflight.Add(1) + defer c.inflight.Done() + running := append([]KV(nil), in.Metadata...) + for _, idx := range c.onRequest { + bm := c.all[idx] + call := cloneInputFor(in, SlotOnRequest) + call.Metadata = append([]KV(nil), running...) + out, invErr := c.dispatcher.Invoke(ctx, bm.spec, bm.mw, call) + if invErr != nil && out == nil { + continue + } + if out == nil { + continue + } + + accepted, rejected := acc.Emit(bm.spec.ID, bm.spec.MetadataKeys, out.Metadata) + for _, rej := range rejected { + c.dispatcher.metrics.IncMetadataRejected(ctx, bm.spec.ID, rej.Reason) + } + merged = append(merged, accepted...) + running = append(running, accepted...) + + if out.Decision == DecisionDeny { + c.dispatcher.metrics.IncRequest(ctx, bm.spec.ID, c.targetID, "deny") + return out, merged, rewrite, nil + } + c.dispatcher.metrics.IncRequest(ctx, bm.spec.ID, c.targetID, "allow") + + if rw := mutationRewrite(bm.spec, out.Mutations); rw != nil { + rewrite = rw + } + if r != nil && bm.spec.CanMutate && out.Mutations != nil { + applyMutations(ctx, c.dispatcher, bm.spec, r, out.Mutations) + } + } + return nil, merged, rewrite, nil +} + +// RunResponse iterates the on_response slot in reverse registration +// order, matching the middleware "last in, first out" convention so +// the last middleware to see the request is the first to see the +// response. Middlewares cannot deny; they emit metadata. +// +// As with RunRequest, each middleware sees the metadata emitted by +// earlier middlewares in this slot — accumulated in the order the +// middlewares run (LIFO of registration). cost_meter relies on this +// to read llm.input_tokens / llm.output_tokens that +// llm_response_parser emitted just before it. +func (c *Chain) RunResponse(ctx context.Context, in *Input, acc *Accumulator) (merged []KV) { + if c.Empty() || len(c.onResponse) == 0 { + return nil + } + c.inflight.Add(1) + defer c.inflight.Done() + running := append([]KV(nil), in.Metadata...) + for i := len(c.onResponse) - 1; i >= 0; i-- { + bm := c.all[c.onResponse[i]] + call := cloneInputFor(in, SlotOnResponse) + call.Metadata = append([]KV(nil), running...) + out, _ := c.dispatcher.Invoke(ctx, bm.spec, bm.mw, call) + if out == nil { + continue + } + accepted, rejected := acc.Emit(bm.spec.ID, bm.spec.MetadataKeys, out.Metadata) + for _, rej := range rejected { + c.dispatcher.metrics.IncMetadataRejected(ctx, bm.spec.ID, rej.Reason) + } + merged = append(merged, accepted...) + running = append(running, accepted...) + c.dispatcher.metrics.IncRequest(ctx, bm.spec.ID, c.targetID, "passthrough") + } + return merged +} + +// RunTerminal iterates the terminal slot in registration order, after +// every on_response middleware has emitted. Terminal middlewares +// observe the full metadata bag carried in `in.Metadata` plus any +// emissions from terminal middlewares that ran before them; they +// cannot deny and cannot mutate. +func (c *Chain) RunTerminal(ctx context.Context, in *Input, acc *Accumulator) (merged []KV) { + if c.Empty() || len(c.terminal) == 0 { + return nil + } + c.inflight.Add(1) + defer c.inflight.Done() + running := append([]KV(nil), in.Metadata...) + for _, idx := range c.terminal { + bm := c.all[idx] + call := cloneInputFor(in, SlotTerminal) + call.Metadata = append([]KV(nil), running...) + out, _ := c.dispatcher.Invoke(ctx, bm.spec, bm.mw, call) + if out == nil { + continue + } + accepted, rejected := acc.Emit(bm.spec.ID, bm.spec.MetadataKeys, out.Metadata) + for _, rej := range rejected { + c.dispatcher.metrics.IncMetadataRejected(ctx, bm.spec.ID, rej.Reason) + } + merged = append(merged, accepted...) + running = append(running, accepted...) + c.dispatcher.metrics.IncRequest(ctx, bm.spec.ID, c.targetID, "terminal") + } + return merged +} + +// mutationRewrite returns the upstream rewrite carried in m when the +// spec's mutation gates allow it. The rewrite itself is not applied +// here; the caller (reverse proxy) decides whether to honour it. +func mutationRewrite(spec Spec, m *Mutations) *UpstreamRewrite { + if m == nil || m.RewriteUpstream == nil { + return nil + } + if !spec.CanMutate || !spec.MutationsSupported { + return nil + } + return m.RewriteUpstream +} + +func applyMutations(ctx context.Context, d *Dispatcher, spec Spec, r *http.Request, m *Mutations) { + if m == nil { + return + } + add, remove, blocked := FilterHeaderMutations(m) + for _, h := range blocked { + d.metrics.IncHeaderMutationBlocked(ctx, spec.ID, h) + } + for _, name := range remove { + r.Header.Del(name) + } + for _, kv := range add { + r.Header.Add(kv.Key, kv.Value) + } + if len(m.BodyReplace) == 0 { + return + } + if err := ValidateBodyReplace(r, m.BodyReplace, true); err != nil { + d.logger.Warnf("middleware %s body replace rejected: %v", spec.ID, err) + return + } + ApplyBodyReplace(r, m.BodyReplace) +} + +// cloneInputFor deep-copies the mutation-prone fields of Input so +// each middleware receives an isolated view. +func cloneInputFor(in *Input, slot Slot) *Input { + if in == nil { + return nil + } + out := *in + out.Slot = slot + out.Headers = cloneKVs(in.Headers) + out.RespHeaders = cloneKVs(in.RespHeaders) + out.Metadata = cloneKVs(in.Metadata) + if len(in.UserGroups) > 0 { + out.UserGroups = append([]string(nil), in.UserGroups...) + } + if len(in.UserGroupNames) > 0 { + out.UserGroupNames = append([]string(nil), in.UserGroupNames...) + } + if len(in.Body) > 0 { + out.Body = append([]byte(nil), in.Body...) + } + if len(in.RespBody) > 0 { + out.RespBody = append([]byte(nil), in.RespBody...) + } + return &out +} + +func cloneKVs(in []KV) []KV { + if len(in) == 0 { + return nil + } + out := make([]KV, len(in)) + copy(out, in) + return out +} diff --git a/proxy/internal/middleware/chain_test.go b/proxy/internal/middleware/chain_test.go new file mode 100644 index 000000000..929ccee08 --- /dev/null +++ b/proxy/internal/middleware/chain_test.go @@ -0,0 +1,370 @@ +package middleware + +import ( + "context" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeMiddleware is a minimal Middleware for chain composition tests. +// It records the metadata the dispatcher hands to it and emits a +// caller-supplied Output. Tests use the recorded snapshot to assert +// that earlier-in-slot emissions are visible to later middlewares. +type fakeMiddleware struct { + id string + slot Slot + keys []string + emit []KV + decision Decision + mutationsSupported bool + canMutate bool + mutations *Mutations + + // seen captures the in.Metadata snapshot the dispatcher passed to + // Invoke, so tests can assert ordering and visibility. + seen []KV +} + +func (f *fakeMiddleware) ID() string { return f.id } +func (f *fakeMiddleware) Version() string { return "test" } +func (f *fakeMiddleware) Slot() Slot { return f.slot } +func (f *fakeMiddleware) AcceptedContentTypes() []string { return nil } +func (f *fakeMiddleware) MetadataKeys() []string { return f.keys } +func (f *fakeMiddleware) MutationsSupported() bool { return f.mutationsSupported } +func (f *fakeMiddleware) Close() error { return nil } + +func (f *fakeMiddleware) Invoke(_ context.Context, in *Input) (*Output, error) { + f.seen = append([]KV(nil), in.Metadata...) + out := &Output{Decision: f.decision, Metadata: append([]KV(nil), f.emit...)} + if f.mutations != nil { + m := *f.mutations + out.Mutations = &m + } + return out, nil +} + +// chainFor builds a Chain over the given middlewares with a noop +// dispatcher. +func chainFor(t *testing.T, mws ...*fakeMiddleware) *Chain { + t.Helper() + bound := make([]boundMiddleware, len(mws)) + for i, mw := range mws { + bound[i] = boundMiddleware{ + spec: Spec{ + ID: mw.id, + Slot: mw.slot, + Enabled: true, + MetadataKeys: mw.keys, + CanMutate: mw.canMutate, + MutationsSupported: mw.mutationsSupported, + }, + mw: mw, + } + } + disp := NewDispatcher(nil, nil) + return NewChain("t-1", bound, disp) +} + +// TestChain_RunRequest_ThreadsMetadataAcrossMiddlewares locks that +// each on_request middleware sees metadata emitted by earlier +// middlewares in the same slot. Regression cover for the original +// chain.go where every iteration cloned from the same source `in` and +// later middlewares (e.g. llm_guardrail) couldn't read what the first +// (e.g. llm_request_parser) had just emitted. +func TestChain_RunRequest_ThreadsMetadataAcrossMiddlewares(t *testing.T) { + first := &fakeMiddleware{ + id: "first", + slot: SlotOnRequest, + keys: []string{"foo.k"}, + emit: []KV{{Key: "foo.k", Value: "v"}}, + } + second := &fakeMiddleware{ + id: "second", + slot: SlotOnRequest, + keys: []string{"bar.k"}, + emit: []KV{{Key: "bar.k", Value: "z"}}, + } + c := chainFor(t, first, second) + acc := NewAccumulator(0) + + denied, merged, rewrite, err := c.RunRequest(context.Background(), nil, &Input{}, acc) + require.NoError(t, err) + assert.Nil(t, denied, "no deny without DecisionDeny") + assert.Nil(t, rewrite, "no rewrite without Mutations.RewriteUpstream") + + require.Len(t, second.seen, 1, "the second middleware must observe one prior emission") + assert.Equal(t, "foo.k", second.seen[0].Key, "second middleware must see the first middleware's key") + assert.Equal(t, "v", second.seen[0].Value, "second middleware must see the first middleware's value") + + require.Len(t, merged, 2, "merged slice contains both middleware emissions") +} + +// TestChain_RunResponse_ThreadsMetadataAcrossMiddlewares does the +// same for the response slot. The response slot iterates in reverse +// registration order, so the middleware registered LAST runs first. +// This test asserts that a middleware running later (in reverse +// order) sees the metadata emitted by the one that ran before it. +func TestChain_RunResponse_ThreadsMetadataAcrossMiddlewares(t *testing.T) { + // Registration order: [outer, inner]. + // Reverse iteration runs inner first, outer second. + // outer must see inner's emission. + outer := &fakeMiddleware{ + id: "outer", + slot: SlotOnResponse, + keys: []string{"outer.k"}, + emit: []KV{{Key: "outer.k", Value: "o"}}, + } + inner := &fakeMiddleware{ + id: "inner", + slot: SlotOnResponse, + keys: []string{"inner.k"}, + emit: []KV{{Key: "inner.k", Value: "i"}}, + } + c := chainFor(t, outer, inner) + acc := NewAccumulator(0) + + merged := c.RunResponse(context.Background(), &Input{}, acc) + + require.Len(t, outer.seen, 1, "outer must observe inner's emission") + assert.Equal(t, "inner.k", outer.seen[0].Key) + require.Len(t, merged, 2, "merged slice contains both response emissions") +} + +// TestChain_RunResponse_CostMeterScenario simulates the synth-service +// chain shape (response_parser registered AFTER cost_meter so reverse +// iter runs response_parser first). The cost_meter analogue must see +// the tokens response_parser just emitted — this is the exact +// regression that produced cost.skipped=missing_tokens in the live +// access logs. +func TestChain_RunResponse_CostMeterScenario(t *testing.T) { + // Synthesizer registers cost_meter first, response_parser second. + costMeter := &fakeMiddleware{ + id: "cost_meter", + slot: SlotOnResponse, + keys: []string{"cost.usd_total", "cost.skipped"}, + } + respParser := &fakeMiddleware{ + id: "llm_response_parser", + slot: SlotOnResponse, + keys: []string{"llm.input_tokens", "llm.output_tokens"}, + emit: []KV{ + {Key: "llm.input_tokens", Value: "13"}, + {Key: "llm.output_tokens", Value: "259"}, + }, + } + c := chainFor(t, costMeter, respParser) + acc := NewAccumulator(0) + + _ = c.RunResponse(context.Background(), &Input{}, acc) + + require.Len(t, costMeter.seen, 2, "cost_meter must observe both token keys emitted by response_parser") + keys := []string{costMeter.seen[0].Key, costMeter.seen[1].Key} + assert.ElementsMatch(t, []string{"llm.input_tokens", "llm.output_tokens"}, keys, + "cost_meter must see the exact keys response_parser emitted") + values := []string{costMeter.seen[0].Value, costMeter.seen[1].Value} + assert.ElementsMatch(t, []string{"13", "259"}, values, "cost_meter must see the exact token counts") + for _, kv := range costMeter.seen { + _, err := strconv.Atoi(kv.Value) + assert.NoError(t, err, "values handed to cost_meter must be numeric (regression for missing_tokens)") + } +} + +// TestChain_RunResponse_DetachedContextStillRecords guards the metering +// fix in reverseproxy.go. The response/terminal phase runs after the body +// is forwarded, so a streaming client has usually disconnected by then, +// cancelling its request context. The dispatcher derives each middleware's +// context from the one passed here and short-circuits to fail-mode the +// instant it's Done, which silently drops token/cost metering. The reverse +// proxy now detaches that phase with context.WithoutCancel; this proves a +// context detached from an already-cancelled parent still lets a response +// middleware emit. (The cancelled-parent direction is intentionally not +// asserted: the dispatcher's select over ctx.Done vs the result channel is +// racy when both are ready, which is exactly why the bug was intermittent.) +func TestChain_RunResponse_DetachedContextStillRecords(t *testing.T) { + resp := &fakeMiddleware{ + id: "recorder", + slot: SlotOnResponse, + keys: []string{"llm.input_tokens"}, + emit: []KV{{Key: "llm.input_tokens", Value: "42"}}, + decision: DecisionPassthrough, + } + c := chainFor(t, resp) + + clientCtx, cancel := context.WithCancel(context.Background()) + cancel() // client disconnected after the stream completed + require.Error(t, clientCtx.Err(), "client context must be cancelled for the test to be meaningful") + + detached := context.WithoutCancel(clientCtx) + require.NoError(t, detached.Err(), "detached context must not inherit the client's cancellation") + + acc := NewAccumulator(MaxRequestMetadataBytes) + merged := c.RunResponse(detached, &Input{Slot: SlotOnResponse}, acc) + + var got string + for _, kv := range merged { + if kv.Key == "llm.input_tokens" { + got = kv.Value + } + } + assert.Equal(t, "42", got, "response middleware must still emit token metadata under the detached context") +} + +// TestChain_RunRequest_LatestRewriteWins asserts that when two +// on_request middlewares both emit an UpstreamRewrite, the chain +// returns the value from the later middleware. +func TestChain_RunRequest_LatestRewriteWins(t *testing.T) { + first := &fakeMiddleware{ + id: "first", + slot: SlotOnRequest, + mutationsSupported: true, + canMutate: true, + mutations: &Mutations{RewriteUpstream: &UpstreamRewrite{Scheme: "https", Host: "first.test"}}, + } + second := &fakeMiddleware{ + id: "second", + slot: SlotOnRequest, + mutationsSupported: true, + canMutate: true, + mutations: &Mutations{RewriteUpstream: &UpstreamRewrite{Scheme: "https", Host: "second.test"}}, + } + c := chainFor(t, first, second) + acc := NewAccumulator(0) + + denied, _, rewrite, err := c.RunRequest(context.Background(), nil, &Input{}, acc) + require.NoError(t, err) + assert.Nil(t, denied, "neither middleware denies") + require.NotNil(t, rewrite, "chain must surface the rewrite emitted by the on_request slot") + assert.Equal(t, "https", rewrite.Scheme, "rewrite scheme must come from the later middleware") + assert.Equal(t, "second.test", rewrite.Host, "rewrite host must come from the later middleware (last-write-wins)") +} + +// TestChain_RunRequest_NoRewrite_NilReturn asserts the chain returns a +// nil rewrite when no middleware emits one. +func TestChain_RunRequest_NoRewrite_NilReturn(t *testing.T) { + first := &fakeMiddleware{id: "first", slot: SlotOnRequest} + second := &fakeMiddleware{id: "second", slot: SlotOnRequest} + c := chainFor(t, first, second) + acc := NewAccumulator(0) + + denied, _, rewrite, err := c.RunRequest(context.Background(), nil, &Input{}, acc) + require.NoError(t, err) + assert.Nil(t, denied, "neither middleware denies") + assert.Nil(t, rewrite, "chain must return nil rewrite when no middleware emits one") +} + +// TestChain_ApplyMutations_RewriteGatedOnCanMutate asserts that a +// middleware emitting an UpstreamRewrite with CanMutate=false has its +// rewrite filtered out by the chain. The dispatcher's filterOutput +// already clears Mutations when the gates fail; the chain's defensive +// gate inside mutationRewrite mirrors that contract so a stale +// Mutations field cannot leak through. +func TestChain_ApplyMutations_RewriteGatedOnCanMutate(t *testing.T) { + mw := &fakeMiddleware{ + id: "first", + slot: SlotOnRequest, + mutationsSupported: true, + canMutate: false, + mutations: &Mutations{RewriteUpstream: &UpstreamRewrite{Scheme: "https", Host: "denied.test"}}, + } + c := chainFor(t, mw) + acc := NewAccumulator(0) + + denied, _, rewrite, err := c.RunRequest(context.Background(), nil, &Input{}, acc) + require.NoError(t, err) + assert.Nil(t, denied, "middleware does not deny") + assert.Nil(t, rewrite, "rewrite must be filtered when CanMutate=false") +} + +// TestChain_RunRequest_PropagatesUserGroups asserts the chain forwards +// Input.UserGroups verbatim through cloneInputFor so policy-aware +// middlewares (e.g. llm_policy_check) can authorise without an extra +// management round-trip. +func TestChain_RunRequest_PropagatesUserGroups(t *testing.T) { + groupCapture := &userGroupCaptureMiddleware{ + id: "group-capture", + slot: SlotOnRequest, + } + c := chainFor(t, groupCapture.fake()) + groupCapture.bind(c) + acc := NewAccumulator(0) + + in := &Input{UserGroups: []string{"g1"}} + denied, _, _, err := c.RunRequest(context.Background(), nil, in, acc) + require.NoError(t, err) + assert.Nil(t, denied, "no deny without DecisionDeny") + + require.Len(t, groupCapture.seenGroups, 1, "middleware must observe the caller's UserGroups") + assert.Equal(t, "g1", groupCapture.seenGroups[0], "UserGroups must reach the middleware verbatim") +} + +// userGroupCaptureMiddleware is a fakeMiddleware variant that records +// Input.UserGroups during Invoke. It exists so the cloneInputFor +// behaviour for the new field can be asserted without leaking into +// every other chain test. +type userGroupCaptureMiddleware struct { + id string + slot Slot + seenGroups []string + fakeMW *fakeMiddleware +} + +func (u *userGroupCaptureMiddleware) fake() *fakeMiddleware { + u.fakeMW = &fakeMiddleware{id: u.id, slot: u.slot} + return u.fakeMW +} + +func (u *userGroupCaptureMiddleware) bind(c *Chain) { + for i, bm := range c.all { + if bm.spec.ID != u.id { + continue + } + c.all[i].mw = userGroupRecorder{ + fakeMiddleware: u.fakeMW, + parent: u, + } + } +} + +type userGroupRecorder struct { + *fakeMiddleware + parent *userGroupCaptureMiddleware +} + +func (r userGroupRecorder) Invoke(ctx context.Context, in *Input) (*Output, error) { + r.parent.seenGroups = append([]string(nil), in.UserGroups...) + return r.fakeMiddleware.Invoke(ctx, in) +} + +// TestChain_RunTerminal_SeesAccumulatedMetadata locks that terminal +// middlewares observe the full bag (the caller-supplied in.Metadata +// plus any prior terminal emissions). +func TestChain_RunTerminal_SeesAccumulatedMetadata(t *testing.T) { + first := &fakeMiddleware{ + id: "term-1", + slot: SlotTerminal, + keys: []string{"term.first"}, + emit: []KV{{Key: "term.first", Value: "1"}}, + } + second := &fakeMiddleware{ + id: "term-2", + slot: SlotTerminal, + keys: []string{"term.second"}, + } + c := chainFor(t, first, second) + acc := NewAccumulator(0) + + in := &Input{Metadata: []KV{{Key: "ext.k", Value: "ext"}}} + merged := c.RunTerminal(context.Background(), in, acc) + + require.Len(t, second.seen, 2, "second terminal must see ext bag + first terminal's emission") + got := map[string]string{} + for _, kv := range second.seen { + got[kv.Key] = kv.Value + } + assert.Equal(t, "ext", got["ext.k"], "external bag carries through") + assert.Equal(t, "1", got["term.first"], "first terminal's emission visible to second terminal") + assert.Len(t, merged, 1, "only first terminal emitted; second emitted nothing") +} diff --git a/proxy/internal/middleware/decision.go b/proxy/internal/middleware/decision.go new file mode 100644 index 000000000..0970bdea4 --- /dev/null +++ b/proxy/internal/middleware/decision.go @@ -0,0 +1,81 @@ +package middleware + +import ( + "encoding/json" + "net/http" + "regexp" +) + +var codeRegex = regexp.MustCompile(`^[a-z][a-z0-9._-]{0,63}$`) + +// denyResponse is the on-wire shape rendered by RenderDenyResponse. +// Keeping this as a typed struct ensures we never leak +// middleware-supplied bytes outside known fields. +type denyResponse struct { + Code string `json:"code"` + Message string `json:"message,omitempty"` + Details map[string]string `json:"details,omitempty"` + Middleware string `json:"middleware,omitempty"` +} + +// RenderDenyResponse writes a structured JSON deny body. Status is +// clamped to [400, 499] excluding 401 (to avoid conflicts with the +// proxy's auth flow). All middleware-supplied strings are redacted and +// truncated. On any validation failure the function writes a generic +// 403. +func RenderDenyResponse(w http.ResponseWriter, middlewareID string, reason *DenyReason, defaultStatus int) { + status := clampDenyStatus(defaultStatus) + + if reason == nil || !codeRegex.MatchString(reason.Code) { + writeGenericDeny(w, middlewareID, status) + return + } + + resp := denyResponse{ + Code: reason.Code, + Message: truncate(Scan(reason.Message), 256), + Middleware: truncate(Scan(middlewareID), 64), + } + if n := len(reason.Details); n > 0 { + resp.Details = make(map[string]string, min(n, 8)) + for k, v := range reason.Details { + if len(resp.Details) >= 8 { + break + } + safeKey := truncate(Scan(k), 64) + if safeKey == "" { + continue + } + resp.Details[safeKey] = truncate(Scan(v), 256) + } + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(resp); err != nil { + return + } +} + +func writeGenericDeny(w http.ResponseWriter, middlewareID string, status int) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(denyResponse{Code: "middleware.error", Middleware: truncate(Scan(middlewareID), 64)}) +} + +func clampDenyStatus(s int) int { + if s < 400 || s >= 500 { + return http.StatusForbidden + } + if s == http.StatusUnauthorized { + return http.StatusForbidden + } + return s +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] +} diff --git a/proxy/internal/middleware/dispatcher.go b/proxy/internal/middleware/dispatcher.go new file mode 100644 index 000000000..316604651 --- /dev/null +++ b/proxy/internal/middleware/dispatcher.go @@ -0,0 +1,189 @@ +package middleware + +import ( + "context" + "errors" + "fmt" + "reflect" + "runtime" + "time" + + log "github.com/sirupsen/logrus" +) + +// Dispatcher reliability kinds reported via +// proxy.middleware.errors_total{kind=...}. +const ( + ErrorKindPanic = "panic" + ErrorKindTimeout = "timeout" + ErrorKindInvokeError = "invoke_error" +) + +// Dispatcher drives a single middleware invocation with panic +// recovery, deadline, and output filtering. Safe for concurrent use. +type Dispatcher struct { + metrics *Metrics + logger *log.Logger +} + +// NewDispatcher returns a dispatcher that emits on the provided +// metrics bundle and logger. A nil metrics bundle falls back to a noop +// instrument set; a nil logger falls back to the standard logger. +func NewDispatcher(metrics *Metrics, logger *log.Logger) *Dispatcher { + if metrics == nil { + metrics, _ = NewMetrics(nil) + } + if logger == nil { + logger = log.StandardLogger() + } + return &Dispatcher{metrics: metrics, logger: logger} +} + +// Invoke runs a single middleware under the reliability wrappers: +// deadline, panic recovery (type + truncated stack only), fail-mode, +// metric emission, and output filtering. The returned output is always +// safe to apply. +func (d *Dispatcher) Invoke(ctx context.Context, spec Spec, mw Middleware, in *Input) (*Output, error) { + if mw == nil { + return nil, fmt.Errorf("middleware %s: instance unavailable", spec.ID) + } + + timeout := clampTimeout(spec.Timeout) + callCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + d.metrics.IncInvocation(ctx, spec.ID) + start := time.Now() + + type result struct { + out *Output + err error + } + ch := make(chan result, 1) + + go func() { + defer func() { + if r := recover(); r != nil { + stack := make([]byte, 4<<10) + n := runtime.Stack(stack, false) + requestID := "" + if in != nil { + requestID = in.RequestID + } + d.logger.Warnf("middleware %s panic: request_id=%s type=%s stack=%s", + spec.ID, requestID, reflect.TypeOf(r).String(), stack[:n]) + ch <- result{err: panicError{msg: fmt.Sprintf("middleware %s panic: %s", spec.ID, reflect.TypeOf(r).String())}} + } + }() + out, err := mw.Invoke(callCtx, in) + ch <- result{out: out, err: err} + }() + + var ( + out *Output + invErr error + kind string + ) + + select { + case <-callCtx.Done(): + invErr = callCtx.Err() + kind = ErrorKindTimeout + case res := <-ch: + out = res.out + invErr = res.err + if invErr != nil { + kind = d.classifyError(invErr) + } + } + + d.metrics.ObserveDuration(ctx, spec.ID, time.Since(start).Milliseconds()) + + if invErr != nil { + d.metrics.IncError(ctx, spec.ID, kind) + return d.failMode(spec, kind), invErr + } + + return d.filterOutput(spec, out), nil +} + +func (d *Dispatcher) classifyError(err error) string { + if err == nil { + return "" + } + if errors.Is(err, context.DeadlineExceeded) { + return ErrorKindTimeout + } + var pe panicError + if errors.As(err, &pe) { + return ErrorKindPanic + } + return ErrorKindInvokeError +} + +// panicError marks an error as coming from the recover branch so the +// classifier can tag it without string inspection. +type panicError struct{ msg string } + +func (p panicError) Error() string { return p.msg } + +// failMode converts an error into a synthesised output per the +// middleware's fail-mode. An mw..error_kind metadata entry is +// attached so operators can alert on error rate even when the +// decision is fail-open. Slot constraints still apply: response and +// terminal slots clamp deny back to passthrough in filterOutput. +func (d *Dispatcher) failMode(spec Spec, kind string) *Output { + meta := []KV{{Key: fmt.Sprintf(KeyFrameworkErrorKindFmt, spec.ID), Value: kind}} + if spec.FailMode == FailClosed && spec.Slot == SlotOnRequest { + return &Output{ + Decision: DecisionDeny, + DenyStatus: 500, + DenyReason: &DenyReason{Code: "middleware.error"}, + Metadata: meta, + } + } + return &Output{Decision: DecisionAllow, Metadata: meta} +} + +// filterOutput applies the output-filter pipeline (slot-aware decision +// clamp, mutations gate) so downstream consumers never see +// middleware-supplied values that violate the contract. Metadata is +// passed through; the Accumulator is the single owner of allowlist + +// caps + redaction (called by Chain). +func (d *Dispatcher) filterOutput(spec Spec, out *Output) *Output { + if out == nil { + return &Output{Decision: DecisionAllow} + } + if spec.Slot != SlotOnRequest && out.Decision == DecisionDeny { + out.Decision = DecisionPassthrough + out.DenyStatus = 0 + out.DenyReason = nil + } + if out.Decision == DecisionDeny { + if out.DenyStatus == 0 { + out.DenyStatus = 403 + } else { + out.DenyStatus = clampDenyStatus(out.DenyStatus) + } + } + if !spec.CanMutate || !spec.MutationsSupported { + out.Mutations = nil + } + if spec.Slot == SlotTerminal { + out.Mutations = nil + } + return out +} + +func clampTimeout(d time.Duration) time.Duration { + if d <= 0 { + return DefaultTimeout + } + if d < MinTimeout { + return MinTimeout + } + if d > MaxTimeout { + return MaxTimeout + } + return d +} diff --git a/proxy/internal/middleware/headerpolicy.go b/proxy/internal/middleware/headerpolicy.go new file mode 100644 index 000000000..d041ad1e1 --- /dev/null +++ b/proxy/internal/middleware/headerpolicy.go @@ -0,0 +1,99 @@ +package middleware + +import "strings" + +var denyHeaders = []string{ + "Authorization", + "Connection", + "Cookie", + "Set-Cookie", + "Forwarded", + "Keep-Alive", + "Proxy-Authorization", + "Proxy-Authenticate", + "Proxy-Connection", + "TE", + "Upgrade", + "Via", + "X-Real-IP", + "X-Request-ID", + "Host", + "Content-Length", + "Transfer-Encoding", + "Trailer", +} + +var denyHeaderPrefixes = []string{ + "X-Authenticated-", + "X-Forwarded-", + "X-Remote-", + "X-NetBird-", +} + +// IsHeaderMutable reports whether a middleware is allowed to mutate +// the named header. The check is case-insensitive and honours both +// exact matches and the compiled-in prefix denylist. +func IsHeaderMutable(name string) bool { + if name == "" { + return false + } + if !isHeaderFieldName(name) { + return false + } + for _, d := range denyHeaders { + if strings.EqualFold(d, name) { + return false + } + } + for _, p := range denyHeaderPrefixes { + if len(name) >= len(p) && strings.EqualFold(name[:len(p)], p) { + return false + } + } + return true +} + +// isHeaderFieldName reports whether name is a valid RFC 7230 header +// field-name (a non-empty token of tchar octets). Rejects names with +// spaces, control characters, or separators that could enable header +// injection or smuggling when applied to the outbound request. +func isHeaderFieldName(name string) bool { + for i := 0; i < len(name); i++ { + c := name[i] + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') { + continue + } + switch c { + case '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~': + continue + default: + return false + } + } + return true +} + +// FilterHeaderMutations returns the subsets of HeadersAdd and +// HeadersRemove that are safe to apply, plus the list of blocked +// header names so the dispatcher can increment the blocked-header +// metric. +func FilterHeaderMutations(m *Mutations) (filteredAdd []KV, filteredRemove []string, blocked []string) { + if m == nil { + return nil, nil, nil + } + for _, kv := range m.HeadersAdd { + if IsHeaderMutable(kv.Key) { + filteredAdd = append(filteredAdd, kv) + continue + } + blocked = append(blocked, kv.Key) + } + for _, name := range m.HeadersRemove { + if IsHeaderMutable(name) { + filteredRemove = append(filteredRemove, name) + continue + } + blocked = append(blocked, name) + } + return filteredAdd, filteredRemove, blocked +} diff --git a/proxy/internal/middleware/keys.go b/proxy/internal/middleware/keys.go new file mode 100644 index 000000000..9c584ad82 --- /dev/null +++ b/proxy/internal/middleware/keys.go @@ -0,0 +1,86 @@ +package middleware + +// Metadata key namespace constants shared across the built-in +// middlewares. Each domain owns a prefix; middlewares declare their +// per-key allowlist drawn from these constants. Agents implementing +// the G2 middlewares import this file so the dashboard's expanded-row +// viewer and the access-log writer see a stable key surface. +// +// Key shape rules (enforced by the metadata accumulator): +// - Lowercase ASCII letters, digits, dot, underscore, hyphen. +// - At least one dot separating namespace from leaf. +// - Max length: MaxMetadataKeyBytes. +const ( + // LLM request-side metadata (emitted by llm_request_parser). + KeyLLMProvider = "llm.provider" + KeyLLMModel = "llm.model" + KeyLLMStream = "llm.stream" + KeyLLMRequestPromptRaw = "llm.request_prompt_raw" + KeyLLMCaptureTruncated = "llm.capture_truncated" + // KeyLLMSessionID groups requests of the same conversation / coding + // session, read from the per-provider session marker in the request + // body. Empty for clients that don't send one. + KeyLLMSessionID = "llm.session_id" + + // LLM response-side metadata (emitted by llm_response_parser). + //nolint:gosec // metadata key name, not a credential + KeyLLMInputTokens = "llm.input_tokens" + //nolint:gosec // metadata key name, not a credential + KeyLLMOutputTokens = "llm.output_tokens" + //nolint:gosec // metadata key name, not a credential + KeyLLMTotalTokens = "llm.total_tokens" + // LLM cached-input bucket. For OpenAI it's the SUBSET of input + // tokens that hit the prompt cache (prompt_tokens_details. + // cached_tokens) — billed at the cached_input_per_1k rate when + // configured. For Anthropic it's cache_read_input_tokens, which + // is ADDITIVE to llm.input_tokens — billed at cache_read_per_1k. + // cost_meter switches formula on llm.provider. + //nolint:gosec // metadata key name, not a credential + KeyLLMCachedInputTokens = "llm.cached_input_tokens" + // LLM cache-creation bucket (Anthropic only). ADDITIVE to + // llm.input_tokens; billed at cache_creation_per_1k. + //nolint:gosec // metadata key name, not a credential + KeyLLMCacheCreationTokens = "llm.cache_creation_tokens" + KeyLLMResponseCompletion = "llm.response_completion" + + // Guardrail outcomes (emitted by llm_guardrail). The guardrail + // also re-emits llm.request_prompt as a redacted variant of the + // raw prompt and drops llm.request_prompt_raw from the bag. + KeyLLMRequestPrompt = "llm.request_prompt" + KeyLLMPolicyDecision = "llm_policy.decision" + KeyLLMPolicyReason = "llm_policy.reason" + + // LLM router routing decision (emitted by llm_router). The router + // stamps the resolved provider id so downstream middlewares and + // the access-log emitter can attribute the request without + // re-parsing the body. + KeyLLMResolvedProviderID = "llm.resolved_provider_id" + + // LLM authorising groups for this request (emitted by llm_router + // on the allow path). Carries the comma-separated intersection of + // the caller's UserGroups with the resolved route's + // AllowedGroupIDs — i.e. the groups that actually authorise this + // specific request, NOT every group the peer happens to be in. + // Identity-stamping middlewares use this for per-request tag + // attribution so unrelated group memberships don't leak into + // downstream gateways' spend logs. + KeyLLMAuthorisingGroups = "llm.authorising_groups" + + // LLM policy attribution (emitted by llm_limit_check on the allow + // path). Names the policy that paid for this request and the + // dimension counters the post-flight llm_limit_record middleware + // must tick. Empty when no applicable policy has any caps + // configured (catch-all-allow attribution). + KeyLLMSelectedPolicyID = "llm.selected_policy_id" + KeyLLMAttributionGroupID = "llm.attribution_group_id" + KeyLLMAttributionWindowS = "llm.attribution_window_seconds" + + // Cost metering (emitted by cost_meter). + KeyCostUSDTotal = "cost.usd_total" + KeyCostSkipped = "cost.skipped" + + // Framework-emitted error markers. Use the mw..* prefix to + // distinguish framework-injected entries from middleware-emitted + // metadata. + KeyFrameworkErrorKindFmt = "mw.%s.error_kind" +) diff --git a/proxy/internal/middleware/manager.go b/proxy/internal/middleware/manager.go new file mode 100644 index 000000000..9b22edeff --- /dev/null +++ b/proxy/internal/middleware/manager.go @@ -0,0 +1,412 @@ +package middleware + +import ( + "context" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/middleware/bodytap" +) + +// chainCloseTimeout bounds how long closeChainsAsync waits for an +// individual chain to drain before forcing teardown. Set to 2x +// MaxTimeout so a middleware blocked on the dispatcher's per-Invoke +// deadline always wins; anything running longer is a runaway and gets +// force-closed. +const chainCloseTimeout = 2 * MaxTimeout + +// PathTargetBinding is the minimal per-path binding the server passes +// to Rebuild. It carries the stable keys Manager uses for snapshot +// lookups plus the validated middleware spec list for that path. +type PathTargetBinding struct { + ServiceID string + PathID string + Specs []Spec +} + +// LiveServiceCheck reports whether the given service ID is still +// present in the proxy's live mapping cache. The Manager calls it +// during InvalidateMiddleware so a chain whose service has been +// removed since the last Rebuild is not resurrected from the binding +// cache, closing the auth-revocation race. +type LiveServiceCheck func(serviceID string) bool + +// chainTable holds the immutable per-target chain snapshot. It is +// cloned into a new instance on every Rebuild and swapped in via +// atomic.Pointer. The reverse index byMiddleware lets +// InvalidateMiddleware find the chain keys that reference a given +// middleware without scanning the whole table. +type chainTable struct { + byTarget map[string]*Chain + byMiddleware map[string]map[string]struct{} +} + +func newChainTable() *chainTable { + return &chainTable{ + byTarget: make(map[string]*Chain), + byMiddleware: make(map[string]map[string]struct{}), + } +} + +func (c *chainTable) clone() *chainTable { + out := newChainTable() + for k, v := range c.byTarget { + out.byTarget[k] = v + } + for id, keys := range c.byMiddleware { + set := make(map[string]struct{}, len(keys)) + for k := range keys { + set[k] = struct{}{} + } + out.byMiddleware[id] = set + } + return out +} + +func (c *chainTable) addChain(key string, ch *Chain) { + c.byTarget[key] = ch + if ch == nil { + return + } + for _, bm := range ch.all { + set, ok := c.byMiddleware[bm.spec.ID] + if !ok { + set = make(map[string]struct{}) + c.byMiddleware[bm.spec.ID] = set + } + set[key] = struct{}{} + } +} + +func (c *chainTable) removeChain(key string) (*Chain, []string) { + ch, ok := c.byTarget[key] + if !ok { + return nil, nil + } + delete(c.byTarget, key) + if ch == nil { + return nil, nil + } + ids := make([]string, 0, len(ch.all)) + for _, bm := range ch.all { + ids = append(ids, bm.spec.ID) + set, ok := c.byMiddleware[bm.spec.ID] + if !ok { + continue + } + delete(set, key) + if len(set) == 0 { + delete(c.byMiddleware, bm.spec.ID) + } + } + return ch, ids +} + +// Manager owns the per-target middleware chains, the global capture +// budget, and the shared dispatcher. Readers (ChainFor) are lock-free; +// writers (Rebuild, Invalidate*) serialise on writeMu so two +// concurrent mapping updates do not lose writes. +type Manager struct { + writeMu sync.Mutex + chains atomic.Pointer[chainTable] + budget bodytap.Budget + metrics *Metrics + logger *log.Logger + dispatcher *Dispatcher + resolver *Resolver + lastBindings map[string]PathTargetBinding + liveServiceCheck atomic.Pointer[LiveServiceCheck] +} + +// NewManager constructs a Manager with the given capture budget size. +// A zero or negative budget falls back to bodytap.DefaultCaptureBudgetBytes. +func NewManager(budgetBytes int64, metrics *Metrics, logger *log.Logger) *Manager { + if metrics == nil { + metrics, _ = NewMetrics(nil) + } + if logger == nil { + logger = log.StandardLogger() + } + if budgetBytes <= 0 { + budgetBytes = bodytap.DefaultCaptureBudgetBytes + } + m := &Manager{ + budget: bodytap.NewBudget(budgetBytes), + metrics: metrics, + logger: logger, + dispatcher: NewDispatcher(metrics, logger), + lastBindings: make(map[string]PathTargetBinding), + } + m.chains.Store(newChainTable()) + return m +} + +// SetResolver installs the resolver used by Rebuild. Safe to call +// once at boot before any Rebuild; not safe to swap concurrently. +func (m *Manager) SetResolver(r *Resolver) { + m.resolver = r +} + +// SetLiveServiceCheck installs a callback the Manager uses to confirm +// a service ID still maps to a live mapping before resurrecting its +// chain from the binding cache during InvalidateMiddleware. A nil fn +// disables the check. +func (m *Manager) SetLiveServiceCheck(fn LiveServiceCheck) { + if fn == nil { + m.liveServiceCheck.Store(nil) + return + } + m.liveServiceCheck.Store(&fn) +} + +// Budget returns the shared capture budget. +func (m *Manager) Budget() bodytap.Budget { + return m.budget +} + +// Metrics returns the shared metrics bundle. +func (m *Manager) Metrics() *Metrics { + return m.metrics +} + +// Dispatcher returns the shared dispatcher (primarily for testing). +func (m *Manager) Dispatcher() *Dispatcher { + return m.dispatcher +} + +// Rebuild replaces every chain keyed by serviceID with the provided +// bindings. Entries for other services are preserved. Replaced chains +// are closed asynchronously after the atomic swap so in-flight +// requests against the previous chain finish before middleware +// resources are released. +func (m *Manager) Rebuild(serviceID string, bindings []PathTargetBinding) error { + m.writeMu.Lock() + defer m.writeMu.Unlock() + + cur := m.chains.Load() + next := cur.clone() + + prefix := serviceID + "|" + var retired []*Chain + for k := range cur.byTarget { + if !strings.HasPrefix(k, prefix) { + continue + } + ch, _ := next.removeChain(k) + if ch != nil { + retired = append(retired, ch) + } + delete(m.lastBindings, k) + } + + for _, b := range bindings { + if b.ServiceID != serviceID { + return fmt.Errorf("binding service %q does not match rebuild service %q", b.ServiceID, serviceID) + } + key := chainKey(b.ServiceID, b.PathID) + m.lastBindings[key] = cloneBinding(b) + chain := m.buildChain(b) + if chain == nil || chain.Empty() { + delete(m.lastBindings, key) + continue + } + next.addChain(key, chain) + } + + m.chains.Store(next) + m.closeChainsAsync(retired) + return nil +} + +// Invalidate drops every chain for the given service ID. +func (m *Manager) Invalidate(serviceID string) { + m.writeMu.Lock() + defer m.writeMu.Unlock() + cur := m.chains.Load() + next := cur.clone() + prefix := serviceID + "|" + var retired []*Chain + for k := range cur.byTarget { + if !strings.HasPrefix(k, prefix) { + continue + } + ch, _ := next.removeChain(k) + if ch != nil { + retired = append(retired, ch) + } + delete(m.lastBindings, k) + } + for k := range m.lastBindings { + if strings.HasPrefix(k, prefix) { + delete(m.lastBindings, k) + } + } + m.chains.Store(next) + m.closeChainsAsync(retired) +} + +// InvalidateMiddleware rebuilds only the chains that reference id. +func (m *Manager) InvalidateMiddleware(id string) { + if id == "" { + return + } + m.writeMu.Lock() + defer m.writeMu.Unlock() + + cur := m.chains.Load() + keys, ok := cur.byMiddleware[id] + if !ok || len(keys) == 0 { + return + } + + affected := make([]string, 0, len(keys)) + for k := range keys { + affected = append(affected, k) + } + + next := cur.clone() + var retired []*Chain + check := m.loadLiveServiceCheck() + for _, k := range affected { + ch, _ := next.removeChain(k) + if ch != nil { + retired = append(retired, ch) + } + b, ok := m.lastBindings[k] + if !ok { + delete(m.lastBindings, k) + continue + } + if check != nil && !check(b.ServiceID) { + m.logger.Debugf("middleware %s: skipping rebuild for %s; service no longer live", id, k) + delete(m.lastBindings, k) + continue + } + chain := m.buildChain(b) + if chain == nil || chain.Empty() { + delete(m.lastBindings, k) + continue + } + next.addChain(k, chain) + } + + m.chains.Store(next) + m.closeChainsAsync(retired) +} + +func (m *Manager) loadLiveServiceCheck() LiveServiceCheck { + p := m.liveServiceCheck.Load() + if p == nil { + return nil + } + return *p +} + +// InvalidateAll drops every chain. +func (m *Manager) InvalidateAll() { + m.writeMu.Lock() + defer m.writeMu.Unlock() + cur := m.chains.Load() + retired := make([]*Chain, 0, len(cur.byTarget)) + for _, c := range cur.byTarget { + retired = append(retired, c) + } + m.chains.Store(newChainTable()) + for k := range m.lastBindings { + delete(m.lastBindings, k) + } + m.closeChainsAsync(retired) +} + +func (m *Manager) closeChainsAsync(retired []*Chain) { + if len(retired) == 0 { + return + } + chains := make([]*Chain, len(retired)) + copy(chains, retired) + go func() { + for _, c := range chains { + ctx, cancel := context.WithTimeout(context.Background(), chainCloseTimeout) + start := time.Now() + if err := c.Close(ctx); err != nil { + if m.metrics != nil { + m.metrics.IncError(context.Background(), c.TargetID(), "chain_close_timeout") + } + m.logger.Warnf("middleware chain %s close exceeded %s after %s: %v", + c.TargetID(), chainCloseTimeout, time.Since(start), err) + } + cancel() + } + }() +} + +// ChainFor returns the chain for serviceID/pathID or nil if none is +// registered. Lock-free. +func (m *Manager) ChainFor(serviceID, pathID string) *Chain { + tbl := m.chains.Load() + if tbl == nil { + return nil + } + c, ok := tbl.byTarget[chainKey(serviceID, pathID)] + if !ok { + return nil + } + return c +} + +// buildChain resolves each enabled spec and returns the assembled +// chain. Returns a nil chain when no middlewares are bound; resolver +// errors per middleware are logged and counted but do not abort the +// chain. +func (m *Manager) buildChain(b PathTargetBinding) *Chain { + if len(b.Specs) == 0 || m.resolver == nil { + return nil + } + + bound := make([]boundMiddleware, 0, len(b.Specs)) + for _, spec := range b.Specs { + if !spec.Enabled { + continue + } + mw, merged, err := m.resolver.Resolve(spec) + if err != nil { + m.logger.Warnf("middleware %s resolve on target %s/%s: %v", spec.ID, b.ServiceID, b.PathID, err) + m.metrics.IncError(context.Background(), spec.ID, "resolve_error") + continue + } + if mw == nil { + continue + } + bound = append(bound, boundMiddleware{spec: merged, mw: mw}) + } + if len(bound) == 0 { + return nil + } + return NewChain(chainKey(b.ServiceID, b.PathID), bound, m.dispatcher) +} + +// cloneBinding returns a deep copy of b suitable for caching across +// mapping updates. +func cloneBinding(b PathTargetBinding) PathTargetBinding { + out := PathTargetBinding{ + ServiceID: b.ServiceID, + PathID: b.PathID, + } + if len(b.Specs) == 0 { + return out + } + out.Specs = make([]Spec, len(b.Specs)) + for i, s := range b.Specs { + out.Specs[i] = s.Clone() + } + return out +} + +func chainKey(serviceID, pathID string) string { + return serviceID + "|" + pathID +} diff --git a/proxy/internal/middleware/metadata.go b/proxy/internal/middleware/metadata.go new file mode 100644 index 000000000..576c379ec --- /dev/null +++ b/proxy/internal/middleware/metadata.go @@ -0,0 +1,99 @@ +package middleware + +import "regexp" + +// keyRegex constrains metadata keys to the cross-domain shape +// described in keys.go. At least one dot, lowercase ASCII / digits / +// dot / underscore / hyphen only, length within MaxMetadataKeyBytes. +var keyRegex = regexp.MustCompile(`^[a-z][a-z0-9_-]*(\.[a-z0-9][a-z0-9_-]*)+$`) + +// MetadataRejection describes a single rejected key/value so the +// dispatcher can emit per-reason counter increments. +type MetadataRejection struct { + Key string + Reason string +} + +// Rejection reasons reported by Accumulator.Emit. +const ( + MetadataReasonBadKey = "bad_key" + MetadataReasonNotAllowlisted = "not_allowlisted" + MetadataReasonKeyTooLong = "key_too_long" + MetadataReasonValueTooLong = "value_too_long" + MetadataReasonMiddlewareCap = "middleware_cap" + MetadataReasonRequestCap = "request_cap" +) + +// Accumulator enforces per-middleware and per-request metadata caps. +// Not safe for concurrent use; callers hold one inside a single chain +// execution. +type Accumulator struct { + perMiddlewareUsed map[string]int + totalUsed int + maxPerRequest int +} + +// NewAccumulator returns an accumulator configured for the per-request +// total cap. A maxPerRequest of zero means use MaxRequestMetadataBytes. +func NewAccumulator(maxPerRequest int) *Accumulator { + if maxPerRequest <= 0 { + maxPerRequest = MaxRequestMetadataBytes + } + return &Accumulator{ + perMiddlewareUsed: make(map[string]int), + maxPerRequest: maxPerRequest, + } +} + +// Emit validates the candidate metadata against the middleware's +// allowlist and the global caps, redacts each accepted value, and +// returns the accepted entries plus any rejections for metric emission. +func (a *Accumulator) Emit(middlewareID string, allow []string, out []KV) ([]KV, []MetadataRejection) { + if len(out) == 0 { + return nil, nil + } + allowSet := make(map[string]struct{}, len(allow)) + for _, k := range allow { + allowSet[k] = struct{}{} + } + + accepted := make([]KV, 0, len(out)) + var rejected []MetadataRejection + + for _, kv := range out { + if len(kv.Key) == 0 || len(kv.Key) > MaxMetadataKeyBytes { + rejected = append(rejected, MetadataRejection{Key: kv.Key, Reason: MetadataReasonKeyTooLong}) + continue + } + if !keyRegex.MatchString(kv.Key) { + rejected = append(rejected, MetadataRejection{Key: kv.Key, Reason: MetadataReasonBadKey}) + continue + } + if _, ok := allowSet[kv.Key]; !ok { + rejected = append(rejected, MetadataRejection{Key: kv.Key, Reason: MetadataReasonNotAllowlisted}) + continue + } + if len(kv.Value) > MaxMetadataValueBytes { + rejected = append(rejected, MetadataRejection{Key: kv.Key, Reason: MetadataReasonValueTooLong}) + continue + } + + redacted := Scan(kv.Value) + cost := len(kv.Key) + len(redacted) + + if a.perMiddlewareUsed[middlewareID]+cost > MaxMiddlewareMetadataBytes { + rejected = append(rejected, MetadataRejection{Key: kv.Key, Reason: MetadataReasonMiddlewareCap}) + continue + } + if a.totalUsed+cost > a.maxPerRequest { + rejected = append(rejected, MetadataRejection{Key: kv.Key, Reason: MetadataReasonRequestCap}) + continue + } + + a.perMiddlewareUsed[middlewareID] += cost + a.totalUsed += cost + accepted = append(accepted, KV{Key: kv.Key, Value: redacted}) + } + + return accepted, rejected +} diff --git a/proxy/internal/middleware/metrics.go b/proxy/internal/middleware/metrics.go new file mode 100644 index 000000000..73745a86b --- /dev/null +++ b/proxy/internal/middleware/metrics.go @@ -0,0 +1,171 @@ +package middleware + +import ( + "context" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/metric/noop" +) + +// Metrics is the bundle of OTel instruments emitted by the middleware +// dispatcher. The constructor falls back to a noop meter when given +// nil so tests can skip metrics wiring entirely. +type Metrics struct { + requestsTotal metric.Int64Counter + durationMs metric.Int64Histogram + invocationsTotal metric.Int64Counter + errorsTotal metric.Int64Counter + metadataRejectedTotal metric.Int64Counter + headerMutationBlocked metric.Int64Counter + captureBypassTotal metric.Int64Counter +} + +// NewMetrics registers the proxy.middleware.* instruments on the +// given meter. A nil meter is treated as the global no-op provider. +func NewMetrics(meter metric.Meter) (*Metrics, error) { + if meter == nil { + meter = noop.NewMeterProvider().Meter("proxy.middleware.noop") + } + + m := &Metrics{} + var err error + + m.requestsTotal, err = meter.Int64Counter( + "proxy.middleware.requests_total", + metric.WithUnit("1"), + metric.WithDescription("Middleware invocations grouped by outcome"), + ) + if err != nil { + return nil, err + } + + m.durationMs, err = meter.Int64Histogram( + "proxy.middleware.duration_ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("Middleware Invoke latency"), + ) + if err != nil { + return nil, err + } + + m.invocationsTotal, err = meter.Int64Counter( + "proxy.middleware.invocations_total", + metric.WithUnit("1"), + metric.WithDescription("Middleware Invoke heartbeat counter"), + ) + if err != nil { + return nil, err + } + + m.errorsTotal, err = meter.Int64Counter( + "proxy.middleware.errors_total", + metric.WithUnit("1"), + metric.WithDescription("Middleware errors grouped by kind"), + ) + if err != nil { + return nil, err + } + + m.metadataRejectedTotal, err = meter.Int64Counter( + "proxy.middleware.metadata_rejected_total", + metric.WithUnit("1"), + metric.WithDescription("Middleware metadata entries rejected by the allowlist/caps"), + ) + if err != nil { + return nil, err + } + + m.headerMutationBlocked, err = meter.Int64Counter( + "proxy.middleware.header_mutation_blocked_total", + metric.WithUnit("1"), + metric.WithDescription("Middleware header mutations dropped by the denylist"), + ) + if err != nil { + return nil, err + } + + m.captureBypassTotal, err = meter.Int64Counter( + "proxy.middleware.capture_bypass_total", + metric.WithUnit("1"), + metric.WithDescription("Capture bypasses grouped by reason"), + ) + if err != nil { + return nil, err + } + + return m, nil +} + +// IncRequest increments proxy.middleware.requests_total with the +// middleware, target, and outcome labels. +func (m *Metrics) IncRequest(ctx context.Context, middlewareID, targetID, outcome string) { + if m == nil { + return + } + m.requestsTotal.Add(ctx, 1, metric.WithAttributes( + attribute.String("middleware", middlewareID), + attribute.String("target_id", targetID), + attribute.String("outcome", outcome), + )) +} + +// ObserveDuration records the middleware Invoke latency in milliseconds. +func (m *Metrics) ObserveDuration(ctx context.Context, middlewareID string, ms int64) { + if m == nil { + return + } + m.durationMs.Record(ctx, ms, metric.WithAttributes(attribute.String("middleware", middlewareID))) +} + +// IncInvocation increments the heartbeat counter regardless of outcome. +func (m *Metrics) IncInvocation(ctx context.Context, middlewareID string) { + if m == nil { + return + } + m.invocationsTotal.Add(ctx, 1, metric.WithAttributes(attribute.String("middleware", middlewareID))) +} + +// IncError increments the error counter with the given failure kind label. +func (m *Metrics) IncError(ctx context.Context, middlewareID, kind string) { + if m == nil { + return + } + m.errorsTotal.Add(ctx, 1, metric.WithAttributes( + attribute.String("middleware", middlewareID), + attribute.String("kind", kind), + )) +} + +// IncMetadataRejected increments the rejected-metadata counter for a reason. +func (m *Metrics) IncMetadataRejected(ctx context.Context, middlewareID, reason string) { + if m == nil { + return + } + m.metadataRejectedTotal.Add(ctx, 1, metric.WithAttributes( + attribute.String("middleware", middlewareID), + attribute.String("reason", reason), + )) +} + +// IncHeaderMutationBlocked increments the blocked-header counter. +func (m *Metrics) IncHeaderMutationBlocked(ctx context.Context, middlewareID, header string) { + if m == nil { + return + } + m.headerMutationBlocked.Add(ctx, 1, metric.WithAttributes( + attribute.String("middleware", middlewareID), + attribute.String("header", header), + )) +} + +// IncCaptureBypass increments the capture-bypass counter for a reason. +func (m *Metrics) IncCaptureBypass(ctx context.Context, targetID, reason string) { + if m == nil { + return + } + m.captureBypassTotal.Add(ctx, 1, metric.WithAttributes( + attribute.String("target_id", targetID), + attribute.String("reason", reason), + )) +} diff --git a/proxy/internal/middleware/middleware.go b/proxy/internal/middleware/middleware.go new file mode 100644 index 000000000..16d398d2c --- /dev/null +++ b/proxy/internal/middleware/middleware.go @@ -0,0 +1,47 @@ +package middleware + +import "context" + +// Middleware is the surface exposed by each concrete implementation. +// The Manager invokes it through the Dispatcher, passing a cloned +// Input. Each middleware lives in exactly one Slot. +// +// Close releases any resources owned by the middleware instance +// (background goroutines, file handles). It is invoked when the chain +// holding the middleware is replaced or torn down. Implementations +// must be idempotent and safe to call after construction even when +// Invoke was never called. +type Middleware interface { + ID() string + Version() string + Slot() Slot + + // AcceptedContentTypes lists the request/response content types + // the middleware needs the body for. Empty slice means the + // middleware does not inspect the body. + AcceptedContentTypes() []string + + // MetadataKeys is the closed set of metadata keys this middleware + // may emit. The accumulator drops anything outside this allowlist. + MetadataKeys() []string + + // MutationsSupported reports whether the middleware may emit + // header / body mutations. A spec with CanMutate=true is honoured + // only when the implementation also supports mutations. + MutationsSupported() bool + + Invoke(ctx context.Context, in *Input) (*Output, error) + + Close() error +} + +// Factory builds a configured Middleware instance from raw config +// bytes shipped on the wire. Each registered middleware ID has a +// single factory in the registry. Factory.New returns an error when +// the config is malformed or violates a per-middleware invariant; the +// chain build path logs the error, increments the resolve_error metric, +// and skips the middleware. +type Factory interface { + ID() string + New(rawConfig []byte) (Middleware, error) +} diff --git a/proxy/internal/middleware/redaction.go b/proxy/internal/middleware/redaction.go new file mode 100644 index 000000000..ebbe90c61 --- /dev/null +++ b/proxy/internal/middleware/redaction.go @@ -0,0 +1,79 @@ +package middleware + +import ( + "regexp" + "strings" +) + +// Redaction scope: Scan handles the narrow, high-signal set of +// secrets we are comfortable masking with a regex. The intent is +// "make accidental leaks impossible to miss at a glance", not "be a +// DLP product". Contributors adding more patterns should weigh false +// positives carefully — a metadata value that over-redacts benign +// strings is strictly worse than one that misses a rare format. +var ( + jwtRegex = regexp.MustCompile(`eyJ[A-Za-z0-9_-]{5,}\.[A-Za-z0-9_-]{5,}\.[A-Za-z0-9_-]{5,}`) + pemRegex = regexp.MustCompile(`-----BEGIN [A-Z ]+-----[\s\S]*?-----END [A-Z ]+-----`) + awsKeyRegex = regexp.MustCompile(`AKIA[0-9A-Z]{16}`) + bearerRegex = regexp.MustCompile(`(?i)\b(?:bearer|token|api[_-]?key|authorization)[\s:=]+([A-Za-z0-9_\-\.]{40,})`) + ccCandidateRgx = regexp.MustCompile(`\b(?:\d[ -]?){13,19}\b`) +) + +// Scan redacts high-signal secret patterns from value. Matches are +// replaced with `[REDACTED:]`. Non-matching input is returned +// unchanged. +func Scan(value string) string { + if value == "" { + return value + } + result := value + result = pemRegex.ReplaceAllString(result, "[REDACTED:pem]") + result = jwtRegex.ReplaceAllString(result, "[REDACTED:jwt]") + result = awsKeyRegex.ReplaceAllString(result, "[REDACTED:aws_key]") + result = bearerRegex.ReplaceAllStringFunc(result, func(match string) string { + sub := bearerRegex.FindStringSubmatch(match) + if len(sub) < 2 { + return "[REDACTED:bearer]" + } + return strings.Replace(match, sub[1], "[REDACTED:bearer]", 1) + }) + result = ccCandidateRgx.ReplaceAllStringFunc(result, func(match string) string { + digits := stripNonDigits(match) + if len(digits) < 13 || len(digits) > 19 { + return match + } + if !luhn(digits) { + return match + } + return "[REDACTED:cc]" + }) + return result +} + +func stripNonDigits(s string) string { + var b strings.Builder + b.Grow(len(s)) + for _, r := range s { + if r >= '0' && r <= '9' { + b.WriteRune(r) + } + } + return b.String() +} + +func luhn(digits string) bool { + sum := 0 + alt := false + for i := len(digits) - 1; i >= 0; i-- { + n := int(digits[i] - '0') + if alt { + n *= 2 + if n > 9 { + n -= 9 + } + } + sum += n + alt = !alt + } + return sum%10 == 0 +} diff --git a/proxy/internal/middleware/registry.go b/proxy/internal/middleware/registry.go new file mode 100644 index 000000000..c46552162 --- /dev/null +++ b/proxy/internal/middleware/registry.go @@ -0,0 +1,121 @@ +package middleware + +import ( + "fmt" + "sync" +) + +// Registry maps middleware IDs to their factories. The proxy installs +// a single Registry at boot; concrete middlewares register themselves +// from init() functions inside their own packages so the boot wiring +// only needs an anonymous import. +// +// Registry is safe for concurrent reads after boot. Register / Unregister +// take the write lock; Get and IDs take the read lock. +type Registry struct { + mu sync.RWMutex + factories map[string]Factory +} + +// NewRegistry returns an empty registry. +func NewRegistry() *Registry { + return &Registry{factories: make(map[string]Factory)} +} + +// Register installs the factory under its ID. Returns an error when an +// ID is already registered — collisions are programmer errors and must +// be visible at boot rather than silently last-write-wins. +func (r *Registry) Register(f Factory) error { + if f == nil { + return fmt.Errorf("middleware registry: nil factory") + } + id := f.ID() + if id == "" { + return fmt.Errorf("middleware registry: factory has empty id") + } + r.mu.Lock() + defer r.mu.Unlock() + if _, exists := r.factories[id]; exists { + return fmt.Errorf("middleware registry: %q already registered", id) + } + r.factories[id] = f + return nil +} + +// MustRegister panics on error. Intended for init() registration so +// duplicate IDs surface at startup. +func (r *Registry) MustRegister(f Factory) { + if err := r.Register(f); err != nil { + panic(err) + } +} + +// Get returns the factory for id, or nil when no factory is +// registered. +func (r *Registry) Get(id string) Factory { + r.mu.RLock() + defer r.mu.RUnlock() + return r.factories[id] +} + +// IDs returns the registered IDs in unspecified order. Used by the +// management translator to reject specs that reference unknown IDs at +// apply time. +func (r *Registry) IDs() []string { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]string, 0, len(r.factories)) + for id := range r.factories { + out = append(out, id) + } + return out +} + +// IsKnown reports whether id has a registered factory. +func (r *Registry) IsKnown(id string) bool { + return r.Get(id) != nil +} + +// Resolver wraps a Registry and produces a configured Middleware +// instance from a Spec. The Manager uses this during chain build. +type Resolver struct { + registry *Registry +} + +// NewResolver returns a resolver backed by the registry. +func NewResolver(registry *Registry) *Resolver { + if registry == nil { + registry = NewRegistry() + } + return &Resolver{registry: registry} +} + +// Resolve builds a Middleware instance and merges runtime-only fields +// (version, accepted content types, metadata key allowlist, mutation +// support) onto the spec. +// +// Return semantics: +// - (mw, mergedSpec, nil): instance built, include in chain. +// - (nil, spec, nil): id not registered; silently skip. +// - (nil, spec, err): factory rejected the config (logged + counted +// by Manager, other middlewares still bind). +func (r *Resolver) Resolve(spec Spec) (Middleware, Spec, error) { + f := r.registry.Get(spec.ID) + if f == nil { + return nil, spec, nil + } + mw, err := f.New(spec.RawConfig) + if err != nil { + return nil, spec, fmt.Errorf("middleware %s factory: %w", spec.ID, err) + } + if mw.Slot() != spec.Slot { + _ = mw.Close() + return nil, spec, fmt.Errorf("middleware %s slot mismatch: spec=%d impl=%d", spec.ID, spec.Slot, mw.Slot()) + } + merged := spec + merged.Version = mw.Version() + merged.MetadataKeys = append([]string(nil), mw.MetadataKeys()...) + merged.AcceptedContentTypes = append([]string(nil), mw.AcceptedContentTypes()...) + merged.MutationsSupported = mw.MutationsSupported() + return mw, merged, nil +} diff --git a/proxy/internal/middleware/spec.go b/proxy/internal/middleware/spec.go new file mode 100644 index 000000000..a154ceca2 --- /dev/null +++ b/proxy/internal/middleware/spec.go @@ -0,0 +1,44 @@ +package middleware + +import "time" + +// Spec is the apply-time, validated representation of a per-target +// middleware configuration merged with the runtime-only fields +// compiled into the middleware implementation. +// +// The wire shape is RawConfig (JSON bytes) instead of the older +// params map[string]string. Each middleware unmarshals RawConfig into +// its own typed config struct, surfacing structural validation errors +// at construction rather than per-invocation lookups. +type Spec struct { + ID string + Slot Slot + Version string + Enabled bool + FailMode FailMode + Timeout time.Duration + RawConfig []byte + CanMutate bool + + // Runtime-only fields populated from the registered middleware at + // chain build time; not sourced from proto. + MetadataKeys []string + AcceptedContentTypes []string + MutationsSupported bool +} + +// Clone returns a deep copy of the spec safe to cache across mapping +// updates. +func (s Spec) Clone() Spec { + out := s + if len(s.RawConfig) > 0 { + out.RawConfig = append([]byte(nil), s.RawConfig...) + } + if len(s.MetadataKeys) > 0 { + out.MetadataKeys = append([]string(nil), s.MetadataKeys...) + } + if len(s.AcceptedContentTypes) > 0 { + out.AcceptedContentTypes = append([]string(nil), s.AcceptedContentTypes...) + } + return out +} diff --git a/proxy/internal/middleware/types.go b/proxy/internal/middleware/types.go new file mode 100644 index 000000000..1b49e6159 --- /dev/null +++ b/proxy/internal/middleware/types.go @@ -0,0 +1,253 @@ +// Package middleware defines the per-target middleware chain that runs +// inside the reverse proxy hot path. It is the only chain wired into +// the request path. +// +// Concepts: +// - Slot: the position a middleware occupies in the chain. A +// middleware lives in exactly one slot — separate concerns become +// separate middlewares. +// - Decision: the on_request slot can DENY; on_response and terminal +// slots can only PASSTHROUGH. The dispatcher clamps decisions that +// violate this contract. +// - Metadata: the only side-channel between middlewares. Each +// middleware declares an allowlist of keys it may emit; the merger +// enforces caps and namespace rules. +package middleware + +import "time" + +// Slot identifies where in the request lifecycle a middleware runs. +// A middleware declares a single slot. Splitting per-purpose work +// (request parsing vs response parsing vs cost metering) into separate +// slot-keyed middlewares is the explicit architectural choice for the +// agent-network use case; no middleware participates in more than one +// slot. +type Slot int + +const ( + // SlotOnRequest runs before the upstream call. Middlewares in this + // slot may DENY the request, mutate headers/body (when permitted), + // and emit metadata derived from the request envelope. + SlotOnRequest Slot = 1 + // SlotOnResponse runs after the upstream returns. Middlewares in + // this slot observe the response, emit metadata, and may mutate + // response headers when permitted. They cannot DENY. + SlotOnResponse Slot = 2 + // SlotTerminal runs after every SlotOnResponse middleware has + // emitted. Terminal middlewares observe the full metadata bag and + // ship it to external sinks (access log, metrics export). They + // cannot DENY and cannot mutate the response. + SlotTerminal Slot = 3 +) + +// FailMode controls how the dispatcher reacts when a middleware +// returns an error, times out, or panics. Observer middlewares default +// to FailOpen; policy middlewares should default to FailClosed. +type FailMode int + +const ( + // FailOpen allows the request to proceed when a middleware fails. + FailOpen FailMode = 0 + // FailClosed denies the request when a middleware fails. Only + // meaningful for SlotOnRequest middlewares. + FailClosed FailMode = 1 +) + +// Decision captures the outcome of a middleware invocation as observed +// by the dispatcher. Response-phase middlewares always return +// DecisionPassthrough; the dispatcher clamps any other value. +type Decision int + +const ( + // DecisionAllow lets the request proceed. + DecisionAllow Decision = 0 + // DecisionDeny stops the chain and returns a rendered deny + // response. Only honoured in SlotOnRequest. + DecisionDeny Decision = 1 + // DecisionPassthrough is the response-phase neutral outcome. + DecisionPassthrough Decision = 2 +) + +// Resource limits enforced by the proxy at config apply time and by +// the dispatcher at runtime. Per-target values supplied by management +// are clamped to these bounds. +const ( + // MaxBodyCapBytes is the proxy-wide upper bound for per-direction + // body capture. Sized to hold a full LLM streaming response (token + // usage rides the trailing SSE event, so the captured prefix must + // reach the end of the stream); a single response is bounded by the + // model's max output tokens, so this is a real ceiling, not a + // treadmill. Request capture stays well under this — oversized + // requests use the tolerant routing scan instead of buffering. + MaxBodyCapBytes int64 = 8 << 20 + // MinTimeout is the proxy-wide lower bound for per-middleware + // Invoke timeouts. + MinTimeout = 10 * time.Millisecond + // MaxTimeout is the proxy-wide upper bound for per-middleware + // Invoke timeouts. + MaxTimeout = 5 * time.Second + // DefaultTimeout is used when the per-target timeout is zero or + // unset. + DefaultTimeout = 500 * time.Millisecond + + // MaxMiddlewareMetadataBytes is the per-middleware metadata total + // cap. + MaxMiddlewareMetadataBytes = 16 << 10 + // MaxRequestMetadataBytes is the per-request metadata total cap + // across all middlewares in the chain. Earlier middlewares win + // when the budget is exhausted. + MaxRequestMetadataBytes = 32 << 10 + // MaxMetadataKeyBytes is the maximum length of a metadata key. + MaxMetadataKeyBytes = 96 + // MaxMetadataValueBytes is the maximum length of a metadata value. + MaxMetadataValueBytes = 4 << 10 + // MaxMiddlewaresPerChain caps the number of middleware entries + // accepted per chain at the proxy translator and the management + // REST API. Mirrors the chain invocation cap so a misconfigured + // mapping cannot push the chain clone cost beyond a known bound. + MaxMiddlewaresPerChain = 16 +) + +// KV is the canonical header/metadata representation used across the +// middleware boundary. We use a slice of KV instead of http.Header +// because it preserves key order, is cheap to deep-copy per +// invocation, and is directly representable in a future protobuf +// envelope. +type KV struct { + Key string + Value string +} + +// Input is the immutable envelope handed to each middleware. The +// dispatcher deep-copies Headers, Body, Metadata, RespHeaders, and +// RespBody before each invocation so middlewares cannot mutate the +// shared in-flight copies; mutations must flow through Output.Mutations. +type Input struct { + Slot Slot + RequestID string + TargetID string + Method string + URL string + Headers []KV + Body []byte + BodyTruncated bool + OriginalBodySize int64 + + Status int + RespHeaders []KV + RespBody []byte + RespBodyTruncated bool + OriginalRespSize int64 + + ServiceID string + AccountID string + UserID string + // UserEmail is the calling user's email address when the auth path + // resolves a user record. Empty for non-OIDC schemes (PIN/Password/ + // Header) and for legacy session JWTs minted before the email claim + // was introduced. Identity-stamping middlewares (e.g. + // llm_identity_inject) prefer this over UserID for upstream gateways + // that key budgets / attribution on a human-readable identifier. + UserEmail string + AuthMethod string + SourceIP string + // UserGroups captures the calling peer's group memberships at + // request time, surfaced from the proxy's auth flow so policy-aware + // middlewares can authorise without an extra management round-trip. + UserGroups []string + // UserGroupNames carries the human-readable display names paired + // positionally with UserGroups (UserGroupNames[i] is the name of + // UserGroups[i]). Identity-stamping middlewares prefer names for + // upstream tags so attribution dashboards stay readable. Slice may + // be shorter than UserGroups for tokens minted before names were + // resolvable; consumers should fall back to ids for missing + // positions. + UserGroupNames []string + Metadata []KV + + // AgentNetwork is true when the target is a synthesised + // agent-network service. Carried on the input so the access-log + // terminal middleware can stamp the proto field without re-deriving + // from the service ID. + AgentNetwork bool +} + +// DenyReason is the structured payload a middleware returns alongside +// a DecisionDeny. The proxy renders it through a fixed JSON template +// so middlewares cannot emit arbitrary bytes to the wire. +type DenyReason struct { + Code string + Message string + Details map[string]string +} + +// Output is the value each middleware returns to the dispatcher. The +// dispatcher applies the output filter (clamp, mutations gate) before +// any side effect reaches the shared request. +type Output struct { + Decision Decision + DenyStatus int + DenyReason *DenyReason + Metadata []KV + Mutations *Mutations +} + +// Mutations describes the deltas a middleware wants applied to the +// in-flight request. The dispatcher filters HeadersAdd/HeadersRemove +// through the compiled-in denylist and runs BodyReplace through the +// body policy before anything is applied. RewriteUpstream redirects +// the outbound target (scheme + host) for the request; the chain +// returns the latest non-nil rewrite to the reverse proxy. +type Mutations struct { + HeadersAdd []KV + HeadersRemove []string + BodyReplace []byte + RewriteUpstream *UpstreamRewrite +} + +// UpstreamRewrite redirects the request's outbound target. Only +// scheme+host are honoured; path, query, and body are untouched. The +// reverse proxy reads the rewrite (when non-nil) instead of the +// PathTarget URL configured by the synth, so a single shared synth +// service can fan out to many upstreams selected per request. +// +// AuthHeader and StripHeaders carry the upstream auth substitution +// the router needs. They bypass the framework's HeadersAdd / +// HeadersRemove denylist (which blocks Authorization, Cookie, etc. +// from middleware mutation) on the grounds that the proxy itself is +// the entity rewriting auth here, not an arbitrary middleware. The +// reverse proxy applies them directly to the upstream request after +// the chain's regular mutation phase, so a malicious or misconfigured +// middleware can still emit RewriteUpstream but only the proxy's +// trusted upstream-build path actually unpacks AuthHeader. +type UpstreamRewrite struct { + Scheme string + Host string + // Path, when non-empty, replaces the path component of the + // proxy's effective upstream URL. The rewrite path is then joined + // with the agent's request path by httputil.ProxyRequest.SetURL — + // e.g. rewrite Path="/v1/{account}/{gateway}/compat" + agent + // request "/chat/completions" → outbound + // "/v1/{account}/{gateway}/compat/chat/completions". Used by + // llm_router to honor the operator-configured upstream path on + // gateways like Cloudflare AI Gateway whose URL contains + // account / gateway segments that the agent's app doesn't know + // about. Empty Path leaves the original target's path + // untouched (the historical behavior). + Path string + // StripPathPrefix, when non-empty, is removed from the front of the agent's + // request path before it is joined onto the upstream URL. Used for + // gateway-namespace prefixes (e.g. a client addressing Bedrock as + // "/bedrock/model/{id}/invoke") that must not reach the real upstream, whose + // native path is "/model/{id}/invoke". Empty leaves the request path intact. + StripPathPrefix string + AuthHeader *AuthHeader + StripHeaders []string +} + +// AuthHeader is a single name/value pair the proxy injects on the +// upstream request after stripping the client's auth headers. +type AuthHeader struct { + Name string + Value string +} diff --git a/proxy/internal/proxy/agent_network_chain_realstack_test.go b/proxy/internal/proxy/agent_network_chain_realstack_test.go new file mode 100644 index 000000000..bc611fc98 --- /dev/null +++ b/proxy/internal/proxy/agent_network_chain_realstack_test.go @@ -0,0 +1,321 @@ +package proxy_test + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "net/url" + "runtime" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/management/internals/modules/agentnetwork" + agentNetworkTypes "github.com/netbirdio/netbird/management/internals/modules/agentnetwork/types" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/bodytap" + mwbuiltin "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" + // Side-effect imports register every builtin middleware factory. + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/cost_meter" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_guardrail" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_identity_inject" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_limit_check" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_limit_record" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_request_parser" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_response_parser" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_router" + "github.com/netbirdio/netbird/proxy/internal/proxy" + nbproxytypes "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/shared/management/proto" + + log "github.com/sirupsen/logrus" +) + +// TestReverseProxy_AgentNetworkRequest_FullChain is the self-contained Go +// replacement for the bash 50 + 51 legs. It drives a real agent-network +// request through proxy.ReverseProxy.ServeHTTP with the actual middleware +// chain the synthesizer produces, against an in-process management gRPC and a +// httptest fake upstream — no tilt, no docker, no real LLM provider, no +// WireGuard tunnel. The test guarantees: +// +// 1. The reverse proxy's response-leg input construction copies UserGroups +// onto respInput so llm_limit_record sends a non-empty group_ids field +// on RecordLLMUsage. This is the exact bug class that motivated the +// reverseproxy.go fix — its regression would land the request OK but +// leave consumption at zero, defeating any group-targeted budget rule. +// 2. With settings.RedactPii=true the parsers ship redacted text on both +// llm.request_prompt_raw and llm.response_completion — proving the +// end-to-end wiring (synth → proto → spec → parser config) carries the +// toggle through to runtime emission. +// 3. The full chain (request + response + recorder) runs against a real +// management stack and the consumption row for the bound group dim +// increments. +// +// If any of those three guarantees regresses, this single test fails. +func TestReverseProxy_AgentNetworkRequest_FullChain(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("sqlite store not supported on Windows") + } + + const ( + testAccountID = "acct-fullchain-1" + testAdminUser = "user-admin-1" + adminGroupID = "grp-admins" + providerID = "prov-openai-test" + cluster = "test.proxy.local" + subdomain = "fullchain" + ) + testLogger := log.New() + testLogger.SetLevel(log.PanicLevel) // keep test output clean + + ctx := context.Background() + + // ---- 1. Fake upstream that returns OpenAI-shaped JSON with PII in the + // completion. The reverse proxy's chain will redact this when the synth + // stamps redact_pii=true on the response parser config. + completion := "Sample record: Alice Johnson alice.johnson@example.com SSN 123-45-6789 phone (202) 555-0147 also Bob 202/555/0108" + upstreamBody := []byte(`{"id":"x","model":"gpt-5.4","choices":[{"message":{"role":"assistant","content":"` + completion + `"}}],"usage":{"prompt_tokens":12,"completion_tokens":40,"total_tokens":52}}`) + var upstreamHits atomic.Int64 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamHits.Add(1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(upstreamBody) + })) + t.Cleanup(upstream.Close) + upstreamHost := strings.TrimPrefix(upstream.URL, "http://") + + // ---- 2. In-process management gRPC server (bufconn) backed by a real + // sqlite store + real agentnetwork.Manager. The proxy's middlewares talk + // to this client. + st, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err, "real sqlite test store must come up") + t.Cleanup(cleanup) + + anMgr := agentnetwork.NewManager(st, nil, nil, nil) + server := &mgmtgrpc.ProxyServiceServer{} + server.SetAgentNetworkLimitsService(anMgr) + + lis := bufconn.Listen(1024 * 1024) + srv := grpc.NewServer() + proto.RegisterProxyServiceServer(srv, server) + go func() { _ = srv.Serve(lis) }() + t.Cleanup(srv.Stop) + + conn, err := grpc.NewClient("passthrough:///bufnet", + grpc.WithContextDialer(func(_ context.Context, _ string) (net.Conn, error) { return lis.Dial() }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + mgmtClient := proto.NewProxyServiceClient(conn) + + // ---- 3. Seed account state: settings (redact + capture on), provider + // whose upstream URL points at our fake server, policy (catch-all-allow + // over the Admins group → window=0 path), and a generous budget rule + // targeting Admins so the curl succeeds and we can prove the counter + // increments on the response leg. + require.NoError(t, st.SaveAgentNetworkSettings(ctx, &agentNetworkTypes.Settings{ + AccountID: testAccountID, + Cluster: cluster, + Subdomain: subdomain, + EnablePromptCollection: true, + EnableLogCollection: true, + RedactPii: true, + })) + require.NoError(t, st.SaveAgentNetworkProvider(ctx, &agentNetworkTypes.Provider{ + ID: providerID, + AccountID: testAccountID, + ProviderID: "openai_api", + Name: "openai-fullchain-test", + UpstreamURL: upstream.URL, // router rewrites to this + APIKey: "sk-test", + Enabled: true, + Models: []agentNetworkTypes.ProviderModel{{ID: "gpt-5.4"}}, + SessionPrivateKey: "priv", + SessionPublicKey: "pub", + })) + require.NoError(t, st.SaveAgentNetworkPolicy(ctx, &agentNetworkTypes.Policy{ + ID: "ainpol-fullchain", + AccountID: testAccountID, + Name: "admins-openai", + Enabled: true, + SourceGroups: []string{adminGroupID}, + DestinationProviderIDs: []string{providerID}, + // No token / budget caps → effectiveWindowSeconds=0 → exercises the + // catch-all-allow path that the GC-2 record-on-window=0 fix targets. + })) + require.NoError(t, st.SaveAgentNetworkBudgetRule(ctx, &agentNetworkTypes.AccountBudgetRule{ + ID: "ainbud-admins-fullchain", + AccountID: testAccountID, + Name: "admins-monthly", + Enabled: true, + TargetGroups: []string{adminGroupID}, + Limits: agentNetworkTypes.PolicyLimits{ + TokenLimit: agentNetworkTypes.PolicyTokenLimit{Enabled: true, GroupCap: 1_000_000, UserCap: 1_000_000, WindowSeconds: 60}, + }, + })) + + // ---- 4. Synth the service. This produces the exact middleware chain + // configuration the production reconcile path ships to the proxy. + services, err := agentnetwork.SynthesizeServices(ctx, st, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1, "exactly one synth service expected") + synthSvc := services[0] + require.NotEmpty(t, synthSvc.Targets, "synth target must exist") + + // ---- 5. Wire the middleware framework — same registry the proxy uses + // in production, configured with our bufconn-backed management client. + mwbuiltin.Configure(ctx, t.TempDir(), nil, testLogger, mgmtClient) + registry := mwbuiltin.DefaultRegistry() + mwMetrics, err := middleware.NewMetrics(nil) + require.NoError(t, err) + mwMgr := middleware.NewManager(0, mwMetrics, testLogger) + mwMgr.SetResolver(middleware.NewResolver(registry)) + + // Convert the synth's rpservice.MiddlewareConfig list into proxy + // middleware.Spec values. Mirrors the proto→Spec translation server.go + // does at runtime; kept inline here so the test isn't coupled to the + // proxy server's private translateMiddlewareConfig helper. + specs := make([]middleware.Spec, 0, len(synthSvc.Targets[0].Options.Middlewares)) + for _, mw := range synthSvc.Targets[0].Options.Middlewares { + var slot middleware.Slot + switch mw.Slot { + case rpservice.MiddlewareSlotOnRequest: + slot = middleware.SlotOnRequest + case rpservice.MiddlewareSlotOnResponse: + slot = middleware.SlotOnResponse + case rpservice.MiddlewareSlotTerminal: + slot = middleware.SlotTerminal + default: + t.Fatalf("unknown middleware slot %q on %s", mw.Slot, mw.ID) + } + specs = append(specs, middleware.Spec{ + ID: mw.ID, + Slot: slot, + Enabled: mw.Enabled, + FailMode: middleware.FailOpen, + Timeout: middleware.DefaultTimeout, + RawConfig: append([]byte(nil), mw.ConfigJSON...), + CanMutate: mw.CanMutate, + }) + } + + serviceIDStr := synthSvc.ID + require.NoError(t, mwMgr.Rebuild(serviceIDStr, []middleware.PathTargetBinding{{ + ServiceID: serviceIDStr, + PathID: "/", + Specs: specs, + }})) + + // ---- 6. Build the reverse proxy, with a mapping whose target URL goes + // straight to the fake upstream (the router middleware rewriting upstream + // from the synth's noop placeholder isn't needed when we own the mapping + // in-process — point the target at the fake URL directly so the body + // arrives at the upstream the synth would have routed to). + upstreamURL, err := url.Parse(upstream.URL) + require.NoError(t, err) + + rp := proxy.NewReverseProxy(http.DefaultTransport, "auto", nil, testLogger, proxy.WithMiddlewareManager(mwMgr)) + rp.AddMapping(proxy.Mapping{ + ID: nbproxytypes.ServiceID(serviceIDStr), + AccountID: nbproxytypes.AccountID(testAccountID), + Host: synthSvc.Domain, + Paths: map[string]*proxy.PathTarget{ + "/": { + URL: upstreamURL, + DirectUpstream: true, + AgentNetwork: true, + Middlewares: specs, + CaptureConfig: &bodytap.Config{ + MaxRequestBytes: 1 << 20, + MaxResponseBytes: 1 << 20, + ContentTypes: []string{"application/json", "text/event-stream"}, + }, + }, + }, + }) + + // ---- 7. Send a request with the auth-stamped CapturedData (mimicking + // what the tunnel-peer auth middleware does at the edge of the proxy). + reqBody := `{"model":"gpt-5.4","client_metadata":{"session_id":"sess-fullchain-1"},"messages":[{"role":"user","content":"contact alice.johnson@example.com SSN 987-65-4321 phone (202)555-0156"}]}` + req := httptest.NewRequest("POST", "https://"+synthSvc.Domain+"/v1/chat/completions", strings.NewReader(reqBody)) + req.Host = synthSvc.Domain + req.Header.Set("Content-Type", "application/json") + + cd := proxy.NewCapturedData("test-request-1") + cd.SetServiceID(nbproxytypes.ServiceID(serviceIDStr)) + cd.SetAccountID(nbproxytypes.AccountID(testAccountID)) + cd.SetUserID(testAdminUser) + cd.SetUserGroups([]string{adminGroupID}) + cd.SetAuthMethod("tunnel_peer") + req = req.WithContext(proxy.WithCapturedData(req.Context(), cd)) + + w := httptest.NewRecorder() + rp.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, "upstream call must succeed end-to-end; body=%s", w.Body.String()) + assert.GreaterOrEqual(t, upstreamHits.Load(), int64(1), "fake upstream must have been hit") + + // ---- 8. Assertions — the three guarantees this test exists for. + + // 8a. The reverseproxy.go respInput construction carried UserGroups + // into the response-leg middleware chain, so llm_limit_record sent a + // non-empty group_ids on RecordLLMUsage. Verifying via the management + // store directly bypasses the manager's permission gate (which is nil + // in this test) — we want to confirm the row landed, not who saw it. + require.Eventually(t, func() bool { + rows, lerr := st.ListAgentNetworkConsumption(ctx, store.LockingStrengthNone, testAccountID) + if lerr != nil { + return false + } + for _, r := range rows { + if r.DimensionKind == agentNetworkTypes.DimensionGroup && + r.DimensionID == adminGroupID && + r.WindowSeconds == 60 && + r.TokensInput+r.TokensOutput > 0 { + return true + } + } + return false + }, 5*time.Second, 50*time.Millisecond, + "Admins group consumption row must increment via the response leg — if this fails the proxy's respInput dropped UserGroups again or the parser/recorder wiring is broken") + + // 8b. Both the captured prompt and the captured completion are + // redacted — proves the synth threads redact_pii=true into BOTH parser + // configs and the parsers honour it at emission time. + md := cd.GetMetadata() + promptRaw := md["llm.request_prompt_raw"] + completionMeta := md["llm.response_completion"] + + // 8a-bis. The session id from client_metadata.session_id flows through + // the request parser into the captured metadata, so the access-log / + // usage rows can group this request with the rest of its conversation. + assert.Equal(t, "sess-fullchain-1", md["llm.session_id"], + "session id must be extracted from client_metadata.session_id and carried through the chain") + + assert.NotEmpty(t, promptRaw, "llm.request_prompt_raw must be present in captured metadata") + assert.Contains(t, promptRaw, "[REDACTED:", "captured raw prompt must carry redaction markers") + assert.NotContains(t, promptRaw, "alice.johnson@example.com", "raw email must NOT survive in prompt_raw") + assert.NotContains(t, promptRaw, "987-65-4321", "raw SSN must NOT survive in prompt_raw") + assert.NotContains(t, promptRaw, "(202)555-0156", "raw paren-no-space phone must NOT survive in prompt_raw") + + assert.NotEmpty(t, completionMeta, "llm.response_completion must be present in captured metadata") + assert.Contains(t, completionMeta, "[REDACTED:", "captured completion must carry redaction markers") + assert.NotContains(t, completionMeta, "alice.johnson@example.com", "raw email must NOT survive in completion") + assert.NotContains(t, completionMeta, "123-45-6789", "raw SSN must NOT survive in completion") + assert.NotContains(t, completionMeta, "(202) 555-0147", "raw paren+space phone must NOT survive in completion") + assert.NotContains(t, completionMeta, "202/555/0108", "raw slash phone must NOT survive in completion") + + _ = upstreamHost // kept for future header-inspection assertions if needed +} diff --git a/proxy/internal/proxy/context.go b/proxy/internal/proxy/context.go index e05ec78aa..09bb9a8d1 100644 --- a/proxy/internal/proxy/context.go +++ b/proxy/internal/proxy/context.go @@ -58,9 +58,11 @@ type CapturedData struct { // the JWT's group_names claim or from ValidateSession/Tunnel // responses. Slice may be shorter than userGroups for tokens minted // before names were resolvable. - userGroupNames []string - authMethod string - metadata map[string]string + userGroupNames []string + authMethod string + metadata map[string]string + agentNetwork bool + suppressAccessLog bool } // NewCapturedData creates a CapturedData with the given request ID. @@ -178,6 +180,41 @@ func (c *CapturedData) SetUserGroups(groups []string) { c.userGroups = append(c.userGroups[:0], groups...) } +// SetAgentNetwork records whether the request hit a synthesised +// agent-network target. The terminal access-log middleware stamps the +// flag onto the proto so management can distinguish synthetic traffic. +func (c *CapturedData) SetAgentNetwork(b bool) { + c.mu.Lock() + defer c.mu.Unlock() + c.agentNetwork = b +} + +// GetAgentNetwork reports whether the request matched a synthesised +// agent-network target. +func (c *CapturedData) GetAgentNetwork() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.agentNetwork +} + +// SetSuppressAccessLog records whether the per-request access-log emission +// must be skipped for this request. Stamped from the matched target's +// DisableAccessLog flag so the access-log middleware can short-circuit +// log delivery for opted-out agent-network targets. +func (c *CapturedData) SetSuppressAccessLog(b bool) { + c.mu.Lock() + defer c.mu.Unlock() + c.suppressAccessLog = b +} + +// GetSuppressAccessLog reports whether access-log emission has been +// suppressed for this request. +func (c *CapturedData) GetSuppressAccessLog() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.suppressAccessLog +} + // GetUserGroups returns a copy of the authenticated user's group // memberships. func (c *CapturedData) GetUserGroups() []string { diff --git a/proxy/internal/proxy/reverseproxy.go b/proxy/internal/proxy/reverseproxy.go index da0bf6552..2c0304ecd 100644 --- a/proxy/internal/proxy/reverseproxy.go +++ b/proxy/internal/proxy/reverseproxy.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "encoding/json" "errors" "fmt" "net" @@ -11,10 +12,13 @@ import ( "net/url" "strings" "sync" + "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/bodytap" "github.com/netbirdio/netbird/proxy/internal/roundtrip" "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/web" @@ -32,6 +36,25 @@ type ReverseProxy struct { mappingsMux sync.RWMutex mappings map[string]Mapping logger *log.Logger + // middlewareManager, when non-nil, drives per-target middleware + // dispatch. A nil manager (or an empty chain for the resolved + // target) keeps the reverse-proxy hot path on the no-capture fast + // path with no middleware overhead. + middlewareManager *middleware.Manager +} + +// Option configures optional ReverseProxy behavior. Options exist so the core +// constructor signature stays stable across additive features. +type Option func(*ReverseProxy) + +// WithMiddlewareManager attaches a middleware manager to the reverse +// proxy. When the manager is nil or returns an empty chain for the +// target, the request follows the fast path with no middleware +// overhead. +func WithMiddlewareManager(m *middleware.Manager) Option { + return func(p *ReverseProxy) { + p.middlewareManager = m + } } // NewReverseProxy configures a new NetBird ReverseProxy. @@ -40,29 +63,28 @@ type ReverseProxy struct { // between requested URLs and targets. // The internal mappings can be modified using the AddMapping // and RemoveMapping functions. -func NewReverseProxy(transport http.RoundTripper, forwardedProto string, trustedProxies []netip.Prefix, logger *log.Logger) *ReverseProxy { +func NewReverseProxy(transport http.RoundTripper, forwardedProto string, trustedProxies []netip.Prefix, logger *log.Logger, opts ...Option) *ReverseProxy { if logger == nil { logger = log.StandardLogger() } - return &ReverseProxy{ + p := &ReverseProxy{ transport: transport, forwardedProto: forwardedProto, trustedProxies: trustedProxies, mappings: make(map[string]Mapping), logger: logger, } + for _, opt := range opts { + opt(p) + } + return p } func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { result, exists := p.findTargetForRequest(r) if !exists { - if cd := CapturedDataFromContext(r.Context()); cd != nil { - cd.SetOrigin(OriginNoRoute) - } - requestID := getRequestID(r) - web.ServeErrorPage(w, r, http.StatusNotFound, "Service Not Found", - "The requested service could not be found. Please check the URL, try refreshing, or check if the peer is running. If that doesn't work, see our documentation for help.", - requestID, web.ErrorStatus{Proxy: true, Destination: false}) + p.serveRouteError(w, r, http.StatusNotFound, "Service Not Found", + "The requested service could not be found. Please check the URL, try refreshing, or check if the peer is running. If that doesn't work, see our documentation for help.") return } @@ -72,38 +94,23 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // with 421 (Misdirected Request) so the caller sees an explicit // error instead of silently doubling tunnel traffic. if p.isSelfTargetLoop(r, result.target.URL) { - if cd := CapturedDataFromContext(r.Context()); cd != nil { - cd.SetOrigin(OriginNoRoute) - } - requestID := getRequestID(r) - web.ServeErrorPage(w, r, http.StatusMisdirectedRequest, "Loop Detected", - "This peer is the target of the requested service. Reach the backend directly instead of dialing the public service URL from the same machine.", - requestID, web.ErrorStatus{Proxy: true, Destination: false}) + p.serveRouteError(w, r, http.StatusMisdirectedRequest, "Loop Detected", + "This peer is the target of the requested service. Reach the backend directly instead of dialing the public service URL from the same machine.") return } - ctx := r.Context() - // Set the account ID in the context for the roundtripper to use. - ctx = roundtrip.WithAccountID(ctx, result.accountID) + pt := result.target + ctx := p.buildTargetContext(r.Context(), result) // Populate captured data if it exists (allows middleware to read after handler completes). // This solves the problem of passing data UP the middleware chain: we put a mutable struct // pointer in the context, and mutate the struct here so outer middleware can read it. - if capturedData := CapturedDataFromContext(ctx); capturedData != nil { + capturedData := CapturedDataFromContext(ctx) + if capturedData != nil { capturedData.SetServiceID(result.serviceID) capturedData.SetAccountID(result.accountID) - } - - pt := result.target - - if pt.SkipTLSVerify { - ctx = roundtrip.WithSkipTLSVerify(ctx) - } - if pt.RequestTimeout > 0 { - ctx = types.WithDialTimeout(ctx, pt.RequestTimeout) - } - if pt.DirectUpstream { - ctx = roundtrip.WithDirectUpstream(ctx) + capturedData.SetAgentNetwork(result.target != nil && result.target.AgentNetwork) + capturedData.SetSuppressAccessLog(result.target != nil && result.target.DisableAccessLog) } rewriteMatchedPath := result.matchedPath @@ -111,6 +118,45 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { rewriteMatchedPath = "" } + chain := p.resolveChain(result) + if chain == nil || chain.Empty() { + p.serveDirect(w, r, ctx, result, rewriteMatchedPath) + return + } + p.serveWithChain(w, r, ctx, result, chain, rewriteMatchedPath, capturedData) +} + +// serveRouteError marks the request as un-routed on any captured-data +// context and renders the proxy error page. +func (p *ReverseProxy) serveRouteError(w http.ResponseWriter, r *http.Request, status int, title, message string) { + if cd := CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(OriginNoRoute) + } + web.ServeErrorPage(w, r, status, title, message, getRequestID(r), + web.ErrorStatus{Proxy: true, Destination: false}) +} + +// buildTargetContext layers the per-target roundtrip flags (account id, +// TLS-verify skip, direct upstream, dial timeout) onto the request context. +func (p *ReverseProxy) buildTargetContext(ctx context.Context, result targetResult) context.Context { + pt := result.target + ctx = roundtrip.WithAccountID(ctx, result.accountID) + if pt.SkipTLSVerify { + ctx = roundtrip.WithSkipTLSVerify(ctx) + } + if pt.DirectUpstream { + ctx = roundtrip.WithDirectUpstream(ctx) + } + if pt.RequestTimeout > 0 { + ctx = types.WithDialTimeout(ctx, pt.RequestTimeout) + } + return ctx +} + +// serveDirect forwards the request without a middleware chain — the common +// path for plain reverse-proxy targets. +func (p *ReverseProxy) serveDirect(w http.ResponseWriter, r *http.Request, ctx context.Context, result targetResult, rewriteMatchedPath string) { + pt := result.target rp := &httputil.ReverseProxy{ Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders, result.stripAuthHeaders), Transport: p.transport, @@ -123,6 +169,344 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { rp.ServeHTTP(w, r.WithContext(ctx)) } +// serveWithChain runs the per-target middleware chain around the upstream +// request: request-leg capture and authorisation, then (on allow) the +// upstream forward with response/terminal observation deferred so it reads +// the captured response before the writer is released. +func (p *ReverseProxy) serveWithChain(w http.ResponseWriter, r *http.Request, ctx context.Context, result targetResult, chain *middleware.Chain, rewriteMatchedPath string, capturedData *CapturedData) { + middlewareIDs := chain.IDs() + p.logger.Debugf("middleware chain matched: service=%s path=%s middlewares=%v", result.serviceID, result.matchedPath, middlewareIDs) + + capturedBody, truncated, originalSize, releaseBudget := p.captureRequestForChain(ctx, r, result, capturedData) + defer releaseBudget() + + acc := middleware.NewAccumulator(middleware.MaxRequestMetadataBytes) + reqInput := buildRequestInput(r, result, capturedData, capturedBody, truncated, originalSize) + + denyOutput, requestMeta, upstreamRewrite, _ := chain.RunRequest(ctx, r, reqInput, acc) + if capturedData != nil { + for _, kv := range requestMeta { + capturedData.SetMetadata(kv.Key, kv.Value) + } + } + if denyOutput != nil { + p.serveDeny(w, denyOutput, result, middlewareIDs) + return + } + + respWriter, capturingWriter := p.newResponseWriter(ctx, w, result, capturedData) + if capturingWriter != nil { + defer capturingWriter.Release() + defer p.observeResponse(ctx, chain, acc, reqInput, requestMeta, capturingWriter, w, capturedData, result, middlewareIDs) + } + + p.forwardUpstream(respWriter, r, ctx, result, rewriteMatchedPath, upstreamRewrite) +} + +// captureRequestForChain copies the request body for inspection by the +// chain, records any capture bypass, and applies agent-network routing +// recovery for oversized bodies. The returned release frees the capture +// budget and must be deferred by the caller. +func (p *ReverseProxy) captureRequestForChain(ctx context.Context, r *http.Request, result targetResult, capturedData *CapturedData) ([]byte, bool, int64, func()) { + pt := result.target + capturedBody, truncated, originalSize, bypass, releaseBudget, captureErr := bodytap.CaptureRequest(r, pt.CaptureConfig, p.middlewareManager.Budget()) + if captureErr != nil { + p.logger.Debugf("middleware request body capture error: %v", captureErr) + } + if bypass != "" { + if capturedData != nil { + capturedData.SetMetadata("mw.capture.bypass_reason", bypass) + } + p.middlewareManager.Metrics().IncCaptureBypass(ctx, string(result.serviceID), bypass) + } + + // Routing recovery for oversized agent-network requests: when the body + // exceeded the capture cap (bypassed or truncated), the captured copy + // can't be parsed for the model, so llm_router would deny with + // model_not_routable. Scan the full stream for just the routing fields + // and hand the request parser a minimal stub so routing succeeds; the + // prompt stays uncaptured and the upstream still gets the full body. + if pt.AgentNetwork && (truncated || capturedBody == nil) { + if model, stream, ok := bodytap.ScanRoutingFields(r, bodytap.MaxRoutingScanBytes); ok { + capturedBody = buildRoutingStub(model, stream) + truncated = false + p.logger.Debugf("agent-network routing recovery: extracted model=%s stream=%t from oversized request body (service=%s)", model, stream, result.serviceID) + } + } + return capturedBody, truncated, originalSize, releaseBudget +} + +// serveDeny renders the chain's deny response. Policy/budget/routing/guardrail +// denials are expected runtime outcomes and can be high-volume under +// misconfigured or hostile clients; per-request detail stays at Debug and +// metrics/access logs carry the signal at scale. +func (p *ReverseProxy) serveDeny(w http.ResponseWriter, denyOutput *middleware.Output, result targetResult, middlewareIDs []string) { + middlewareID := "middleware" + if denyOutput.DenyReason != nil && denyOutput.DenyReason.Code != "" { + middlewareID = denyOutput.DenyReason.Code + } + p.logger.Debugf("middleware chain denied request: service=%s path=%s middlewares=%v reason=%s status=%d", + result.serviceID, result.matchedPath, middlewareIDs, middlewareID, denyOutput.DenyStatus) + middleware.RenderDenyResponse(w, middlewareID, denyOutput.DenyReason, denyOutput.DenyStatus) +} + +// newResponseWriter returns the writer the upstream forward should use. When +// response capture is enabled and not bypassed it wraps w in a capturing +// writer (also returned so the caller can release it and feed the response +// leg); otherwise the capturing writer is nil and w is used directly. +func (p *ReverseProxy) newResponseWriter(ctx context.Context, w http.ResponseWriter, result targetResult, capturedData *CapturedData) (http.ResponseWriter, *bodytap.CapturingResponseWriter) { + pt := result.target + if pt.CaptureConfig == nil || pt.CaptureConfig.MaxResponseBytes <= 0 { + return w, nil + } + capturingWriter := bodytap.NewCapturingResponseWriter(w, pt.CaptureConfig.MaxResponseBytes, p.middlewareManager.Budget()) + if capturingWriter.Bypassed() { + if capturedData != nil { + capturedData.SetMetadata("mw.capture.bypass_reason", capturingWriter.BypassReason()) + } + p.middlewareManager.Metrics().IncCaptureBypass(ctx, string(result.serviceID), capturingWriter.BypassReason()) + capturingWriter.Release() + return w, nil + } + return capturingWriter, capturingWriter +} + +// observeResponse runs the response and terminal middleware slots after the +// body has been forwarded. It is deferred by serveWithChain so it reads the +// captured response before the writer is released. +func (p *ReverseProxy) observeResponse(ctx context.Context, chain *middleware.Chain, acc *middleware.Accumulator, reqInput *middleware.Input, requestMeta []middleware.KV, capturingWriter *bodytap.CapturingResponseWriter, w http.ResponseWriter, capturedData *CapturedData, result targetResult, middlewareIDs []string) { + respInput := &middleware.Input{ + Slot: middleware.SlotOnResponse, + RequestID: reqInput.RequestID, + TargetID: reqInput.TargetID, + Method: reqInput.Method, + URL: reqInput.URL, + Headers: reqInput.Headers, + Status: capturingWriter.Status(), + RespHeaders: headerToKV(w.Header()), + RespBody: capturingWriter.Body(), + RespBodyTruncated: capturingWriter.Truncated(), + OriginalRespSize: capturingWriter.BytesWritten(), + ServiceID: reqInput.ServiceID, + AccountID: reqInput.AccountID, + UserID: reqInput.UserID, + // UserEmail / UserGroups / UserGroupNames must flow into the + // response leg too — llm_limit_record needs UserGroups to send + // group_ids on RecordLLMUsage so management's account-budget + // fan-out can match group-targeted rules; identity-stamping and + // any future response-side authorisation also depend on these. + UserEmail: reqInput.UserEmail, + UserGroups: reqInput.UserGroups, + UserGroupNames: reqInput.UserGroupNames, + AuthMethod: reqInput.AuthMethod, + SourceIP: reqInput.SourceIP, + Metadata: requestMeta, + AgentNetwork: reqInput.AgentNetwork, + } + // The response/terminal phase runs after the body is forwarded, so + // a streaming client (e.g. Codex) has usually disconnected by now, + // cancelling r.Context(). These middlewares only observe and record + // (token/cost metering, usage recording) and must still complete — + // otherwise the dispatcher short-circuits each to fail-mode and the + // usage is silently lost. Detach from client cancellation, keep ctx + // values, and bound the work. + obsCtx, obsCancel := context.WithTimeout(context.WithoutCancel(ctx), observabilityPhaseTimeout) + defer obsCancel() + + respMeta := chain.RunResponse(obsCtx, respInput, acc) + if capturedData != nil { + for _, kv := range respMeta { + capturedData.SetMetadata(kv.Key, kv.Value) + } + } + + // Terminal slot sees the merged metadata bag from request and + // response phases. + mergedMeta := append(append([]middleware.KV(nil), requestMeta...), respMeta...) + termInput := *respInput + termInput.Slot = middleware.SlotTerminal + termInput.Metadata = mergedMeta + termMeta := chain.RunTerminal(obsCtx, &termInput, acc) + if capturedData != nil { + for _, kv := range termMeta { + capturedData.SetMetadata(kv.Key, kv.Value) + } + } + + p.logger.Debugf("middleware chain ran: service=%s path=%s middlewares=%v status=%d req_meta=%d resp_meta=%d term_meta=%d", + result.serviceID, result.matchedPath, middlewareIDs, capturingWriter.Status(), len(requestMeta), len(respMeta), len(termMeta)) +} + +// forwardUpstream applies any middleware-emitted upstream rewrite and proxies +// the request to the effective upstream URL. +func (p *ReverseProxy) forwardUpstream(respWriter http.ResponseWriter, r *http.Request, ctx context.Context, result targetResult, rewriteMatchedPath string, upstreamRewrite *middleware.UpstreamRewrite) { + pt := result.target + effectiveURL := applyUpstreamRewrite(pt.URL, upstreamRewrite) + if upstreamRewrite != nil { + r.Host = effectiveURL.Host + applyUpstreamHeaders(r, upstreamRewrite) + stripUpstreamPathPrefix(r, upstreamRewrite.StripPathPrefix) + } + + rp := &httputil.ReverseProxy{ + Rewrite: p.rewriteFunc(effectiveURL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders, result.stripAuthHeaders), + Transport: p.transport, + FlushInterval: -1, + ErrorHandler: p.proxyErrorHandler, + } + if result.rewriteRedirects { + rp.ModifyResponse = p.rewriteLocationFunc(effectiveURL, rewriteMatchedPath, r) //nolint:bodyclose + } + rp.ServeHTTP(respWriter, r.WithContext(ctx)) +} + +// buildRoutingStub returns a minimal JSON request body carrying only the +// model and stream fields. It feeds the LLM request parser when the real +// body was too large to capture: the parser emits llm.model / llm.stream +// so llm_router can route, while ExtractPrompt on the stub yields nothing +// — no prompt is captured for oversized requests. +func buildRoutingStub(model string, stream bool) []byte { + b, err := json.Marshal(map[string]any{"model": model, "stream": stream}) + if err != nil { + return nil + } + return b +} + +// applyUpstreamRewrite returns the effective upstream URL after +// applying a middleware-emitted rewrite. When rewrite is nil or +// incomplete, the original target is returned unchanged. The original +// URL is never mutated; a clone is returned when a rewrite applies. +// +// Rewrite Path semantics: when non-empty, replaces the cloned URL's +// path entirely. httputil.ProxyRequest.SetURL then joins target.Path +// with the agent's request path, so an operator-configured upstream +// path like "/v1/{account}/{gateway}/compat" gets prepended to +// "/chat/completions" yielding the full Cloudflare-shaped path. +// Empty rewrite.Path preserves the original target's path (the +// historical, non-agent-network behavior). +func applyUpstreamRewrite(orig *url.URL, rewrite *middleware.UpstreamRewrite) *url.URL { + if rewrite == nil || orig == nil { + return orig + } + if rewrite.Scheme == "" || rewrite.Host == "" { + return orig + } + cloned := *orig + cloned.Scheme = rewrite.Scheme + cloned.Host = rewrite.Host + if rewrite.Path != "" { + cloned.Path = rewrite.Path + cloned.RawPath = "" + } + return &cloned +} + +// stripUpstreamPathPrefix removes a gateway-namespace prefix (e.g. "/bedrock") +// from the request path before it is forwarded, so the upstream receives its +// native path. The chain has already run by this point, so metering/logging +// keep the original client path; only the outbound path is rewritten. RawPath +// is cleared so the escaped form is recomputed from the trimmed Path. +func stripUpstreamPathPrefix(r *http.Request, prefix string) { + if r == nil || r.URL == nil || prefix == "" { + return + } + if !strings.HasPrefix(r.URL.Path, prefix+"/") && r.URL.Path != prefix { + return + } + r.URL.Path = strings.TrimPrefix(r.URL.Path, prefix) + if r.URL.Path == "" { + r.URL.Path = "/" + } + r.URL.RawPath = "" +} + +// applyUpstreamHeaders strips the headers the rewrite asks for and +// injects the resolved auth header on the in-flight request. It is +// the proxy-trusted counterpart to chain.applyMutations: regular +// middleware HeadersAdd/HeadersRemove pass through the framework +// denylist (which blocks Authorization, Cookie, etc.), but the +// router middleware needs to replace Authorization on the upstream +// request as a first-class operation. AuthHeader/StripHeaders ride +// on UpstreamRewrite so only the proxy's upstream-build path +// unpacks them — middlewares can't smuggle these in via the +// regular mutation surface. +func applyUpstreamHeaders(r *http.Request, rewrite *middleware.UpstreamRewrite) { + if r == nil || rewrite == nil { + return + } + for _, name := range rewrite.StripHeaders { + if name == "" { + continue + } + r.Header.Del(name) + } + if rewrite.AuthHeader != nil && rewrite.AuthHeader.Name != "" { + r.Header.Set(rewrite.AuthHeader.Name, rewrite.AuthHeader.Value) + } +} + +// resolveChain returns the middleware chain registered for the +// resolved target, or nil when middleware is disabled for the proxy +// or the target. +func (p *ReverseProxy) resolveChain(result targetResult) *middleware.Chain { + if p.middlewareManager == nil { + return nil + } + return p.middlewareManager.ChainFor(string(result.serviceID), result.matchedPath) +} + +// buildRequestInput gathers the per-request fields the middleware +// chain needs. Body and captured metadata are passed in; the rest are +// copied from the request and CapturedData. +func buildRequestInput(r *http.Request, result targetResult, cd *CapturedData, body []byte, truncated bool, originalSize int64) *middleware.Input { + in := &middleware.Input{ + Slot: middleware.SlotOnRequest, + TargetID: result.matchedPath, + Method: r.Method, + URL: r.URL.String(), + Headers: headerToKV(r.Header), + Body: body, + BodyTruncated: truncated, + OriginalBodySize: originalSize, + ServiceID: string(result.serviceID), + AccountID: string(result.accountID), + AgentNetwork: result.target != nil && result.target.AgentNetwork, + } + if cd != nil { + in.RequestID = cd.GetRequestID() + in.UserID = cd.GetUserID() + in.UserEmail = cd.GetUserEmail() + in.UserGroups = cd.GetUserGroups() + in.UserGroupNames = cd.GetUserGroupNames() + in.AuthMethod = cd.GetAuthMethod() + if ip := cd.GetClientIP(); ip.IsValid() { + in.SourceIP = ip.String() + } + } + return in +} + +// headerToKV flattens an http.Header into the KV slice shape expected +// by the middleware envelope, preserving value order under the same +// key. +func headerToKV(h http.Header) []middleware.KV { + if len(h) == 0 { + return nil + } + total := 0 + for _, v := range h { + total += len(v) + } + out := make([]middleware.KV, 0, total) + for k, vs := range h { + for _, v := range vs { + out = append(out, middleware.KV{Key: k, Value: v}) + } + } + return out +} + // isSelfTargetLoop reports whether an overlay-origin request is about to // be forwarded back to the very peer that initiated it. The detection // is intentionally narrow: it only fires when the request arrived on @@ -486,6 +870,14 @@ const ( // comma or any non-printable byte are dropped at stamp time so the // list is unambiguously splittable by consumers. headerNetBirdGroups = "X-NetBird-Groups" + + // observabilityPhaseTimeout bounds the detached response/terminal + // metering phase. It runs after the client connection (and its context) + // may be gone, so it can't borrow the request deadline; this ceiling + // keeps a slow management round-trip (RecordLLMUsage) from pinning the + // handler goroutine indefinitely while still allowing each middleware + // its own per-invoke timeout. + observabilityPhaseTimeout = 30 * time.Second ) // isHeaderValueSafe reports whether v is a valid RFC 7230 field-value: diff --git a/proxy/internal/proxy/reverseproxy_test.go b/proxy/internal/proxy/reverseproxy_test.go index a8244fa56..9bd427056 100644 --- a/proxy/internal/proxy/reverseproxy_test.go +++ b/proxy/internal/proxy/reverseproxy_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/middleware" "github.com/netbirdio/netbird/proxy/internal/roundtrip" "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/web" @@ -1407,3 +1408,45 @@ func TestStampNetBirdIdentity_CapturedDataPresentButEmpty(t *testing.T) { assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups), "X-NetBird-Groups must be stripped when CapturedData has no groups") } + +// TestBuildRequestInput_PropagatesIdentityAndGroups locks the final wiring link +// between auth and the middleware chain: CapturedData identity (user, groups, +// auth method, client IP) and the target's AgentNetwork flag must land on the +// middleware Input the chain runs against. If UserGroups stops flowing here, +// llm_router denies every request with no_authorised_provider. +func TestBuildRequestInput_PropagatesIdentityAndGroups(t *testing.T) { + cd := NewCapturedData("req-123") + cd.SetUserID("user-1") + cd.SetUserEmail("user@example.com") + cd.SetUserGroups([]string{"grp-admins", "grp-users"}) + cd.SetUserGroupNames([]string{"Admins", "Users"}) + cd.SetAuthMethod("oidc") + cd.SetClientIP(netip.MustParseAddr("100.90.1.14")) + + r := httptest.NewRequest(http.MethodPost, "http://agent.example.com/v1/chat/completions", nil) + r.Header.Set("Content-Type", "application/json") + + result := targetResult{ + target: &PathTarget{AgentNetwork: true}, + matchedPath: "/", + serviceID: types.ServiceID("svc-1"), + accountID: types.AccountID("acct-1"), + } + + body := []byte(`{"model":"gpt-5.4"}`) + in := buildRequestInput(r, result, cd, body, false, int64(len(body))) + + require.NotNil(t, in, "buildRequestInput must return an envelope") + assert.Equal(t, middleware.SlotOnRequest, in.Slot, "request input runs in the on-request slot") + assert.Equal(t, "svc-1", in.ServiceID, "service id must propagate") + assert.Equal(t, "acct-1", in.AccountID, "account id must propagate") + assert.Equal(t, "user-1", in.UserID, "user id must propagate") + assert.Equal(t, "user@example.com", in.UserEmail, "user email must propagate") + assert.Equal(t, []string{"grp-admins", "grp-users"}, in.UserGroups, + "CapturedData groups MUST reach the middleware Input — llm_router authorises against this") + assert.Equal(t, []string{"Admins", "Users"}, in.UserGroupNames, "group names must propagate") + assert.Equal(t, "oidc", in.AuthMethod, "auth method must propagate") + assert.Equal(t, "100.90.1.14", in.SourceIP, "client IP must propagate") + assert.True(t, in.AgentNetwork, "agent-network target flag must reach the Input") + assert.Equal(t, body, in.Body, "captured body must reach the Input") +} diff --git a/proxy/internal/proxy/servicemapping.go b/proxy/internal/proxy/servicemapping.go index 46b4d2e8d..64fccc42a 100644 --- a/proxy/internal/proxy/servicemapping.go +++ b/proxy/internal/proxy/servicemapping.go @@ -8,6 +8,8 @@ import ( "strings" "time" + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/bodytap" "github.com/netbirdio/netbird/proxy/internal/types" ) @@ -32,6 +34,20 @@ type PathTarget struct { // over the embedded NetBird WireGuard client when forwarding requests // to this target. Default false → embedded client (existing behaviour). DirectUpstream bool + // Middlewares is the validated per-target middleware chain. Nil or empty + // for non-agent-network targets, keeping them on the no-middleware fast path. + Middlewares []middleware.Spec + // CaptureConfig holds the per-target body-capture limits used by the + // middleware chain. Nil for targets without body-inspecting middlewares. + CaptureConfig *bodytap.Config + // AgentNetwork marks this target as a synthesised agent-network target so + // the proxy can tag access-log entries and gate agent-network behaviour. + AgentNetwork bool + // DisableAccessLog suppresses the per-request access-log emission for this + // target. Defaults false so non-agent-network targets continue to log + // unchanged. The agent-network synthesizer sets this true only when the + // account's EnableLogCollection toggle is off. + DisableAccessLog bool } // Mapping describes how a domain is routed by the HTTP reverse proxy. diff --git a/proxy/internal/proxy/strip_prefix_test.go b/proxy/internal/proxy/strip_prefix_test.go new file mode 100644 index 000000000..4ff364f6a --- /dev/null +++ b/proxy/internal/proxy/strip_prefix_test.go @@ -0,0 +1,30 @@ +package proxy + +import ( + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStripUpstreamPathPrefix(t *testing.T) { + cases := []struct { + name string + path string + prefix string + want string + }{ + {"strips matching namespace prefix", "/bedrock/model/x/invoke", "/bedrock", "/model/x/invoke"}, + {"no-op when prefix absent", "/model/x/invoke", "/bedrock", "/model/x/invoke"}, + {"no-op on empty prefix", "/bedrock/model/x/invoke", "", "/bedrock/model/x/invoke"}, + {"no-op on non-segment match", "/bedrockfoo/model/x", "/bedrock", "/bedrockfoo/model/x"}, + {"bare prefix collapses to root", "/bedrock", "/bedrock", "/"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r := httptest.NewRequest("POST", tc.path, nil) + stripUpstreamPathPrefix(r, tc.prefix) + assert.Equal(t, tc.want, r.URL.Path, "stripped path for %q", tc.path) + }) + } +} diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index 13d386da2..cb2e7f930 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -8,6 +8,8 @@ import ( "net" "net/http" "net/netip" + "os" + "strings" "sync" "time" @@ -347,8 +349,20 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account "public_key": publicKey.String(), }).Info("proxy peer authenticated successfully with management") + // Embedded client log level: warn by default (quiet in production); set + // NB_PROXY_CLIENT_LOG_LEVEL (e.g. "trace") to surface the embedded NetBird + // client's relay / signal / handshake detail for local debugging. + clientLogLevel := log.WarnLevel.String() + if v := strings.TrimSpace(os.Getenv("NB_PROXY_CLIENT_LOG_LEVEL")); v != "" { + if lvl, err := log.ParseLevel(v); err == nil { + clientLogLevel = lvl.String() + } else { + n.logger.Warnf("invalid NB_PROXY_CLIENT_LOG_LEVEL %q, using %q: %v", v, clientLogLevel, err) + } + } + n.initLogOnce.Do(func() { - if err := util.InitLog(log.WarnLevel.String(), util.LogConsole); err != nil { + if err := util.InitLog(clientLogLevel, util.LogConsole); err != nil { n.logger.WithField("account_id", accountID).Warnf("failed to initialize embedded client logging: %v", err) } }) @@ -356,11 +370,11 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account // Create embedded NetBird client with the generated private key. // The peer has already been created via CreateProxyPeer RPC with the public key. wgPort := int(n.clientCfg.WGPort) - client, err := embed.New(embed.Options{ + embedOpts := embed.Options{ DeviceName: deviceNamePrefix + n.proxyID, ManagementURL: n.clientCfg.MgmtAddr, PrivateKey: privateKey.String(), - LogLevel: log.WarnLevel.String(), + LogLevel: clientLogLevel, BlockInbound: n.clientCfg.BlockInbound, // The embedded proxy peer must never be a stepping stone into // the proxy host's LAN: it only exists to reach NetBird mesh @@ -371,7 +385,9 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account WireguardPort: &wgPort, PreSharedKey: n.clientCfg.PreSharedKey, Performance: n.clientCfg.Performance, - }) + } + logEmbedOptions(n.logger, accountID, serviceID, publicKey.String(), embedOpts) + client, err := embed.New(embedOpts) if err != nil { return nil, fmt.Errorf("create netbird client: %w", err) } @@ -847,3 +863,53 @@ func DirectUpstreamFromContext(ctx context.Context) bool { v, _ := ctx.Value(directUpstreamContextKey{}).(bool) return v } + +// logEmbedOptions emits a single structured INFO line summarising every +// operationally meaningful flag handed to embed.New for this per-account +// client. Secrets (PrivateKey, PreSharedKey) are reduced to a "present" +// boolean — never logged verbatim. Use this when an embedded peer +// silently misbehaves: most failure modes (inbound drops, wrong +// management URL, v6 unexpectedly on, userspace flipped, port clash) +// are obvious from these flags before any traffic flows. +func logEmbedOptions(logger *log.Logger, accountID types.AccountID, serviceID types.ServiceID, publicKey string, opts embed.Options) { + wgPort := 0 + if opts.WireguardPort != nil { + wgPort = *opts.WireguardPort + } + mtu := uint16(0) + if opts.MTU != nil { + mtu = *opts.MTU + } + perfBuffers := uint32(0) + if opts.Performance.PreallocatedBuffersPerPool != nil { + perfBuffers = *opts.Performance.PreallocatedBuffersPerPool + } + perfBatch := uint32(0) + if opts.Performance.MaxBatchSize != nil { + perfBatch = *opts.Performance.MaxBatchSize + } + logger.WithFields(log.Fields{ + "account_id": accountID, + "service_id": serviceID, + "public_key": publicKey, + "device_name": opts.DeviceName, + "management_url": opts.ManagementURL, + "log_level": opts.LogLevel, + "wg_port": wgPort, + "mtu": mtu, + "block_inbound": opts.BlockInbound, + "block_lan_access": opts.BlockLANAccess, + "disable_ipv6": opts.DisableIPv6, + "disable_client_routes": opts.DisableClientRoutes, + "no_userspace": opts.NoUserspace, + "config_path_set": opts.ConfigPath != "", + "state_path_set": opts.StatePath != "", + "private_key_present": opts.PrivateKey != "", + "presharedkey_present": opts.PreSharedKey != "", + "setup_key_present": opts.SetupKey != "", + "jwt_token_present": opts.JWTToken != "", + "dns_labels": opts.DNSLabels, + "perf_buffers_per_pool": perfBuffers, + "perf_max_batch_size": perfBatch, + }).Info("starting embedded netbird client for account") +} diff --git a/proxy/internal/tcp/accept.go b/proxy/internal/tcp/accept.go new file mode 100644 index 000000000..a63560a9e --- /dev/null +++ b/proxy/internal/tcp/accept.go @@ -0,0 +1,85 @@ +package tcp + +import ( + "context" + "errors" + "net" + "strings" + "time" +) + +// gvisorInvalidEndpointMsg is the canonical text gVisor netstack returns +// when Accept() is called on a listener whose underlying endpoint has +// been destroyed (peer rekey, embedded-client reset, account churn). +// There is no exported sentinel from gvisor.dev/gvisor/pkg/tcpip that +// survives gonet's *net.OpError wrapping in a way errors.Is can match, +// so we fall back to a string check. Stable across the gVisor versions +// netbird pins. +const gvisorInvalidEndpointMsg = "endpoint is in invalid state" + +// IsClosedListenerErr reports whether err signals that an accept loop +// should exit because the underlying listener can no longer serve +// connections. It recognises: +// +// - net.ErrClosed for stdlib listeners (Listener.Close was called). +// - gVisor's "endpoint is in invalid state" for netstack-backed +// listeners whose endpoint was destroyed out from under them +// (typically when a per-account WireGuard netstack is reset without +// also tearing the listener entry down). +// +// Without the gVisor branch an accept loop on a netstack listener spins +// CPU-hot forever after the endpoint dies, because Accept never blocks +// again and the error neither matches net.ErrClosed nor cancels ctx. +func IsClosedListenerErr(err error) bool { + if err == nil { + return false + } + if errors.Is(err, net.ErrClosed) { + return true + } + return strings.Contains(err.Error(), gvisorInvalidEndpointMsg) +} + +// AcceptBackoff implements the exponential backoff used by +// net/http.Server.Serve for transient Accept errors. Without it a loop +// hitting a sticky unknown error burns a full CPU core. The zero value +// is ready to use; call Reset after a successful Accept. +type AcceptBackoff struct { + delay time.Duration +} + +// minAcceptDelay / maxAcceptDelay mirror the stdlib defaults +// (net/http.Server.Serve) and keep us well below 1 log line per second +// per orphaned listener. +const ( + minAcceptDelay = 5 * time.Millisecond + maxAcceptDelay = time.Second +) + +// Backoff waits the next exponential delay (5ms doubling up to 1s) and +// returns true when the wait completed. Returns false if ctx fired +// during the wait — callers should treat that as "exit the loop". +func (b *AcceptBackoff) Backoff(ctx context.Context) bool { + b.advance() + select { + case <-ctx.Done(): + return false + case <-time.After(b.delay): + return true + } +} + +// Reset clears the accumulated delay so the next failure starts at the +// minimum delay again. Call after a successful Accept. +func (b *AcceptBackoff) Reset() { b.delay = 0 } + +func (b *AcceptBackoff) advance() { + if b.delay == 0 { + b.delay = minAcceptDelay + } else { + b.delay *= 2 + } + if b.delay > maxAcceptDelay { + b.delay = maxAcceptDelay + } +} diff --git a/proxy/internal/tcp/accept_test.go b/proxy/internal/tcp/accept_test.go new file mode 100644 index 000000000..b2824d38a --- /dev/null +++ b/proxy/internal/tcp/accept_test.go @@ -0,0 +1,142 @@ +package tcp + +import ( + "context" + "errors" + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestIsClosedListenerErr_NetErrClosed verifies the stdlib path: a +// closed *net.Listener returns net.ErrClosed wrapped in *net.OpError, +// and IsClosedListenerErr must unwrap it. +func TestIsClosedListenerErr_NetErrClosed(t *testing.T) { + wrapped := &net.OpError{Op: "accept", Net: "tcp", Err: net.ErrClosed} + assert.True(t, IsClosedListenerErr(wrapped), + "net.OpError wrapping net.ErrClosed must be recognised as closed") +} + +// TestIsClosedListenerErr_GVisorInvalidEndpoint is the load-bearing +// regression guard. A gVisor netstack listener whose endpoint has been +// destroyed returns this exact text. Without recognising it the accept +// loop spins forever and burns a CPU core. +func TestIsClosedListenerErr_GVisorInvalidEndpoint(t *testing.T) { + err := fmt.Errorf("accept tcp 10.10.1.254:80: endpoint is in invalid state") + assert.True(t, IsClosedListenerErr(err), + "gVisor 'endpoint is in invalid state' must be recognised as closed") +} + +// TestIsClosedListenerErr_OtherError confirms we don't over-match — +// transient errors must keep returning false so the backoff path runs. +func TestIsClosedListenerErr_OtherError(t *testing.T) { + cases := []error{ + errors.New("temporary failure"), + errors.New("accept tcp 10.10.1.254:80: too many open files"), + nil, + } + for _, c := range cases { + assert.False(t, IsClosedListenerErr(c), + "unexpected match on %v — must not be treated as closed", c) + } +} + +// TestAcceptBackoff_ProgressionAndCap asserts the doubling schedule: +// 5ms, 10ms, 20ms, 40ms, ... capped at 1s. The test runs against a +// real timer but uses tight bounds so a slow CI machine still passes. +func TestAcceptBackoff_ProgressionAndCap(t *testing.T) { + var b AcceptBackoff + expected := []time.Duration{ + 5 * time.Millisecond, + 10 * time.Millisecond, + 20 * time.Millisecond, + 40 * time.Millisecond, + } + for i, want := range expected { + start := time.Now() + ok := b.Backoff(context.Background()) + elapsed := time.Since(start) + require.True(t, ok, "Backoff %d must complete; ctx is alive", i) + assert.GreaterOrEqual(t, elapsed, want, + "backoff %d (%v) must wait at least the configured delay", i, want) + assert.Less(t, elapsed, want*4, + "backoff %d (%v) must not overshoot by more than 4x — caps misbehaving", i, want) + } + + // Burn enough rounds to reach the cap, then assert subsequent + // rounds stay at exactly maxAcceptDelay (1s) — the timer should + // never exceed it. + for range 6 { + b.Backoff(context.Background()) + } + assert.Equal(t, maxAcceptDelay, b.delay, + "after enough doublings the delay must clamp to maxAcceptDelay") +} + +// TestAcceptBackoff_Reset confirms that a successful Accept resets the +// schedule — a busy-then-quiet listener mustn't stay on a 1s timer +// after recovery. +func TestAcceptBackoff_Reset(t *testing.T) { + var b AcceptBackoff + for range 5 { + b.Backoff(context.Background()) + } + require.NotEqual(t, time.Duration(0), b.delay, "precondition: delay must have accumulated") + + b.Reset() + assert.Equal(t, time.Duration(0), b.delay, "Reset must zero the delay") + + start := time.Now() + ok := b.Backoff(context.Background()) + elapsed := time.Since(start) + require.True(t, ok, "Backoff after Reset must complete") + assert.GreaterOrEqual(t, elapsed, minAcceptDelay, + "after Reset the next backoff must restart at minAcceptDelay") + assert.Less(t, elapsed, 50*time.Millisecond, + "after Reset the next backoff must NOT carry over the prior delay") +} + +// TestAcceptBackoff_CancelDuringWait proves the loop exits promptly +// when ctx fires mid-wait. Without this, a tear-down would still take +// up to 1 second per orphaned listener. +func TestAcceptBackoff_CancelDuringWait(t *testing.T) { + var b AcceptBackoff + // Drive the backoff up so the next call will wait ~1s — long + // enough that we can detect early cancellation. + for range 10 { + b.Backoff(context.Background()) + } + require.Equal(t, maxAcceptDelay, b.delay) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + + start := time.Now() + ok := b.Backoff(ctx) + elapsed := time.Since(start) + assert.False(t, ok, "Backoff must return false when ctx is cancelled mid-wait") + assert.Less(t, elapsed, 200*time.Millisecond, + "cancellation must short-circuit the timer; took %v", elapsed) +} + +// TestAcceptBackoff_CancelBeforeCall — when ctx is already done the +// loop exits without sleeping at all. +func TestAcceptBackoff_CancelBeforeCall(t *testing.T) { + var b AcceptBackoff + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + ok := b.Backoff(ctx) + elapsed := time.Since(start) + assert.False(t, ok, "Backoff must return false when ctx is already cancelled") + assert.Less(t, elapsed, 50*time.Millisecond, + "already-cancelled ctx must return immediately; took %v", elapsed) +} diff --git a/proxy/internal/tcp/router.go b/proxy/internal/tcp/router.go index 15c5022b0..307f2b4f3 100644 --- a/proxy/internal/tcp/router.go +++ b/proxy/internal/tcp/router.go @@ -297,18 +297,29 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error { } }() + var backoff AcceptBackoff for { conn, err := ln.Accept() if err != nil { - if ctx.Err() != nil || errors.Is(err, net.ErrClosed) { + if ctx.Err() != nil || IsClosedListenerErr(err) { + if ok := r.Drain(DefaultDrainTimeout); !ok { + r.logger.Warn("timed out waiting for connections to drain") + } + return nil + } + r.logger.Debugf("SNI router accept: %v; backing off", err) + if !backoff.Backoff(ctx) { + // Cancelled during backoff: still drain in-flight + // connections/relays before returning, matching the + // shutdown path above. if ok := r.Drain(DefaultDrainTimeout); !ok { r.logger.Warn("timed out waiting for connections to drain") } return nil } - r.logger.Debugf("SNI router accept: %v", err) continue } + backoff.Reset() r.logger.Debugf("SNI router accepted conn from %s on %s", conn.RemoteAddr(), conn.LocalAddr()) r.activeConns.Add(1) go func() { diff --git a/proxy/internal/tcp/router_test.go b/proxy/internal/tcp/router_test.go index ea1b418f5..8be617dff 100644 --- a/proxy/internal/tcp/router_test.go +++ b/proxy/internal/tcp/router_test.go @@ -1836,3 +1836,132 @@ func TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled(t *testing.T) { t.Fatal("TLS conn never reached the TLS channel") } } + +// scriptedAcceptListener is a net.Listener whose Accept() returns +// pre-scripted errors. Used by the accept-loop exit tests to simulate +// the failure mode that triggers the tight-loop bug: a netstack +// listener whose endpoint has been destroyed and now returns the gVisor +// "endpoint is in invalid state" error from every Accept call. +type scriptedAcceptListener struct { + errs chan error + closed chan struct{} +} + +func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener { + s := &scriptedAcceptListener{ + errs: make(chan error, len(errs)+1), + closed: make(chan struct{}), + } + for _, e := range errs { + s.errs <- e + } + return s +} + +func (s *scriptedAcceptListener) Accept() (net.Conn, error) { + select { + case <-s.closed: + return nil, net.ErrClosed + case err := <-s.errs: + return nil, err + } +} + +func (s *scriptedAcceptListener) Close() error { + select { + case <-s.closed: + default: + close(s.closed) + } + return nil +} + +func (s *scriptedAcceptListener) Addr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +// TestRouter_Serve_ExitsOnGVisorInvalidEndpoint is the regression guard +// for the tight-loop bug: when the underlying netstack endpoint is +// destroyed, Accept returns "endpoint is in invalid state" forever. The +// loop must recognise that signal and return, otherwise it pegs a CPU +// core and floods logs. +func TestRouter_Serve_ExitsOnGVisorInvalidEndpoint(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, nil, addr) + + gvisorErr := &net.OpError{ + Op: "accept", + Net: "tcp", + Addr: addr, + Err: errSentinel("endpoint is in invalid state"), + } + ln := newScriptedAcceptListener(gvisorErr) + defer ln.Close() + + done := make(chan error, 1) + go func() { + done <- router.Serve(context.Background(), ln) + }() + + select { + case err := <-done: + assert.NoError(t, err, "Serve must return cleanly on a recognised closed-listener error") + case <-time.After(2 * time.Second): + t.Fatal("Serve did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning") + } +} + +// TestRouter_Serve_BacksOffOnTransientError verifies the defence-in- +// depth path: when Accept returns an unknown transient error, the loop +// MUST not spin. It backs off, then exits cleanly once ctx is cancelled. +// "Bounded call count" stands in for "no CPU spin" — without backoff +// the goroutine would issue thousands of Accept calls in this window. +func TestRouter_Serve_BacksOffOnTransientError(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, nil, addr) + + const transientErrCount = 5 + errs := make([]error, transientErrCount) + for i := range errs { + errs[i] = errSentinel("transient: too many open files") + } + ln := newScriptedAcceptListener(errs...) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + start := time.Now() + go func() { + done <- router.Serve(ctx, ln) + }() + + // Cancel after enough time for the backoff to climb (5ms + 10ms + + // 20ms + 40ms = 75ms minimum), but short enough that a spinning + // loop would have made thousands of calls by now. + time.AfterFunc(150*time.Millisecond, cancel) + + select { + case err := <-done: + assert.NoError(t, err, "Serve must return cleanly on ctx cancellation") + case <-time.After(2 * time.Second): + t.Fatal("Serve did not exit on ctx cancellation — backoff or exit path broken") + } + + // Without backoff the loop would burn through all 5 scripted errors + // in microseconds and then block on the channel. With backoff the + // total wall time should be at least 5ms (the first backoff). + elapsed := time.Since(start) + assert.GreaterOrEqual(t, elapsed, minAcceptDelay, + "loop ran without backing off — would burn CPU in production") +} + +// errSentinel mirrors gVisor's tcpip error message exactly. We can't +// import the gVisor package without dragging in the whole netstack, so +// the test uses the canonical string the production error formatter +// emits — same shape IsClosedListenerErr matches in production. +type errSentinel string + +func (e errSentinel) Error() string { return string(e) } + diff --git a/proxy/middleware_register.go b/proxy/middleware_register.go new file mode 100644 index 000000000..736ee04c1 --- /dev/null +++ b/proxy/middleware_register.go @@ -0,0 +1,16 @@ +package proxy + +// Anonymous imports trigger init() in each built-in middleware +// sub-package so they self-register into mwbuiltin.DefaultRegistry() +// before initMiddlewareManager builds the resolver. Add a new line +// here when introducing another built-in middleware. +import ( + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/cost_meter" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_guardrail" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_identity_inject" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_limit_check" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_limit_record" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_request_parser" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_response_parser" + _ "github.com/netbirdio/netbird/proxy/internal/middleware/builtin/llm_router" +) diff --git a/proxy/middleware_translate.go b/proxy/middleware_translate.go new file mode 100644 index 000000000..c5d9fe016 --- /dev/null +++ b/proxy/middleware_translate.go @@ -0,0 +1,165 @@ +package proxy + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/proxy/internal/middleware/bodytap" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// translateMiddlewareCaptureConfig builds the per-target capture +// limits used by the middleware chain. Returns nil when the options +// are nil or no capture field is set. Negative caps are normalised to +// zero; oversized caps are clamped to middleware.MaxBodyCapBytes. +func translateMiddlewareCaptureConfig(targetID string, opts *proto.PathTargetOptions) *bodytap.Config { + if opts == nil { + return nil + } + reqCap := clampMiddlewareCaptureBytes(targetID, "request", opts.GetCaptureMaxRequestBytes()) + respCap := clampMiddlewareCaptureBytes(targetID, "response", opts.GetCaptureMaxResponseBytes()) + types := opts.GetCaptureContentTypes() + if reqCap == 0 && respCap == 0 && len(types) == 0 { + return nil + } + return &bodytap.Config{ + MaxRequestBytes: reqCap, + MaxResponseBytes: respCap, + ContentTypes: types, + } +} + +func clampMiddlewareCaptureBytes(targetID, direction string, v int64) int64 { + if v < 0 { + log.Debugf("target %s %s capture cap %d clamped to 0", targetID, direction, v) + return 0 + } + if v > middleware.MaxBodyCapBytes { + log.Debugf("target %s %s capture cap %d clamped to %d", targetID, direction, v, middleware.MaxBodyCapBytes) + return middleware.MaxBodyCapBytes + } + return v +} + +// translateMiddlewareConfigs converts the proto MiddlewareConfig list +// into validated middleware.Spec values. The list is truncated to +// middleware.MaxMiddlewaresPerChain when the caller exceeds the cap. +// Entries with empty IDs, unknown IDs (when registry is non-nil), or +// unspecified slots are skipped with a warn log. Timeouts are clamped +// to [MinTimeout, MaxTimeout] and zero substitutes for DefaultTimeout. +// Returns nil when the resulting slice is empty so callers can leave +// PathTarget.Middlewares unset. +func translateMiddlewareConfigs( + ctx context.Context, + targetID string, + in []*proto.MiddlewareConfig, + registry *middleware.Registry, +) []middleware.Spec { + _ = ctx + if len(in) == 0 { + return nil + } + if len(in) > middleware.MaxMiddlewaresPerChain { + log.Warnf("middleware list for target %q truncated: %d entries exceeds cap of %d", + targetID, len(in), middleware.MaxMiddlewaresPerChain) + in = in[:middleware.MaxMiddlewaresPerChain] + } + + out := make([]middleware.Spec, 0, len(in)) + for _, cfg := range in { + spec, ok := translateMiddlewareConfig(targetID, cfg, registry) + if !ok { + continue + } + out = append(out, spec) + } + if len(out) == 0 { + return nil + } + return out +} + +// translateMiddlewareConfig validates and converts a single +// MiddlewareConfig. The second return value is false when the entry +// must be dropped from the chain. +func translateMiddlewareConfig(targetID string, cfg *proto.MiddlewareConfig, registry *middleware.Registry) (middleware.Spec, bool) { + if cfg == nil { + return middleware.Spec{}, false + } + id := cfg.GetId() + if id == "" { + log.Warnf("middleware config for target %q dropped: empty middleware id", targetID) + return middleware.Spec{}, false + } + if registry != nil && !registry.IsKnown(id) { + log.Warnf("unknown middleware %q configured for target %s; dropping", id, targetID) + return middleware.Spec{}, false + } + slot, ok := protoToMiddlewareSlot(cfg.GetSlot()) + if !ok { + log.Warnf("middleware %q on target %q dropped: slot is unspecified", id, targetID) + return middleware.Spec{}, false + } + + var rawConfig []byte + if src := cfg.GetConfigJson(); len(src) > 0 { + rawConfig = append([]byte(nil), src...) + } + + return middleware.Spec{ + ID: id, + Slot: slot, + Enabled: cfg.GetEnabled(), + FailMode: protoToMiddlewareFailMode(cfg.GetFailMode()), + Timeout: clampMiddlewareTimeout(id, cfg.GetTimeout().AsDuration()), + RawConfig: rawConfig, + CanMutate: cfg.GetCanMutate(), + }, true +} + +// protoToMiddlewareSlot maps the proto slot enum onto the internal +// middleware.Slot. Returns ok=false for the UNSPECIFIED value so the +// translator can drop the entry. +func protoToMiddlewareSlot(s proto.MiddlewareSlot) (middleware.Slot, bool) { + switch s { + case proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST: + return middleware.SlotOnRequest, true + case proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE: + return middleware.SlotOnResponse, true + case proto.MiddlewareSlot_MIDDLEWARE_SLOT_TERMINAL: + return middleware.SlotTerminal, true + default: + return 0, false + } +} + +// protoToMiddlewareFailMode maps the proto FailMode enum onto the +// internal middleware.FailMode, defaulting to FailOpen for any value +// other than FAIL_CLOSED. +func protoToMiddlewareFailMode(m proto.MiddlewareConfig_FailMode) middleware.FailMode { + if m == proto.MiddlewareConfig_FAIL_CLOSED { + return middleware.FailClosed + } + return middleware.FailOpen +} + +// clampMiddlewareTimeout enforces the proxy-wide [MinTimeout, MaxTimeout] +// bounds and substitutes DefaultTimeout for zero inputs. A warn is logged +// only on an actual clamp, not when filling the default. +func clampMiddlewareTimeout(id string, d time.Duration) time.Duration { + if d <= 0 { + return middleware.DefaultTimeout + } + if d < middleware.MinTimeout { + log.Debugf("middleware %s timeout %s clamped to %s", id, d, middleware.MinTimeout) + return middleware.MinTimeout + } + if d > middleware.MaxTimeout { + log.Debugf("middleware %s timeout %s clamped to %s", id, d, middleware.MaxTimeout) + return middleware.MaxTimeout + } + return d +} diff --git a/proxy/middleware_translate_test.go b/proxy/middleware_translate_test.go new file mode 100644 index 000000000..1a956090c --- /dev/null +++ b/proxy/middleware_translate_test.go @@ -0,0 +1,246 @@ +package proxy + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/netbirdio/netbird/proxy/internal/middleware" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// stubFactory builds a stub Middleware so the registry's IsKnown check +// passes for the configured id. The translator never invokes the +// middleware, so the methods only need to satisfy the interface. +type stubFactory struct { + id string + slot middleware.Slot +} + +func (f stubFactory) ID() string { return f.id } +func (f stubFactory) New(_ []byte) (middleware.Middleware, error) { + return stubMiddleware(f), nil +} + +type stubMiddleware struct { + id string + slot middleware.Slot +} + +func (m stubMiddleware) ID() string { return m.id } +func (m stubMiddleware) Version() string { return "test" } +func (m stubMiddleware) Slot() middleware.Slot { return m.slot } +func (m stubMiddleware) AcceptedContentTypes() []string { return nil } +func (m stubMiddleware) MetadataKeys() []string { return nil } +func (m stubMiddleware) MutationsSupported() bool { return false } +func (m stubMiddleware) Close() error { return nil } +func (m stubMiddleware) Invoke(context.Context, *middleware.Input) (*middleware.Output, error) { + panic("stubMiddleware.Invoke must not be called in translator tests") +} + +// newTestRegistry returns a fresh registry pre-populated with the given +// middleware ids in the matching slot. +func newTestRegistry(t *testing.T, entries map[string]middleware.Slot) *middleware.Registry { + t.Helper() + r := middleware.NewRegistry() + for id, slot := range entries { + require.NoError(t, r.Register(stubFactory{id: id, slot: slot}), "stub registration must succeed") + } + return r +} + +func TestTranslateMiddlewareConfigs_EmptyInput(t *testing.T) { + assert.Nil(t, translateMiddlewareConfigs(context.Background(), "target-a", nil, nil), + "nil input should translate to nil") + assert.Nil(t, translateMiddlewareConfigs(context.Background(), "target-a", []*proto.MiddlewareConfig{}, nil), + "empty input should translate to nil") +} + +func TestTranslateMiddlewareConfigs_KnownIDs(t *testing.T) { + registry := newTestRegistry(t, map[string]middleware.Slot{ + "llm_request_parser": middleware.SlotOnRequest, + "llm_response_parser": middleware.SlotOnResponse, + }) + in := []*proto.MiddlewareConfig{ + { + Id: "llm_request_parser", + Enabled: true, + Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, + ConfigJson: []byte(`{"foo":"bar"}`), + FailMode: proto.MiddlewareConfig_FAIL_OPEN, + Timeout: durationpb.New(250 * time.Millisecond), + CanMutate: true, + }, + { + Id: "llm_response_parser", + Enabled: false, + Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE, + ConfigJson: nil, + FailMode: proto.MiddlewareConfig_FAIL_CLOSED, + Timeout: durationpb.New(50 * time.Millisecond), + }, + } + + out := translateMiddlewareConfigs(context.Background(), "target-a", in, registry) + require.Len(t, out, 2, "two known middlewares should produce two specs") + + assert.Equal(t, "llm_request_parser", out[0].ID, "first id should match") + assert.Equal(t, middleware.SlotOnRequest, out[0].Slot, "first slot should be on_request") + assert.True(t, out[0].Enabled, "first spec should be enabled") + assert.Equal(t, middleware.FailOpen, out[0].FailMode, "first spec should be fail-open") + assert.Equal(t, 250*time.Millisecond, out[0].Timeout, "first spec timeout should pass through") + assert.True(t, out[0].CanMutate, "first spec should permit mutations") + assert.Equal(t, []byte(`{"foo":"bar"}`), out[0].RawConfig, "first spec raw config should match") + + assert.Equal(t, "llm_response_parser", out[1].ID, "second id should match") + assert.Equal(t, middleware.SlotOnResponse, out[1].Slot, "second slot should be on_response") + assert.False(t, out[1].Enabled, "second spec should be disabled") + assert.Equal(t, middleware.FailClosed, out[1].FailMode, "second spec should be fail-closed") + assert.Equal(t, 50*time.Millisecond, out[1].Timeout, "second spec timeout should pass through") + assert.Nil(t, out[1].RawConfig, "second spec raw config should be nil") +} + +func TestTranslateMiddlewareConfigs_UnknownIDSkipped(t *testing.T) { + registry := newTestRegistry(t, map[string]middleware.Slot{ + "llm_request_parser": middleware.SlotOnRequest, + }) + in := []*proto.MiddlewareConfig{ + {Id: "llm_request_parser", Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST}, + {Id: "not_registered", Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST}, + } + out := translateMiddlewareConfigs(context.Background(), "target-unknown", in, registry) + require.Len(t, out, 1, "unknown id must be skipped") + assert.Equal(t, "llm_request_parser", out[0].ID, "remaining entry should be the known one") +} + +func TestTranslateMiddlewareConfigs_NilRegistrySkipsValidation(t *testing.T) { + in := []*proto.MiddlewareConfig{ + {Id: "anything_goes", Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST}, + } + out := translateMiddlewareConfigs(context.Background(), "target-nilreg", in, nil) + require.Len(t, out, 1, "nil registry must accept any non-empty id") + assert.Equal(t, "anything_goes", out[0].ID, "id should pass through unchecked") +} + +func TestTranslateMiddlewareConfigs_TimeoutClamps(t *testing.T) { + registry := newTestRegistry(t, map[string]middleware.Slot{ + "llm_request_parser": middleware.SlotOnRequest, + }) + in := []*proto.MiddlewareConfig{ + {Id: "llm_request_parser", Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, Timeout: nil}, + {Id: "llm_request_parser", Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, Timeout: durationpb.New(time.Microsecond)}, + {Id: "llm_request_parser", Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, Timeout: durationpb.New(time.Hour)}, + } + out := translateMiddlewareConfigs(context.Background(), "target-clamp", in, registry) + require.Len(t, out, 3, "clamping must keep all three entries") + assert.Equal(t, middleware.DefaultTimeout, out[0].Timeout, "zero timeout should default") + assert.Equal(t, middleware.MinTimeout, out[1].Timeout, "below-min timeout should clamp up") + assert.Equal(t, middleware.MaxTimeout, out[2].Timeout, "above-max timeout should clamp down") +} + +func TestTranslateMiddlewareConfigs_FailModeMapping(t *testing.T) { + registry := newTestRegistry(t, map[string]middleware.Slot{ + "llm_request_parser": middleware.SlotOnRequest, + }) + in := []*proto.MiddlewareConfig{ + {Id: "llm_request_parser", Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST}, + {Id: "llm_request_parser", Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, FailMode: proto.MiddlewareConfig_FAIL_CLOSED}, + } + out := translateMiddlewareConfigs(context.Background(), "target-failmode", in, registry) + require.Len(t, out, 2, "both entries should translate") + assert.Equal(t, middleware.FailOpen, out[0].FailMode, "default fail mode should be open") + assert.Equal(t, middleware.FailClosed, out[1].FailMode, "explicit fail closed should map") +} + +func TestTranslateMiddlewareConfigs_SlotMapping(t *testing.T) { + registry := newTestRegistry(t, map[string]middleware.Slot{ + "req": middleware.SlotOnRequest, + "resp": middleware.SlotOnResponse, + "term": middleware.SlotTerminal, + }) + in := []*proto.MiddlewareConfig{ + {Id: "req", Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST}, + {Id: "resp", Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE}, + {Id: "term", Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_TERMINAL}, + {Id: "req", Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_UNSPECIFIED}, + } + out := translateMiddlewareConfigs(context.Background(), "target-slot", in, registry) + require.Len(t, out, 3, "unspecified slot entry must be skipped") + assert.Equal(t, middleware.SlotOnRequest, out[0].Slot, "on_request slot mapping") + assert.Equal(t, middleware.SlotOnResponse, out[1].Slot, "on_response slot mapping") + assert.Equal(t, middleware.SlotTerminal, out[2].Slot, "terminal slot mapping") +} + +func TestTranslateMiddlewareConfigs_EmptyIDSkipped(t *testing.T) { + registry := newTestRegistry(t, map[string]middleware.Slot{ + "llm_request_parser": middleware.SlotOnRequest, + }) + in := []*proto.MiddlewareConfig{ + {Id: "", Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST}, + {Id: "llm_request_parser", Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST}, + } + out := translateMiddlewareConfigs(context.Background(), "target-empty-id", in, registry) + require.Len(t, out, 1, "empty id must be dropped") + assert.Equal(t, "llm_request_parser", out[0].ID, "remaining entry should be valid") +} + +// TestTranslateMiddlewareConfigs_TruncatesAboveCap proves the translator +// truncates lists that exceed MaxMiddlewaresPerChain rather than dropping +// the whole slice, matching the documented G3 behaviour. +func TestTranslateMiddlewareConfigs_TruncatesAboveCap(t *testing.T) { + registry := newTestRegistry(t, map[string]middleware.Slot{ + "llm_request_parser": middleware.SlotOnRequest, + }) + overCap := middleware.MaxMiddlewaresPerChain + 1 + in := make([]*proto.MiddlewareConfig, 0, overCap) + for i := 0; i < overCap; i++ { + in = append(in, &proto.MiddlewareConfig{ + Id: "llm_request_parser", + Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, + }) + } + out := translateMiddlewareConfigs(context.Background(), "target-truncate", in, registry) + assert.Len(t, out, middleware.MaxMiddlewaresPerChain, "over-cap input must be truncated to MaxMiddlewaresPerChain") +} + +func TestTranslateMiddlewareConfigs_AllowsListAtCap(t *testing.T) { + registry := newTestRegistry(t, map[string]middleware.Slot{ + "llm_request_parser": middleware.SlotOnRequest, + }) + in := make([]*proto.MiddlewareConfig, 0, middleware.MaxMiddlewaresPerChain) + for i := 0; i < middleware.MaxMiddlewaresPerChain; i++ { + in = append(in, &proto.MiddlewareConfig{ + Id: "llm_request_parser", + Slot: proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, + }) + } + out := translateMiddlewareConfigs(context.Background(), "target-cap", in, registry) + assert.Len(t, out, middleware.MaxMiddlewaresPerChain, "list at the cap boundary must translate fully") +} + +func TestProtoToMiddlewareSlot(t *testing.T) { + cases := []struct { + name string + in proto.MiddlewareSlot + want middleware.Slot + wantOk bool + }{ + {"unspecified", proto.MiddlewareSlot_MIDDLEWARE_SLOT_UNSPECIFIED, 0, false}, + {"on_request", proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST, middleware.SlotOnRequest, true}, + {"on_response", proto.MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE, middleware.SlotOnResponse, true}, + {"terminal", proto.MiddlewareSlot_MIDDLEWARE_SLOT_TERMINAL, middleware.SlotTerminal, true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, ok := protoToMiddlewareSlot(tc.in) + assert.Equal(t, tc.wantOk, ok, "ok flag for %s", tc.name) + if tc.wantOk { + assert.Equal(t, tc.want, got, "slot mapping for %s", tc.name) + } + }) + } +} diff --git a/proxy/server.go b/proxy/server.go index 1d8a2451b..f28d580bd 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -55,6 +55,8 @@ import ( "github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/k8s" proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics" + "github.com/netbirdio/netbird/proxy/internal/middleware" + mwbuiltin "github.com/netbirdio/netbird/proxy/internal/middleware/builtin" "github.com/netbirdio/netbird/proxy/internal/netutil" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/restrict" @@ -77,29 +79,36 @@ type portRouter struct { type Server struct { ctx context.Context - mgmtClient proto.ProxyServiceClient - proxy *proxy.ReverseProxy - netbird *roundtrip.NetBird - acme *acme.Manager + mgmtClient proto.ProxyServiceClient + proxy *proxy.ReverseProxy + netbird *roundtrip.NetBird + acme *acme.Manager staticCertWatcher *certwatch.Watcher - auth *auth.Middleware - http *http.Server - https *http.Server - debug *http.Server - healthServer *health.Server - healthChecker *health.Checker - meter *proxymetrics.Metrics - accessLog *accesslog.Logger - mainRouter *nbtcp.Router - mainPort uint16 - udpMu sync.Mutex - udpRelays map[types.ServiceID]*udprelay.Relay - udpRelayWg sync.WaitGroup - portMu sync.RWMutex - portRouters map[uint16]*portRouter - svcPorts map[types.ServiceID][]uint16 - lastMappings map[types.ServiceID]*proto.ProxyMapping - portRouterWg sync.WaitGroup + auth *auth.Middleware + http *http.Server + https *http.Server + debug *http.Server + healthServer *health.Server + healthChecker *health.Checker + meter *proxymetrics.Metrics + accessLog *accesslog.Logger + // middlewareManager drives per-target middleware dispatch. Always + // constructed during boot; an empty registry produces empty chains and + // the reverse-proxy stays on the no-capture fast path. + middlewareManager *middleware.Manager + // middlewareRegistry is the source of registered middleware factories. + // Concrete middlewares register themselves through init(). + middlewareRegistry *middleware.Registry + mainRouter *nbtcp.Router + mainPort uint16 + udpMu sync.Mutex + udpRelays map[types.ServiceID]*udprelay.Relay + udpRelayWg sync.WaitGroup + portMu sync.RWMutex + portRouters map[uint16]*portRouter + svcPorts map[types.ServiceID][]uint16 + lastMappings map[types.ServiceID]*proto.ProxyMapping + portRouterWg sync.WaitGroup // hijackTracker tracks hijacked connections (e.g. WebSocket upgrades) // so they can be closed during graceful shutdown, since http.Server.Shutdown @@ -236,8 +245,20 @@ type Server struct { // in processMappings before the receive loop reconnects to resync. // Zero uses defaultMappingBatchWatchdog. MappingBatchWatchdog time.Duration + // MiddlewareDataDir is the base directory the middleware system uses to + // resolve file-backed configuration (e.g. the cost_meter pricing table). + // Empty means any middleware that requires a file fails at configure time. + MiddlewareDataDir string + // MiddlewareCaptureBudgetBytes overrides the proxy-wide in-flight capture + // budget passed to middleware.NewManager. Zero or negative values fall + // back to defaultMiddlewareCaptureBudgetBytes (256 MiB). + MiddlewareCaptureBudgetBytes int64 } +// defaultMiddlewareCaptureBudgetBytes is the proxy-wide in-flight capture cap +// passed to middleware.NewManager when MiddlewareCaptureBudgetBytes is unset. +const defaultMiddlewareCaptureBudgetBytes = 256 << 20 + // clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured. func (s *Server) clampIdleTimeout(d time.Duration) time.Duration { if s.MaxSessionIdleTimeout > 0 && d > s.MaxSessionIdleTimeout { @@ -343,6 +364,15 @@ func (s *Server) Start(ctx context.Context) error { return err } + // Management client must be initialised BEFORE the middleware manager — + // initMiddlewareManager passes s.mgmtClient into the builtin FactoryContext + // that the limit-check / limit-record middlewares pull from. Reversed + // order would silently disable enforcement (mgmt=nil → allow-without- + // attribution + no-record). + if err := s.initMiddlewareManager(ctx); err != nil { + return fmt.Errorf("init middleware manager: %w", err) + } + runCtx, runCancel := context.WithCancel(ctx) s.runCancel = runCancel @@ -562,7 +592,11 @@ func (s *Server) initNetBirdClient() { // proxy host's resolver instead of the tunnel's DNS. func (s *Server) initReverseProxy() { upstreamRT := roundtrip.NewMultiTransport(s.netbird, s.Logger) - s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(upstreamRT), s.ForwardedProto, s.TrustedProxies, s.Logger) + var rpOpts []proxy.Option + if s.middlewareManager != nil { + rpOpts = append(rpOpts, proxy.WithMiddlewareManager(s.middlewareManager)) + } + s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(upstreamRT), s.ForwardedProto, s.TrustedProxies, s.Logger, rpOpts...) } // initGeoLookup configures the GeoLite2 lookup used for country-based @@ -2047,9 +2081,94 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) m := s.protoToMapping(ctx, mapping) s.proxy.AddMapping(m) s.meter.AddMapping(m) + s.rebuildMiddlewareChains(svcID, m) return nil } +// initMiddlewareManager wires the middleware subsystem at boot. It configures +// the per-process FactoryContext concrete middlewares consult, installs the +// live-service check, and binds the resolver to the registry concrete +// middlewares register themselves into via init(). +func (s *Server) initMiddlewareManager(ctx context.Context) error { + if s.meter == nil { + return fmt.Errorf("middleware manager requires metrics bundle") + } + otelMeter := s.meter.Meter() + mwbuiltin.Configure(ctx, s.MiddlewareDataDir, otelMeter, s.Logger, s.mgmtClient) + + mwMetrics, err := middleware.NewMetrics(otelMeter) + if err != nil { + return fmt.Errorf("init middleware metrics: %w", err) + } + budgetBytes := s.MiddlewareCaptureBudgetBytes + if budgetBytes <= 0 { + budgetBytes = defaultMiddlewareCaptureBudgetBytes + } + + registry := mwbuiltin.DefaultRegistry() + mgr := middleware.NewManager(budgetBytes, mwMetrics, s.Logger) + mgr.SetResolver(middleware.NewResolver(registry)) + mgr.SetLiveServiceCheck(s.isLiveService) + + s.middlewareRegistry = registry + s.middlewareManager = mgr + ids := registry.IDs() + s.Logger.Infof("middleware system enabled: %d built-in middlewares registered %v, capture budget %d bytes", + len(ids), ids, budgetBytes) + return nil +} + +// rebuildMiddlewareChains converts m into per-path bindings and calls +// Manager.Rebuild. Short-circuits when the middleware manager is unset. +func (s *Server) rebuildMiddlewareChains(svcID types.ServiceID, m proxy.Mapping) { + if s.middlewareManager == nil { + return + } + bindings := buildMiddlewareBindings(svcID, m) + if err := s.middlewareManager.Rebuild(string(svcID), bindings); err != nil { + s.Logger.WithError(err).WithField("service_id", svcID).Error("failed to rebuild middleware chains") + } +} + +// isLiveService reports whether svcID is currently present in the live +// mapping cache. Used by the middleware manager to confirm a chain is still +// referenced before rebuilding it from cached bindings. +func (s *Server) isLiveService(svcID string) bool { + s.portMu.RLock() + defer s.portMu.RUnlock() + _, ok := s.lastMappings[types.ServiceID(svcID)] + return ok +} + +// invalidateMiddlewareChains drops every middleware chain registered for svcID. +func (s *Server) invalidateMiddlewareChains(svcID types.ServiceID) { + if s.middlewareManager == nil { + return + } + s.middlewareManager.Invalidate(string(svcID)) +} + +// buildMiddlewareBindings converts the path targets of m into the per-path +// binding list the middleware manager's Rebuild expects. Targets without any +// middleware specs are skipped. +func buildMiddlewareBindings(svcID types.ServiceID, m proxy.Mapping) []middleware.PathTargetBinding { + if len(m.Paths) == 0 { + return nil + } + bindings := make([]middleware.PathTargetBinding, 0, len(m.Paths)) + for pathID, pt := range m.Paths { + if pt == nil || len(pt.Middlewares) == 0 { + continue + } + bindings = append(bindings, middleware.PathTargetBinding{ + ServiceID: string(svcID), + PathID: pathID, + Specs: pt.Middlewares, + }) + } + return bindings +} + // removeMapping tears down routes/relays and the NetBird peer for a service. // Uses the stored mapping state when available to ensure all previously // configured routes are cleaned up. @@ -2085,6 +2204,8 @@ func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) { svcID := types.ServiceID(mapping.GetId()) host := mapping.GetDomain() + s.invalidateMiddlewareChains(svcID) + // HTTP/TLS cleanup (only relevant when a domain is set). if host != "" { d := domain.Domain(host) @@ -2192,6 +2313,12 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping pt.RequestTimeout = d.AsDuration() } pt.DirectUpstream = opts.GetDirectUpstream() + // Agent-network middleware specs + capture config + flag ride on + // the same per-target options. + pt.CaptureConfig = translateMiddlewareCaptureConfig(mapping.GetId(), opts) + pt.Middlewares = translateMiddlewareConfigs(ctx, mapping.GetId(), opts.GetMiddlewares(), s.middlewareRegistry) + pt.AgentNetwork = opts.GetAgentNetwork() + pt.DisableAccessLog = opts.GetDisableAccessLog() } pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout) paths[pathMapping.GetPath()] = pt diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 196a0c6b1..dffb7d7de 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -5069,6 +5069,1054 @@ components: type: string description: A human-readable error message. example: "couldn't parse JSON request" + AgentNetworkProvider: + type: object + properties: + id: + type: string + description: Provider ID + example: "ainp_d1m3kebd9pcs0c1pnu7g" + provider_id: + type: string + description: Catalog identifier for the upstream AI provider (e.g. openai_api, anthropic_api, azure_openai_api, bedrock_api, vertex_ai_api, mistral_api, custom). + example: "openai_api" + name: + type: string + description: Display name shown in the dashboard. + example: "OpenAI API" + upstream_url: + type: string + description: Full upstream URL (with scheme) that NetBird forwards traffic to. + example: "https://api.openai.com" + models: + type: array + description: Models exposed through this endpoint, with the operator's per-1k input/output prices. Empty means all catalog models are allowed at catalog prices. + items: + $ref: '#/components/schemas/AgentNetworkProviderModel' + extra_values: + type: object + description: | + Operator-typed values for catalog-declared extra headers. Keys are wire header names (e.g. `x-portkey-config`); values are the strings the proxy stamps on every upstream request to this provider. Catalog (AgentNetworkCatalogProvider.extra_headers) declares which keys are accepted; values not declared by the catalog are ignored at synth time. Empty / missing values mean no header stamped. + additionalProperties: + type: string + example: + x-portkey-config: "pc-prod-3f2a" + identity_header_user_id: + type: string + description: | + Wire header name the proxy stamps with the caller's display identity (user email or peer name) when the catalog entry's HeaderPair is `customizable`. Empty disables stamping for this dimension. Ignored when the catalog entry has a fixed HeaderPair (e.g. LiteLLM, Portkey). Used today by Bifrost: typical values are `x-bf-lh-netbird_user_id` (always-on log metadata) or `x-bf-dim-netbird_user_id` (Prometheus / OTEL — requires the label to be pre-declared in the gateway's `client.prometheus_labels` config). + example: "x-bf-dim-netbird_user_id" + identity_header_groups: + type: string + description: | + Wire header name the proxy stamps with the caller's NetBird groups as a comma-separated list (sorted) when the catalog entry's HeaderPair is `customizable`. Empty disables stamping for this dimension. Same per-catalog semantics as `identity_header_user_id`. + example: "x-bf-dim-netbird_groups" + enabled: + type: boolean + description: Whether the provider is enabled. + example: true + created_at: + type: string + format: date-time + description: Timestamp when the provider was created. + readOnly: true + example: "2026-04-26T10:30:00Z" + updated_at: + type: string + format: date-time + description: Timestamp when the provider was last updated. + readOnly: true + example: "2026-04-26T10:30:00Z" + required: + - id + - provider_id + - name + - upstream_url + - models + - enabled + - created_at + - updated_at + AgentNetworkProviderRequest: + type: object + properties: + provider_id: + type: string + description: Catalog identifier for the upstream AI provider (e.g. openai_api, anthropic_api, azure_openai_api, bedrock_api, vertex_ai_api, mistral_api, custom). + example: "openai_api" + name: + type: string + description: Display name for the provider. + example: "OpenAI API" + upstream_url: + type: string + description: Full upstream URL (with scheme) that NetBird forwards traffic to. + example: "https://api.openai.com" + bootstrap_cluster: + type: string + description: Proxy cluster used to bootstrap the per-account agent-network endpoint when the first provider is created. Ignored on subsequent creates and on updates because the cluster is pinned on the account-level Settings row. + example: "eu.proxy.netbird.io" + api_key: + type: string + description: Upstream provider API key. Sealed at rest on the management server and never returned in responses. Required on create; optional on update (omit to keep the existing key). + example: "sk-..." + models: + type: array + description: Models exposed through this endpoint, with the operator's per-1k input/output prices. Empty means all catalog models are allowed at catalog prices. + items: + $ref: '#/components/schemas/AgentNetworkProviderModel' + extra_values: + type: object + description: | + Operator-typed values for catalog-declared extra headers (see AgentNetworkProvider.extra_values). When present on a request, the whole map replaces the stored values. Empty strings drop the corresponding key. + additionalProperties: + type: string + example: + x-portkey-config: "pc-prod-3f2a" + identity_header_user_id: + type: string + description: | + Wire header name for the caller's display identity. See AgentNetworkProvider.identity_header_user_id. When omitted on a request, the stored value is left unchanged; pass an empty string explicitly to clear it (which disables stamping for this dimension). + example: "x-bf-dim-netbird_user_id" + identity_header_groups: + type: string + description: | + Wire header name for the caller's groups CSV. See AgentNetworkProvider.identity_header_groups. Same omit / empty semantics as `identity_header_user_id`. + example: "x-bf-dim-netbird_groups" + enabled: + type: boolean + description: Whether the provider is enabled. Defaults to true on create. + example: true + required: + - provider_id + - name + - upstream_url + AgentNetworkProviderModel: + type: object + description: A model exposed by the provider, with the operator's per-1k input/output prices in USD. + properties: + id: + type: string + description: Model identifier (e.g. "gpt-4o-mini"). + example: "gpt-4o-mini" + input_per_1k: + type: number + format: double + description: Cost per 1k input tokens, in USD. + example: 0.00015 + output_per_1k: + type: number + format: double + description: Cost per 1k output tokens, in USD. + example: 0.0006 + required: + - id + - input_per_1k + - output_per_1k + AgentNetworkCatalogModel: + type: object + properties: + id: + type: string + description: Catalog model identifier as exposed by the upstream provider. + example: "gpt-4o" + label: + type: string + description: Human-friendly model name for the dashboard. + example: "GPT-4o" + input_per_1k: + type: number + format: double + description: Input token price per 1k tokens, in USD. + example: 0.005 + output_per_1k: + type: number + format: double + description: Output token price per 1k tokens, in USD. + example: 0.015 + context_window: + type: integer + description: Maximum context window in tokens. + example: 128000 + required: + - id + - label + - input_per_1k + - output_per_1k + - context_window + AgentNetworkCatalogProvider: + type: object + properties: + id: + type: string + description: Catalog provider identifier (referenced by AgentNetworkProvider.provider_id). + example: "openai_api" + name: + type: string + description: Display name for the provider. + example: "OpenAI API" + description: + type: string + description: Short description shown in the provider picker. + example: "GPT, Responses API, and Embeddings" + default_host: + type: string + description: Default upstream host suggested when adding a provider of this type. + example: "api.openai.com" + auth_header_template: + type: string + description: Template the proxy uses to inject the API key (the literal string ${API_KEY} is replaced at request time). + example: "Bearer ${API_KEY}" + default_content_type: + type: string + description: Default Content-Type for upstream requests. + example: "application/json" + brand_color: + type: string + description: Hex brand color used to render the provider badge in the dashboard. + example: "#10A37F" + kind: + type: string + description: | + Presentation grouping for the provider Select on the dashboard. + "provider" — first-party vendor API (OpenAI, Anthropic, …); the upstream is the model itself. + "gateway" — routing/aggregation layer in front of multiple providers (LiteLLM, Portkey, …); typically pairs with NetBird identity stamping. + "custom" — generic OpenAI-compatible self-hosted endpoint catch-all. + enum: [provider, gateway, custom] + example: "provider" + extra_headers: + type: array + description: | + Catalog-declared list of optional per-provider routing/config headers the proxy stamps on every upstream request. Each entry surfaces an input on the dashboard's provider modal (one per item, labeled with `label`). Operators fill any subset; values land on the provider record's `extra_values` map keyed by `name`. Used by gateways like Portkey for `x-portkey-config: pc-...` (saved-config id resolving upstream provider + virtual key). + items: + $ref: '#/components/schemas/AgentNetworkCatalogExtraHeader' + identity_injection: + $ref: '#/components/schemas/AgentNetworkCatalogIdentityInjection' + models: + type: array + description: Catalog models available for this provider. + items: + $ref: '#/components/schemas/AgentNetworkCatalogModel' + required: + - id + - name + - description + - default_host + - auth_header_template + - default_content_type + - brand_color + - kind + - models + AgentNetworkCatalogIdentityInjection: + type: object + description: | + Catalog-declared identity-injection shape. Present when this provider supports stamping the caller's NetBird identity onto upstream requests. Exactly one of `header_pair` or `json_metadata` is set per provider entry. The dashboard reads the `customizable` flag on whichever shape is present to decide whether to surface the labels as editable inputs (true → editable with the catalog values shown as placeholders; false → fixed and read-only). + properties: + header_pair: + $ref: '#/components/schemas/AgentNetworkCatalogHeaderPairInjection' + json_metadata: + $ref: '#/components/schemas/AgentNetworkCatalogJSONMetadataInjection' + AgentNetworkCatalogHeaderPairInjection: + type: object + description: HeaderPair identity-injection shape — separate per-dimension headers (LiteLLM-style, Bifrost). + properties: + customizable: + type: boolean + description: When true, the wire header names are operator-overridable per provider record (Bifrost). When false, the catalog values are authoritative (LiteLLM and similar gateways with a fixed wire protocol). + example: true + end_user_id_header: + type: string + description: Wire header name for the caller's display identity. Default placeholder when `customizable` is true. + example: "x-bf-dim-netbird_user_id" + tags_header: + type: string + description: Wire header name for the caller's groups CSV. Default placeholder when `customizable` is true. + example: "x-bf-dim-netbird_groups" + required: + - customizable + - end_user_id_header + - tags_header + AgentNetworkCatalogJSONMetadataInjection: + type: object + description: JSONMetadata identity-injection shape — one wire header carrying a JSON object whose keys label each dimension (Portkey-style, Cloudflare AI Gateway). + properties: + customizable: + type: boolean + description: When true, the JSON keys are operator-overridable per provider record (Cloudflare). The wire header itself stays catalog-owned. When false, the catalog values are authoritative (Portkey and similar gateways with a fixed JSON schema). + example: true + header: + type: string + description: Wire header name carrying the JSON metadata payload. Catalog-owned (not customizable per provider record). + example: "cf-aig-metadata" + user_key: + type: string + description: JSON key for the caller's display identity. Default placeholder when `customizable` is true. + example: "netbird_user_id" + groups_key: + type: string + description: JSON key for the caller's groups CSV. Default placeholder when `customizable` is true. + example: "netbird_groups" + required: + - customizable + - header + - user_key + - groups_key + AgentNetworkCatalogExtraHeader: + type: object + description: One optional per-provider routing/config header surfaced on the dashboard. Operator-typed value lives on the provider record's `extra_values` map keyed by `name`. UI copy (input label, helper line, tooltip) is owned by the dashboard, keyed by `name`. + properties: + name: + type: string + description: Wire header name the proxy stamps with the operator-typed value. + example: "x-portkey-config" + required: + - name + AgentNetworkPolicy: + type: object + properties: + id: + type: string + description: Policy ID + example: "ainpol_d1m3kebd9pcs0c1pnu7g" + name: + type: string + description: Display name for the policy. + example: "Engineering → OpenAI" + description: + type: string + description: Optional human-readable description. + example: "Engineers can call OpenAI under production guardrails." + enabled: + type: boolean + description: Whether the policy is enabled. + example: true + source_groups: + type: array + description: NetBird group ids whose members are allowed to call the destination providers. + items: + type: string + example: ["ch8vp3o6lnna9hg0sd8g"] + destination_provider_ids: + type: array + description: Agent Network provider ids (returned by the providers API) the source groups can reach. + items: + type: string + example: ["ainp_d1m3kebd9pcs0c1pnu7g"] + guardrail_ids: + type: array + description: Agent Network guardrail ids attached to this policy. + items: + type: string + example: [] + limits: + $ref: '#/components/schemas/AgentNetworkPolicyLimits' + created_at: + type: string + format: date-time + description: Timestamp when the policy was created. + readOnly: true + example: "2026-04-26T10:30:00Z" + updated_at: + type: string + format: date-time + description: Timestamp when the policy was last updated. + readOnly: true + example: "2026-04-26T10:30:00Z" + required: + - id + - name + - description + - enabled + - source_groups + - destination_provider_ids + - guardrail_ids + - limits + - created_at + - updated_at + AgentNetworkPolicyRequest: + type: object + properties: + name: + type: string + description: Display name for the policy. + example: "Engineering → OpenAI" + description: + type: string + description: Optional human-readable description. + example: "Engineers can call OpenAI under production guardrails." + enabled: + type: boolean + description: Whether the policy is enabled. Defaults to true on create. + example: true + source_groups: + type: array + description: NetBird group ids whose members are allowed to call the destination providers. + items: + type: string + minItems: 1 + example: ["ch8vp3o6lnna9hg0sd8g"] + destination_provider_ids: + type: array + description: Agent Network provider ids the source groups can reach. + items: + type: string + minItems: 1 + example: ["ainp_d1m3kebd9pcs0c1pnu7g"] + guardrail_ids: + type: array + description: Agent Network guardrail ids to attach to this policy. + items: + type: string + example: [] + limits: + $ref: '#/components/schemas/AgentNetworkPolicyLimits' + required: + - name + - source_groups + - destination_provider_ids + AgentNetworkPolicyTokenLimit: + type: object + description: Per-policy token cap. `group_cap` is applied to each source group independently — every group in the policy's `source_groups` gets its own bucket of this size. `user_cap` is applied independently to each individual user. Caps reset to zero at the start of each window. + properties: + enabled: + type: boolean + example: true + group_cap: + type: integer + format: int64 + minimum: 0 + description: Tokens allowed per source group within the window (each group has its own bucket of this size). 0 means uncapped. + example: 10000000 + user_cap: + type: integer + format: int64 + minimum: 0 + description: Tokens allowed per individual user within the window. 0 means uncapped. + example: 1000000 + window_seconds: + type: integer + format: int64 + minimum: 60 + description: Reset frequency in seconds. The cap counter resets to zero at the start of each window. Minimum 60 (one minute) when the limit is enabled. + example: 2592000 + required: + - enabled + - group_cap + - user_cap + - window_seconds + AgentNetworkPolicyBudgetLimit: + type: object + description: Per-policy USD spend cap. `group_cap_usd` is applied to each source group independently — every group in the policy's `source_groups` gets its own bucket of this size. `user_cap_usd` is applied independently to each individual user. Caps reset to zero at the start of each window. + properties: + enabled: + type: boolean + example: true + group_cap_usd: + type: number + format: double + minimum: 0 + description: USD allowed per source group within the window (each group has its own bucket of this size). 0 means uncapped. + example: 1000 + user_cap_usd: + type: number + format: double + minimum: 0 + description: USD allowed per individual user within the window. 0 means uncapped. + example: 100 + window_seconds: + type: integer + format: int64 + minimum: 60 + description: Reset frequency in seconds. Caps reset at the start of each window. Minimum 60 (one minute) when the limit is enabled. + example: 2592000 + required: + - enabled + - group_cap_usd + - user_cap_usd + - window_seconds + AgentNetworkPolicyLimits: + type: object + description: Token and budget caps attached directly to the policy. These compose with any guardrail-level checks. + properties: + token_limit: + $ref: '#/components/schemas/AgentNetworkPolicyTokenLimit' + budget_limit: + $ref: '#/components/schemas/AgentNetworkPolicyBudgetLimit' + required: + - token_limit + - budget_limit + AgentNetworkGuardrailChecks: + type: object + description: Guardrail check parameters. Each entry has an `enabled` flag plus per-check configuration; disabled entries are inert. + properties: + model_allowlist: + type: object + properties: + enabled: + type: boolean + example: true + models: + type: array + description: Allowed catalog model ids. Requests for any other model are denied. + items: + type: string + example: ["gpt-4o-mini", "claude-haiku-4-5"] + required: + - enabled + - models + prompt_capture: + type: object + properties: + enabled: + type: boolean + example: true + redact_pii: + type: boolean + example: true + required: + - enabled + - redact_pii + required: + - model_allowlist + - prompt_capture + AgentNetworkGuardrail: + type: object + properties: + id: + type: string + description: Guardrail ID + example: "ainguard_d1m3kebd9pcs0c1pnu7g" + name: + type: string + description: Display name for the guardrail. + example: "Strict — Production" + description: + type: string + description: Optional human-readable description. + example: "Tight model allowlist, PII redaction, hard monthly budget." + checks: + $ref: '#/components/schemas/AgentNetworkGuardrailChecks' + created_at: + type: string + format: date-time + description: Timestamp when the guardrail was created. + readOnly: true + example: "2026-04-26T10:30:00Z" + updated_at: + type: string + format: date-time + description: Timestamp when the guardrail was last updated. + readOnly: true + example: "2026-04-26T10:30:00Z" + required: + - id + - name + - description + - checks + - created_at + - updated_at + AgentNetworkGuardrailRequest: + type: object + properties: + name: + type: string + description: Display name for the guardrail. + example: "Strict — Production" + description: + type: string + description: Optional human-readable description. + example: "Tight model allowlist, PII redaction, hard monthly budget." + checks: + $ref: '#/components/schemas/AgentNetworkGuardrailChecks' + required: + - name + - checks + AgentNetworkConsumption: + type: object + description: One per-(dimension, window) consumption counter row. The proxy ticks one row per dimension on every served LLM request; the dashboard reads this listing to surface live counter growth. + properties: + dimension_kind: + type: string + enum: [user, group] + description: Whether this row counts a single end user or a single source group across every member. + dimension_id: + type: string + description: NetBird user id (when `dimension_kind=user`) or NetBird group id (when `dimension_kind=group`). + example: "grp-engineers" + window_seconds: + type: integer + format: int64 + description: Length of the aligned window this counter covers, in seconds. Distinct window lengths produce independent counters even on the same dimension. + example: 86400 + window_start_utc: + type: string + format: date-time + description: UTC start of the aligned window this counter covers. Aligned to the unix epoch so every node computes the same boundary. + example: "2026-05-05T12:00:00Z" + tokens_input: + type: integer + format: int64 + description: Total input tokens consumed within the window. + example: 12000 + tokens_output: + type: integer + format: int64 + description: Total output tokens consumed within the window. + example: 6500 + cost_usd: + type: number + format: double + description: Total USD spend booked against this dimension for the window. + example: 0.4231 + updated_at: + type: string + format: date-time + description: Timestamp of the last increment recorded for this row. + readOnly: true + example: "2026-05-05T12:34:56Z" + required: + - dimension_kind + - dimension_id + - window_seconds + - window_start_utc + - tokens_input + - tokens_output + - cost_usd + AgentNetworkAccessLog: + type: object + description: One per-request agent-network (LLM) access log entry with flattened, queryable LLM dimensions. + properties: + id: + type: string + description: Unique identifier for the access log entry. + example: "ch8i4ug6lnn4g9hqv7m0" + service_id: + type: string + description: ID of the synthesised agent-network service that handled the request. + timestamp: + type: string + format: date-time + description: Timestamp when the request was made. + example: "2026-05-05T12:34:56Z" + status_code: + type: integer + description: HTTP status code returned upstream. + example: 200 + duration_ms: + type: integer + description: Duration of the request in milliseconds. + example: 850 + user_id: + type: string + description: NetBird user id of the authenticated caller, if applicable. + source_ip: + type: string + description: Source IP of the request. Empty when log collection is disabled. + method: + type: string + description: HTTP method of the request. + example: "POST" + host: + type: string + description: Upstream host the request was routed to. Empty when log collection is disabled. + path: + type: string + description: Request path. Empty when log collection is disabled. + provider: + type: string + description: LLM provider vendor (e.g. openai, anthropic). + example: "openai" + model: + type: string + description: Requested LLM model. + example: "gpt-4o" + session_id: + type: string + description: Conversation / coding-session identifier that groups related requests. Sourced from the client's session marker (e.g. OpenAI Codex client_metadata.session_id, Claude Code metadata.user_id). Empty for clients that send none. + example: "019eeb72-ab7c-7cd2-aa05-6e8eb834afcb" + resolved_provider_id: + type: string + description: NetBird agent-network provider id that served the request. + selected_policy_id: + type: string + description: Agent-network policy id that authorised (or denied) the request. + decision: + type: string + description: Policy decision for the request (e.g. allow, deny). + example: "allow" + deny_reason: + type: string + description: Raw deny reason code when the request was blocked (e.g. llm_policy.token_cap_exceeded). + input_tokens: + type: integer + format: int64 + description: Input (prompt) tokens consumed. + example: 1200 + output_tokens: + type: integer + format: int64 + description: Output (completion) tokens produced. + example: 640 + total_tokens: + type: integer + format: int64 + description: Total tokens consumed. + example: 1840 + cost_usd: + type: number + format: double + description: Estimated USD cost of the request. + example: 0.0231 + stream: + type: boolean + description: Whether the request was a streaming completion. + group_ids: + type: array + items: + type: string + description: NetBird group ids that authorised the request (the caller's groups intersected with the policy's source groups). + request_prompt: + type: string + description: Captured request prompt. Present only when prompt collection is enabled. + response_completion: + type: string + description: Captured response completion. Present only when prompt collection is enabled. + required: + - id + - service_id + - timestamp + - status_code + - duration_ms + - input_tokens + - output_tokens + - total_tokens + - cost_usd + AgentNetworkAccessLogsResponse: + type: object + properties: + data: + type: array + description: List of agent-network access log entries. + items: + $ref: "#/components/schemas/AgentNetworkAccessLog" + page: + type: integer + description: Current page number. + example: 1 + page_size: + type: integer + description: Number of items per page. + example: 50 + total_records: + type: integer + description: Total number of log records matching the filter. + example: 523 + total_pages: + type: integer + description: Total number of pages available. + example: 11 + required: + - data + - page + - page_size + - total_records + - total_pages + AgentNetworkAccessLogSession: + type: object + description: A session-grouped view of agent-network access logs — all requests sharing a session id (or a single session-less request) folded into one summary plus its ordered entries. + properties: + session_id: + type: string + description: Conversation / coding-session identifier shared by the entries. Empty for a session-less (singleton) request grouped on its own id. + example: "019eeb72-ab7c-7cd2-aa05-6e8eb834afcb" + user_id: + type: string + description: NetBird user id of the session's caller. + group_ids: + type: array + items: + type: string + description: Union of the authorising group ids across the session's entries. + started_at: + type: string + format: date-time + description: Timestamp of the session's earliest request. + example: "2026-05-05T12:30:00Z" + ended_at: + type: string + format: date-time + description: Timestamp of the session's latest request. + example: "2026-05-05T12:34:56Z" + request_count: + type: integer + description: Number of requests in the session. + example: 7 + input_tokens: + type: integer + format: int64 + description: Total input (prompt) tokens across the session. + example: 8400 + output_tokens: + type: integer + format: int64 + description: Total output (completion) tokens across the session. + example: 4480 + total_tokens: + type: integer + format: int64 + description: Total tokens across the session. + example: 12880 + cost_usd: + type: number + format: double + description: Total estimated USD cost across the session. + example: 0.1617 + providers: + type: array + items: + type: string + description: Distinct LLM provider vendors seen in the session. + models: + type: array + items: + type: string + description: Distinct models seen in the session. + decision: + type: string + description: Session decision — "deny" if any request was denied, otherwise "allow". + example: "allow" + entries: + type: array + description: The session's access-log entries, oldest first. + items: + $ref: "#/components/schemas/AgentNetworkAccessLog" + required: + - started_at + - ended_at + - request_count + - input_tokens + - output_tokens + - total_tokens + - cost_usd + - decision + - entries + AgentNetworkAccessLogSessionsResponse: + type: object + properties: + data: + type: array + description: List of session-grouped agent-network access logs. + items: + $ref: "#/components/schemas/AgentNetworkAccessLogSession" + page: + type: integer + description: Current page number. + example: 1 + page_size: + type: integer + description: Number of sessions per page. + example: 50 + total_records: + type: integer + description: Total number of sessions matching the filter. + example: 124 + total_pages: + type: integer + description: Total number of pages available. + example: 3 + required: + - data + - page + - page_size + - total_records + - total_pages + AgentNetworkUsageBucket: + type: object + description: One aggregated agent-network usage time bucket (UTC). The bucket width is set by the request's granularity. + properties: + period_start: + type: string + description: Start of the bucket in YYYY-MM-DD (UTC) — the day, the week start (Monday), or the month start, depending on granularity. + example: "2026-05-05" + input_tokens: + type: integer + format: int64 + description: Total input (prompt) tokens in the bucket. + example: 120000 + output_tokens: + type: integer + format: int64 + description: Total output (completion) tokens in the bucket. + example: 64000 + total_tokens: + type: integer + format: int64 + description: Total tokens in the bucket. + example: 184000 + cost_usd: + type: number + format: double + description: Total estimated USD spend in the bucket. + example: 2.31 + required: + - period_start + - input_tokens + - output_tokens + - total_tokens + - cost_usd + AgentNetworkSettings: + type: object + description: Per-account Agent Network gateway settings. One row per account; cluster and subdomain are auto-assigned on first provider create and immutable thereafter. + properties: + cluster: + type: string + description: Address of the NetBird proxy cluster fronting this account's agent-network endpoint. + example: "eu.proxy.netbird.io" + subdomain: + type: string + description: Auto-generated DNS-safe label that prefixes the cluster to form the agent-network endpoint. + example: "violet" + endpoint: + type: string + description: Bare hostname agents call for this account, computed as `.`. + example: "violet.eu.proxy.netbird.io" + enable_log_collection: + type: boolean + description: Whether per-request access-log entries are collected for this account's agent-network traffic. + example: false + enable_prompt_collection: + type: boolean + description: Master switch for request/response prompt capture. Capture runs only when this is on AND a policy guardrail also enables it. + example: false + redact_pii: + type: boolean + description: Whether captured prompts have PII redacted. Effective redaction is the OR of this and any policy guardrail's redact setting. + example: false + access_log_retention_days: + type: integer + description: Days to retain full access-log rows; older rows are swept. 0 or less means keep indefinitely. Usage records are retained independently. + example: 30 + created_at: + type: string + format: date-time + description: Timestamp when the settings row was created. + readOnly: true + example: "2026-04-26T10:30:00Z" + updated_at: + type: string + format: date-time + description: Timestamp when the settings row was last updated. + readOnly: true + example: "2026-04-26T10:30:00Z" + required: + - cluster + - subdomain + - endpoint + - enable_log_collection + - enable_prompt_collection + - redact_pii + - created_at + - updated_at + AgentNetworkSettingsRequest: + type: object + description: Mutable account-level Agent Network settings. Cluster and subdomain are immutable and not accepted here. + properties: + enable_log_collection: + type: boolean + description: Whether per-request access-log entries are collected for this account's agent-network traffic. + example: true + enable_prompt_collection: + type: boolean + description: Master switch for request/response prompt capture. + example: true + redact_pii: + type: boolean + description: Whether captured prompts have PII redacted. + example: true + access_log_retention_days: + type: integer + description: Days to retain full access-log rows; older rows are swept. 0 or less means keep indefinitely. + example: 30 + required: + - enable_log_collection + - enable_prompt_collection + - redact_pii + AgentNetworkBudgetRule: + type: object + description: Account-level budget rule. A limit-only rule bound to groups and/or users that applies across all policies as a min-wins ceiling. Empty targets means it applies to every caller. + properties: + id: + type: string + description: Budget rule ID. + example: "ainbud_d1m3kebd9pcs0c1pnu7g" + name: + type: string + description: Display name for the budget rule. + example: "Org monthly ceiling" + enabled: + type: boolean + description: Whether the rule is enforced. + example: true + target_groups: + type: array + description: NetBird group ids the rule binds. Empty plus empty target_users means account-wide. + items: + type: string + example: ["ch8vp3o6lnna9hg0sd8g"] + target_users: + type: array + description: NetBird user ids the rule binds directly. + items: + type: string + example: [] + limits: + $ref: '#/components/schemas/AgentNetworkPolicyLimits' + created_at: + type: string + format: date-time + readOnly: true + example: "2026-04-26T10:30:00Z" + updated_at: + type: string + format: date-time + readOnly: true + example: "2026-04-26T10:30:00Z" + required: + - id + - name + - enabled + - target_groups + - target_users + - limits + - created_at + - updated_at + AgentNetworkBudgetRuleRequest: + type: object + properties: + name: + type: string + description: Display name for the budget rule. + example: "Org monthly ceiling" + enabled: + type: boolean + description: Whether the rule is enforced. Defaults to true on create. + example: true + target_groups: + type: array + description: NetBird group ids the rule binds. Empty plus empty target_users means account-wide. + items: + type: string + example: ["ch8vp3o6lnna9hg0sd8g"] + target_users: + type: array + description: NetBird user ids the rule binds directly. + items: + type: string + example: [] + limits: + $ref: '#/components/schemas/AgentNetworkPolicyLimits' + required: + - name + - limits responses: not_found: description: Resource not found @@ -12068,3 +13116,1016 @@ paths: "$ref": "#/components/responses/not_found" '500': "$ref": "#/components/responses/internal_error" + /api/agent-network/access-logs: + get: + summary: List Agent Network access logs + description: Returns a paginated, server-side-filtered list of agent-network (LLM) access log entries. Available only when the account has log collection enabled; otherwise entries are not retained. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: query + name: page + schema: + type: integer + default: 1 + minimum: 1 + description: Page number for pagination (1-indexed). + - in: query + name: page_size + schema: + type: integer + default: 50 + minimum: 1 + maximum: 100 + description: Number of items per page (max 100). + - in: query + name: sort_by + schema: + type: string + enum: [timestamp, model, provider, status_code, duration, cost_usd, total_tokens, user_id, decision] + default: timestamp + description: Field to sort by. + - in: query + name: sort_order + schema: + type: string + enum: [asc, desc] + default: desc + description: Sort order (ascending or descending). + - in: query + name: search + schema: + type: string + description: General search across log ID, host, path, model, and user email/name. + - in: query + name: user_id + schema: + type: string + description: Filter by authenticated user ID. + - in: query + name: session_id + schema: + type: string + description: Filter to a single conversation / coding session id (groups all requests of one session). + - in: query + name: group_id + schema: + type: array + items: + type: string + style: form + explode: true + description: Filter by authorising group id. Repeat for multiple (matches any). + - in: query + name: provider_id + schema: + type: array + items: + type: string + style: form + explode: true + description: Filter by resolved provider id. Repeat for multiple (matches any). + - in: query + name: model + schema: + type: array + items: + type: string + style: form + explode: true + description: Filter by model. Repeat for multiple (matches any). + - in: query + name: decision + schema: + type: string + description: Filter by policy decision (e.g. allow, deny). + - in: query + name: path + schema: + type: string + description: Filter by request path prefix (matches entries whose path starts with this value). + - in: query + name: start_date + schema: + type: string + format: date-time + description: Filter by timestamp >= start_date (RFC3339 format). + - in: query + name: end_date + schema: + type: string + format: date-time + description: Filter by timestamp <= end_date (RFC3339 format). + responses: + '200': + description: Paginated list of agent-network access logs + content: + application/json: + schema: + $ref: "#/components/schemas/AgentNetworkAccessLogsResponse" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/agent-network/access-log-sessions: + get: + summary: List Agent Network access logs grouped by session + description: Returns a paginated, server-side-filtered list of agent-network (LLM) access logs grouped by session. The page unit is a session (total_records counts sessions); each session carries an aggregate summary and its ordered entries. Requests the client sent no session id for each form their own singleton group. Accepts the same filters as the flat access-logs endpoint. Available only when the account has log collection enabled. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: query + name: page + schema: + type: integer + default: 1 + minimum: 1 + description: Page number for pagination (1-indexed). + - in: query + name: page_size + schema: + type: integer + default: 50 + minimum: 1 + maximum: 100 + description: Number of sessions per page (max 100). + - in: query + name: sort_by + schema: + type: string + enum: [timestamp, started_at, cost_usd, total_tokens, duration, request_count, status_code, user_id, decision] + default: timestamp + description: Session-level field to sort by. "timestamp" is the session's last activity, "started_at" its first. + - in: query + name: sort_order + schema: + type: string + enum: [asc, desc] + default: desc + description: Sort order (ascending or descending). + - in: query + name: search + schema: + type: string + description: General search across log ID, host, path, model, and user email/name. + - in: query + name: user_id + schema: + type: string + description: Filter by authenticated user ID. + - in: query + name: session_id + schema: + type: string + description: Filter to a single conversation / coding session id. + - in: query + name: group_id + schema: + type: array + items: + type: string + style: form + explode: true + description: Filter by authorising group id. Repeat for multiple (matches any). + - in: query + name: provider_id + schema: + type: array + items: + type: string + style: form + explode: true + description: Filter by resolved provider id. Repeat for multiple (matches any). + - in: query + name: model + schema: + type: array + items: + type: string + style: form + explode: true + description: Filter by model. Repeat for multiple (matches any). + - in: query + name: decision + schema: + type: string + description: Filter by policy decision (e.g. allow, deny). + - in: query + name: path + schema: + type: string + description: Filter by request path prefix (matches entries whose path starts with this value). + - in: query + name: start_date + schema: + type: string + format: date-time + description: Filter by timestamp >= start_date (RFC3339 format). + - in: query + name: end_date + schema: + type: string + format: date-time + description: Filter by timestamp <= end_date (RFC3339 format). + responses: + '200': + description: Paginated list of session-grouped agent-network access logs + content: + application/json: + schema: + $ref: "#/components/schemas/AgentNetworkAccessLogSessionsResponse" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/agent-network/usage/overview: + get: + summary: Agent Network usage overview + description: Returns agent-network token and cost usage aggregated into time buckets, server-side filtered. Usage is always collected (independent of log collection). + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: query + name: granularity + schema: + type: string + enum: [day, week, month] + default: day + description: Time bucket width. Defaults to day. + - in: query + name: start_date + schema: + type: string + format: date-time + description: Filter by timestamp >= start_date (RFC3339 format). + - in: query + name: end_date + schema: + type: string + format: date-time + description: Filter by timestamp <= end_date (RFC3339 format). + - in: query + name: user_id + schema: + type: string + description: Filter by user ID. + - in: query + name: session_id + schema: + type: string + description: Filter to a single conversation / coding session id. + - in: query + name: group_id + schema: + type: array + items: + type: string + style: form + explode: true + description: Filter by authorising group id. Repeat for multiple (matches any). + - in: query + name: provider_id + schema: + type: array + items: + type: string + style: form + explode: true + description: Filter by resolved provider id. Repeat for multiple (matches any). + - in: query + name: model + schema: + type: array + items: + type: string + style: form + explode: true + description: Filter by model. Repeat for multiple (matches any). + responses: + '200': + description: A JSON array of aggregated usage buckets, ordered oldest-first. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/AgentNetworkUsageBucket' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/agent-network/consumption: + get: + summary: List Agent Network consumption counters + description: Returns every per-(dimension, window) consumption counter recorded for the account, ordered window-newest-first. Empty list when nothing has been consumed yet. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of consumption counter rows + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/AgentNetworkConsumption' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/agent-network/settings: + get: + summary: Retrieve Agent Network settings + description: Returns the per-account Agent Network gateway settings (cluster, subdomain, endpoint). Returns 404 when no provider has been created yet — settings are lazily bootstrapped on first provider create. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: Agent Network settings for the account + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkSettings' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update Agent Network settings + description: Updates the mutable account-level Agent Network settings (collection toggles). Cluster and subdomain are immutable and ignored if sent. Returns 404 when settings have not been bootstrapped (no provider created yet). + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + description: Settings update request + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkSettingsRequest' + responses: + '200': + description: Updated Agent Network settings + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkSettings' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + /api/agent-network/budget-rules: + get: + summary: List all Agent Network budget rules + description: Returns all account-level budget rules. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of Agent Network budget rules + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/AgentNetworkBudgetRule' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create an Agent Network budget rule + description: Creates a new account-level budget rule. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + description: New budget rule request + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkBudgetRuleRequest' + responses: + '200': + description: Budget rule created + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkBudgetRule' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/agent-network/budget-rules/{ruleId}: + get: + summary: Retrieve an Agent Network budget rule + description: Get a specific account-level budget rule. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: ruleId + required: true + schema: + type: string + description: The unique identifier of a budget rule + responses: + '200': + description: An Agent Network budget rule object + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkBudgetRule' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update an Agent Network budget rule + description: Updates an existing account-level budget rule. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: ruleId + required: true + schema: + type: string + description: The unique identifier of a budget rule + requestBody: + description: Budget rule update request + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkBudgetRuleRequest' + responses: + '200': + description: Budget rule updated + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkBudgetRule' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete an Agent Network budget rule + description: Deletes an account-level budget rule. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: ruleId + required: true + schema: + type: string + description: The unique identifier of a budget rule + responses: + '200': + description: Budget rule deleted + content: + application/json: + schema: + type: object + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + /api/agent-network/catalog/providers: + get: + summary: List Agent Network catalog providers + description: Returns the static catalog of supported Agent Network providers (OpenAI, Anthropic, …) along with their default upstream host, auth header template, brand color, and known models. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of catalog providers + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/AgentNetworkCatalogProvider' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/agent-network/providers: + get: + summary: List all Agent Network Providers + description: Returns a list of all Agent Network AI providers configured for the account. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of Agent Network providers + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/AgentNetworkProvider' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create an Agent Network Provider + description: Connects a new Agent Network AI provider for the account. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + description: New provider request + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkProviderRequest' + responses: + '200': + description: Provider created + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkProvider' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '409': + "$ref": "#/components/responses/conflict" + '500': + "$ref": "#/components/responses/internal_error" + /api/agent-network/providers/{providerId}: + get: + summary: Retrieve an Agent Network Provider + description: Get information about a specific Agent Network AI provider. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: providerId + required: true + schema: + type: string + description: The unique identifier of an Agent Network provider + responses: + '200': + description: An Agent Network provider object + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkProvider' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update an Agent Network Provider + description: Update an existing Agent Network AI provider. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: providerId + required: true + schema: + type: string + description: The unique identifier of an Agent Network provider + requestBody: + description: Provider update request + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkProviderRequest' + responses: + '200': + description: Provider updated + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkProvider' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '409': + "$ref": "#/components/responses/conflict" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete an Agent Network Provider + description: Delete an existing Agent Network AI provider. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: providerId + required: true + schema: + type: string + description: The unique identifier of an Agent Network provider + responses: + '200': + description: Provider deleted + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + /api/agent-network/policies: + get: + summary: List all Agent Network Policies + description: Returns a list of all Agent Network policies for the account. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of Agent Network policies + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/AgentNetworkPolicy' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create an Agent Network Policy + description: Creates a new Agent Network policy binding source groups to destination providers, optionally enforced by guardrails. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + description: New policy request + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkPolicyRequest' + responses: + '200': + description: Policy created + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkPolicy' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '409': + "$ref": "#/components/responses/conflict" + '500': + "$ref": "#/components/responses/internal_error" + /api/agent-network/policies/{policyId}: + get: + summary: Retrieve an Agent Network Policy + description: Get information about a specific Agent Network policy. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: policyId + required: true + schema: + type: string + description: The unique identifier of an Agent Network policy + responses: + '200': + description: An Agent Network policy object + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkPolicy' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update an Agent Network Policy + description: Update an existing Agent Network policy. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: policyId + required: true + schema: + type: string + description: The unique identifier of an Agent Network policy + requestBody: + description: Policy update request + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkPolicyRequest' + responses: + '200': + description: Policy updated + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkPolicy' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '409': + "$ref": "#/components/responses/conflict" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete an Agent Network Policy + description: Delete an existing Agent Network policy. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: policyId + required: true + schema: + type: string + description: The unique identifier of an Agent Network policy + responses: + '200': + description: Policy deleted + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + /api/agent-network/guardrails: + get: + summary: List all Agent Network Guardrails + description: Returns a list of all Agent Network guardrails for the account. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of Agent Network guardrails + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/AgentNetworkGuardrail' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create an Agent Network Guardrail + description: Creates a new Agent Network guardrail that can be attached to one or more policies. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + description: New guardrail request + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkGuardrailRequest' + responses: + '200': + description: Guardrail created + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkGuardrail' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '409': + "$ref": "#/components/responses/conflict" + '500': + "$ref": "#/components/responses/internal_error" + /api/agent-network/guardrails/{guardrailId}: + get: + summary: Retrieve an Agent Network Guardrail + description: Get information about a specific Agent Network guardrail. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: guardrailId + required: true + schema: + type: string + description: The unique identifier of an Agent Network guardrail + responses: + '200': + description: An Agent Network guardrail object + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkGuardrail' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update an Agent Network Guardrail + description: Update an existing Agent Network guardrail. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: guardrailId + required: true + schema: + type: string + description: The unique identifier of an Agent Network guardrail + requestBody: + description: Guardrail update request + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkGuardrailRequest' + responses: + '200': + description: Guardrail updated + content: + application/json: + schema: + $ref: '#/components/schemas/AgentNetworkGuardrail' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '409': + "$ref": "#/components/responses/conflict" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete an Agent Network Guardrail + description: Delete an existing Agent Network guardrail. + tags: [ Agent Network ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: guardrailId + required: true + schema: + type: string + description: The unique identifier of an Agent Network guardrail + responses: + '200': + description: Guardrail deleted + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index ed5060a86..7ea9514c0 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -38,6 +38,45 @@ func (e AccessRestrictionsCrowdsecMode) Valid() bool { } } +// Defines values for AgentNetworkCatalogProviderKind. +const ( + AgentNetworkCatalogProviderKindCustom AgentNetworkCatalogProviderKind = "custom" + AgentNetworkCatalogProviderKindGateway AgentNetworkCatalogProviderKind = "gateway" + AgentNetworkCatalogProviderKindProvider AgentNetworkCatalogProviderKind = "provider" +) + +// Valid indicates whether the value is a known member of the AgentNetworkCatalogProviderKind enum. +func (e AgentNetworkCatalogProviderKind) Valid() bool { + switch e { + case AgentNetworkCatalogProviderKindCustom: + return true + case AgentNetworkCatalogProviderKindGateway: + return true + case AgentNetworkCatalogProviderKindProvider: + return true + default: + return false + } +} + +// Defines values for AgentNetworkConsumptionDimensionKind. +const ( + AgentNetworkConsumptionDimensionKindGroup AgentNetworkConsumptionDimensionKind = "group" + AgentNetworkConsumptionDimensionKindUser AgentNetworkConsumptionDimensionKind = "user" +) + +// Valid indicates whether the value is a known member of the AgentNetworkConsumptionDimensionKind enum. +func (e AgentNetworkConsumptionDimensionKind) Valid() bool { + switch e { + case AgentNetworkConsumptionDimensionKindGroup: + return true + case AgentNetworkConsumptionDimensionKindUser: + return true + default: + return false + } +} + // Defines values for CreateAzureIntegrationRequestHost. const ( CreateAzureIntegrationRequestHostMicrosoftCom CreateAzureIntegrationRequestHost = "microsoft.com" @@ -1163,6 +1202,141 @@ func (e WorkloadType) Valid() bool { } } +// Defines values for GetApiAgentNetworkAccessLogSessionsParamsSortBy. +const ( + GetApiAgentNetworkAccessLogSessionsParamsSortByCostUsd GetApiAgentNetworkAccessLogSessionsParamsSortBy = "cost_usd" + GetApiAgentNetworkAccessLogSessionsParamsSortByDecision GetApiAgentNetworkAccessLogSessionsParamsSortBy = "decision" + GetApiAgentNetworkAccessLogSessionsParamsSortByDuration GetApiAgentNetworkAccessLogSessionsParamsSortBy = "duration" + GetApiAgentNetworkAccessLogSessionsParamsSortByRequestCount GetApiAgentNetworkAccessLogSessionsParamsSortBy = "request_count" + GetApiAgentNetworkAccessLogSessionsParamsSortByStartedAt GetApiAgentNetworkAccessLogSessionsParamsSortBy = "started_at" + GetApiAgentNetworkAccessLogSessionsParamsSortByStatusCode GetApiAgentNetworkAccessLogSessionsParamsSortBy = "status_code" + GetApiAgentNetworkAccessLogSessionsParamsSortByTimestamp GetApiAgentNetworkAccessLogSessionsParamsSortBy = "timestamp" + GetApiAgentNetworkAccessLogSessionsParamsSortByTotalTokens GetApiAgentNetworkAccessLogSessionsParamsSortBy = "total_tokens" + GetApiAgentNetworkAccessLogSessionsParamsSortByUserId GetApiAgentNetworkAccessLogSessionsParamsSortBy = "user_id" +) + +// Valid indicates whether the value is a known member of the GetApiAgentNetworkAccessLogSessionsParamsSortBy enum. +func (e GetApiAgentNetworkAccessLogSessionsParamsSortBy) Valid() bool { + switch e { + case GetApiAgentNetworkAccessLogSessionsParamsSortByCostUsd: + return true + case GetApiAgentNetworkAccessLogSessionsParamsSortByDecision: + return true + case GetApiAgentNetworkAccessLogSessionsParamsSortByDuration: + return true + case GetApiAgentNetworkAccessLogSessionsParamsSortByRequestCount: + return true + case GetApiAgentNetworkAccessLogSessionsParamsSortByStartedAt: + return true + case GetApiAgentNetworkAccessLogSessionsParamsSortByStatusCode: + return true + case GetApiAgentNetworkAccessLogSessionsParamsSortByTimestamp: + return true + case GetApiAgentNetworkAccessLogSessionsParamsSortByTotalTokens: + return true + case GetApiAgentNetworkAccessLogSessionsParamsSortByUserId: + return true + default: + return false + } +} + +// Defines values for GetApiAgentNetworkAccessLogSessionsParamsSortOrder. +const ( + GetApiAgentNetworkAccessLogSessionsParamsSortOrderAsc GetApiAgentNetworkAccessLogSessionsParamsSortOrder = "asc" + GetApiAgentNetworkAccessLogSessionsParamsSortOrderDesc GetApiAgentNetworkAccessLogSessionsParamsSortOrder = "desc" +) + +// Valid indicates whether the value is a known member of the GetApiAgentNetworkAccessLogSessionsParamsSortOrder enum. +func (e GetApiAgentNetworkAccessLogSessionsParamsSortOrder) Valid() bool { + switch e { + case GetApiAgentNetworkAccessLogSessionsParamsSortOrderAsc: + return true + case GetApiAgentNetworkAccessLogSessionsParamsSortOrderDesc: + return true + default: + return false + } +} + +// Defines values for GetApiAgentNetworkAccessLogsParamsSortBy. +const ( + GetApiAgentNetworkAccessLogsParamsSortByCostUsd GetApiAgentNetworkAccessLogsParamsSortBy = "cost_usd" + GetApiAgentNetworkAccessLogsParamsSortByDecision GetApiAgentNetworkAccessLogsParamsSortBy = "decision" + GetApiAgentNetworkAccessLogsParamsSortByDuration GetApiAgentNetworkAccessLogsParamsSortBy = "duration" + GetApiAgentNetworkAccessLogsParamsSortByModel GetApiAgentNetworkAccessLogsParamsSortBy = "model" + GetApiAgentNetworkAccessLogsParamsSortByProvider GetApiAgentNetworkAccessLogsParamsSortBy = "provider" + GetApiAgentNetworkAccessLogsParamsSortByStatusCode GetApiAgentNetworkAccessLogsParamsSortBy = "status_code" + GetApiAgentNetworkAccessLogsParamsSortByTimestamp GetApiAgentNetworkAccessLogsParamsSortBy = "timestamp" + GetApiAgentNetworkAccessLogsParamsSortByTotalTokens GetApiAgentNetworkAccessLogsParamsSortBy = "total_tokens" + GetApiAgentNetworkAccessLogsParamsSortByUserId GetApiAgentNetworkAccessLogsParamsSortBy = "user_id" +) + +// Valid indicates whether the value is a known member of the GetApiAgentNetworkAccessLogsParamsSortBy enum. +func (e GetApiAgentNetworkAccessLogsParamsSortBy) Valid() bool { + switch e { + case GetApiAgentNetworkAccessLogsParamsSortByCostUsd: + return true + case GetApiAgentNetworkAccessLogsParamsSortByDecision: + return true + case GetApiAgentNetworkAccessLogsParamsSortByDuration: + return true + case GetApiAgentNetworkAccessLogsParamsSortByModel: + return true + case GetApiAgentNetworkAccessLogsParamsSortByProvider: + return true + case GetApiAgentNetworkAccessLogsParamsSortByStatusCode: + return true + case GetApiAgentNetworkAccessLogsParamsSortByTimestamp: + return true + case GetApiAgentNetworkAccessLogsParamsSortByTotalTokens: + return true + case GetApiAgentNetworkAccessLogsParamsSortByUserId: + return true + default: + return false + } +} + +// Defines values for GetApiAgentNetworkAccessLogsParamsSortOrder. +const ( + GetApiAgentNetworkAccessLogsParamsSortOrderAsc GetApiAgentNetworkAccessLogsParamsSortOrder = "asc" + GetApiAgentNetworkAccessLogsParamsSortOrderDesc GetApiAgentNetworkAccessLogsParamsSortOrder = "desc" +) + +// Valid indicates whether the value is a known member of the GetApiAgentNetworkAccessLogsParamsSortOrder enum. +func (e GetApiAgentNetworkAccessLogsParamsSortOrder) Valid() bool { + switch e { + case GetApiAgentNetworkAccessLogsParamsSortOrderAsc: + return true + case GetApiAgentNetworkAccessLogsParamsSortOrderDesc: + return true + default: + return false + } +} + +// Defines values for GetApiAgentNetworkUsageOverviewParamsGranularity. +const ( + GetApiAgentNetworkUsageOverviewParamsGranularityDay GetApiAgentNetworkUsageOverviewParamsGranularity = "day" + GetApiAgentNetworkUsageOverviewParamsGranularityMonth GetApiAgentNetworkUsageOverviewParamsGranularity = "month" + GetApiAgentNetworkUsageOverviewParamsGranularityWeek GetApiAgentNetworkUsageOverviewParamsGranularity = "week" +) + +// Valid indicates whether the value is a known member of the GetApiAgentNetworkUsageOverviewParamsGranularity enum. +func (e GetApiAgentNetworkUsageOverviewParamsGranularity) Valid() bool { + switch e { + case GetApiAgentNetworkUsageOverviewParamsGranularityDay: + return true + case GetApiAgentNetworkUsageOverviewParamsGranularityMonth: + return true + case GetApiAgentNetworkUsageOverviewParamsGranularityWeek: + return true + default: + return false + } +} + // Defines values for GetApiEventsNetworkTrafficParamsType. const ( GetApiEventsNetworkTrafficParamsTypeTYPEDROP GetApiEventsNetworkTrafficParamsType = "TYPE_DROP" @@ -1541,6 +1715,627 @@ type AccountSettings struct { RoutingPeerDnsResolutionEnabled *bool `json:"routing_peer_dns_resolution_enabled,omitempty"` } +// AgentNetworkAccessLog One per-request agent-network (LLM) access log entry with flattened, queryable LLM dimensions. +type AgentNetworkAccessLog struct { + // CostUsd Estimated USD cost of the request. + CostUsd float64 `json:"cost_usd"` + + // Decision Policy decision for the request (e.g. allow, deny). + Decision *string `json:"decision,omitempty"` + + // DenyReason Raw deny reason code when the request was blocked (e.g. llm_policy.token_cap_exceeded). + DenyReason *string `json:"deny_reason,omitempty"` + + // DurationMs Duration of the request in milliseconds. + DurationMs int `json:"duration_ms"` + + // GroupIds NetBird group ids that authorised the request (the caller's groups intersected with the policy's source groups). + GroupIds *[]string `json:"group_ids,omitempty"` + + // Host Upstream host the request was routed to. Empty when log collection is disabled. + Host *string `json:"host,omitempty"` + + // Id Unique identifier for the access log entry. + Id string `json:"id"` + + // InputTokens Input (prompt) tokens consumed. + InputTokens int64 `json:"input_tokens"` + + // Method HTTP method of the request. + Method *string `json:"method,omitempty"` + + // Model Requested LLM model. + Model *string `json:"model,omitempty"` + + // OutputTokens Output (completion) tokens produced. + OutputTokens int64 `json:"output_tokens"` + + // Path Request path. Empty when log collection is disabled. + Path *string `json:"path,omitempty"` + + // Provider LLM provider vendor (e.g. openai, anthropic). + Provider *string `json:"provider,omitempty"` + + // RequestPrompt Captured request prompt. Present only when prompt collection is enabled. + RequestPrompt *string `json:"request_prompt,omitempty"` + + // ResolvedProviderId NetBird agent-network provider id that served the request. + ResolvedProviderId *string `json:"resolved_provider_id,omitempty"` + + // ResponseCompletion Captured response completion. Present only when prompt collection is enabled. + ResponseCompletion *string `json:"response_completion,omitempty"` + + // SelectedPolicyId Agent-network policy id that authorised (or denied) the request. + SelectedPolicyId *string `json:"selected_policy_id,omitempty"` + + // ServiceId ID of the synthesised agent-network service that handled the request. + ServiceId string `json:"service_id"` + + // SessionId Conversation / coding-session identifier that groups related requests. Sourced from the client's session marker (e.g. OpenAI Codex client_metadata.session_id, Claude Code metadata.user_id). Empty for clients that send none. + SessionId *string `json:"session_id,omitempty"` + + // SourceIp Source IP of the request. Empty when log collection is disabled. + SourceIp *string `json:"source_ip,omitempty"` + + // StatusCode HTTP status code returned upstream. + StatusCode int `json:"status_code"` + + // Stream Whether the request was a streaming completion. + Stream *bool `json:"stream,omitempty"` + + // Timestamp Timestamp when the request was made. + Timestamp time.Time `json:"timestamp"` + + // TotalTokens Total tokens consumed. + TotalTokens int64 `json:"total_tokens"` + + // UserId NetBird user id of the authenticated caller, if applicable. + UserId *string `json:"user_id,omitempty"` +} + +// AgentNetworkAccessLogSession A session-grouped view of agent-network access logs — all requests sharing a session id (or a single session-less request) folded into one summary plus its ordered entries. +type AgentNetworkAccessLogSession struct { + // CostUsd Total estimated USD cost across the session. + CostUsd float64 `json:"cost_usd"` + + // Decision Session decision — "deny" if any request was denied, otherwise "allow". + Decision string `json:"decision"` + + // EndedAt Timestamp of the session's latest request. + EndedAt time.Time `json:"ended_at"` + + // Entries The session's access-log entries, oldest first. + Entries []AgentNetworkAccessLog `json:"entries"` + + // GroupIds Union of the authorising group ids across the session's entries. + GroupIds *[]string `json:"group_ids,omitempty"` + + // InputTokens Total input (prompt) tokens across the session. + InputTokens int64 `json:"input_tokens"` + + // Models Distinct models seen in the session. + Models *[]string `json:"models,omitempty"` + + // OutputTokens Total output (completion) tokens across the session. + OutputTokens int64 `json:"output_tokens"` + + // Providers Distinct LLM provider vendors seen in the session. + Providers *[]string `json:"providers,omitempty"` + + // RequestCount Number of requests in the session. + RequestCount int `json:"request_count"` + + // SessionId Conversation / coding-session identifier shared by the entries. Empty for a session-less (singleton) request grouped on its own id. + SessionId *string `json:"session_id,omitempty"` + + // StartedAt Timestamp of the session's earliest request. + StartedAt time.Time `json:"started_at"` + + // TotalTokens Total tokens across the session. + TotalTokens int64 `json:"total_tokens"` + + // UserId NetBird user id of the session's caller. + UserId *string `json:"user_id,omitempty"` +} + +// AgentNetworkAccessLogSessionsResponse defines model for AgentNetworkAccessLogSessionsResponse. +type AgentNetworkAccessLogSessionsResponse struct { + // Data List of session-grouped agent-network access logs. + Data []AgentNetworkAccessLogSession `json:"data"` + + // Page Current page number. + Page int `json:"page"` + + // PageSize Number of sessions per page. + PageSize int `json:"page_size"` + + // TotalPages Total number of pages available. + TotalPages int `json:"total_pages"` + + // TotalRecords Total number of sessions matching the filter. + TotalRecords int `json:"total_records"` +} + +// AgentNetworkAccessLogsResponse defines model for AgentNetworkAccessLogsResponse. +type AgentNetworkAccessLogsResponse struct { + // Data List of agent-network access log entries. + Data []AgentNetworkAccessLog `json:"data"` + + // Page Current page number. + Page int `json:"page"` + + // PageSize Number of items per page. + PageSize int `json:"page_size"` + + // TotalPages Total number of pages available. + TotalPages int `json:"total_pages"` + + // TotalRecords Total number of log records matching the filter. + TotalRecords int `json:"total_records"` +} + +// AgentNetworkBudgetRule Account-level budget rule. A limit-only rule bound to groups and/or users that applies across all policies as a min-wins ceiling. Empty targets means it applies to every caller. +type AgentNetworkBudgetRule struct { + CreatedAt *time.Time `json:"created_at,omitempty"` + + // Enabled Whether the rule is enforced. + Enabled bool `json:"enabled"` + + // Id Budget rule ID. + Id string `json:"id"` + + // Limits Token and budget caps attached directly to the policy. These compose with any guardrail-level checks. + Limits AgentNetworkPolicyLimits `json:"limits"` + + // Name Display name for the budget rule. + Name string `json:"name"` + + // TargetGroups NetBird group ids the rule binds. Empty plus empty target_users means account-wide. + TargetGroups []string `json:"target_groups"` + + // TargetUsers NetBird user ids the rule binds directly. + TargetUsers []string `json:"target_users"` + UpdatedAt *time.Time `json:"updated_at,omitempty"` +} + +// AgentNetworkBudgetRuleRequest defines model for AgentNetworkBudgetRuleRequest. +type AgentNetworkBudgetRuleRequest struct { + // Enabled Whether the rule is enforced. Defaults to true on create. + Enabled *bool `json:"enabled,omitempty"` + + // Limits Token and budget caps attached directly to the policy. These compose with any guardrail-level checks. + Limits AgentNetworkPolicyLimits `json:"limits"` + + // Name Display name for the budget rule. + Name string `json:"name"` + + // TargetGroups NetBird group ids the rule binds. Empty plus empty target_users means account-wide. + TargetGroups *[]string `json:"target_groups,omitempty"` + + // TargetUsers NetBird user ids the rule binds directly. + TargetUsers *[]string `json:"target_users,omitempty"` +} + +// AgentNetworkCatalogExtraHeader One optional per-provider routing/config header surfaced on the dashboard. Operator-typed value lives on the provider record's `extra_values` map keyed by `name`. UI copy (input label, helper line, tooltip) is owned by the dashboard, keyed by `name`. +type AgentNetworkCatalogExtraHeader struct { + // Name Wire header name the proxy stamps with the operator-typed value. + Name string `json:"name"` +} + +// AgentNetworkCatalogHeaderPairInjection HeaderPair identity-injection shape — separate per-dimension headers (LiteLLM-style, Bifrost). +type AgentNetworkCatalogHeaderPairInjection struct { + // Customizable When true, the wire header names are operator-overridable per provider record (Bifrost). When false, the catalog values are authoritative (LiteLLM and similar gateways with a fixed wire protocol). + Customizable bool `json:"customizable"` + + // EndUserIdHeader Wire header name for the caller's display identity. Default placeholder when `customizable` is true. + EndUserIdHeader string `json:"end_user_id_header"` + + // TagsHeader Wire header name for the caller's groups CSV. Default placeholder when `customizable` is true. + TagsHeader string `json:"tags_header"` +} + +// AgentNetworkCatalogIdentityInjection Catalog-declared identity-injection shape. Present when this provider supports stamping the caller's NetBird identity onto upstream requests. Exactly one of `header_pair` or `json_metadata` is set per provider entry. The dashboard reads the `customizable` flag on whichever shape is present to decide whether to surface the labels as editable inputs (true → editable with the catalog values shown as placeholders; false → fixed and read-only). +type AgentNetworkCatalogIdentityInjection struct { + // HeaderPair HeaderPair identity-injection shape — separate per-dimension headers (LiteLLM-style, Bifrost). + HeaderPair *AgentNetworkCatalogHeaderPairInjection `json:"header_pair,omitempty"` + + // JsonMetadata JSONMetadata identity-injection shape — one wire header carrying a JSON object whose keys label each dimension (Portkey-style, Cloudflare AI Gateway). + JsonMetadata *AgentNetworkCatalogJSONMetadataInjection `json:"json_metadata,omitempty"` +} + +// AgentNetworkCatalogJSONMetadataInjection JSONMetadata identity-injection shape — one wire header carrying a JSON object whose keys label each dimension (Portkey-style, Cloudflare AI Gateway). +type AgentNetworkCatalogJSONMetadataInjection struct { + // Customizable When true, the JSON keys are operator-overridable per provider record (Cloudflare). The wire header itself stays catalog-owned. When false, the catalog values are authoritative (Portkey and similar gateways with a fixed JSON schema). + Customizable bool `json:"customizable"` + + // GroupsKey JSON key for the caller's groups CSV. Default placeholder when `customizable` is true. + GroupsKey string `json:"groups_key"` + + // Header Wire header name carrying the JSON metadata payload. Catalog-owned (not customizable per provider record). + Header string `json:"header"` + + // UserKey JSON key for the caller's display identity. Default placeholder when `customizable` is true. + UserKey string `json:"user_key"` +} + +// AgentNetworkCatalogModel defines model for AgentNetworkCatalogModel. +type AgentNetworkCatalogModel struct { + // ContextWindow Maximum context window in tokens. + ContextWindow int `json:"context_window"` + + // Id Catalog model identifier as exposed by the upstream provider. + Id string `json:"id"` + + // InputPer1k Input token price per 1k tokens, in USD. + InputPer1k float64 `json:"input_per_1k"` + + // Label Human-friendly model name for the dashboard. + Label string `json:"label"` + + // OutputPer1k Output token price per 1k tokens, in USD. + OutputPer1k float64 `json:"output_per_1k"` +} + +// AgentNetworkCatalogProvider defines model for AgentNetworkCatalogProvider. +type AgentNetworkCatalogProvider struct { + // AuthHeaderTemplate Template the proxy uses to inject the API key (the literal string ${API_KEY} is replaced at request time). + AuthHeaderTemplate string `json:"auth_header_template"` + + // BrandColor Hex brand color used to render the provider badge in the dashboard. + BrandColor string `json:"brand_color"` + + // DefaultContentType Default Content-Type for upstream requests. + DefaultContentType string `json:"default_content_type"` + + // DefaultHost Default upstream host suggested when adding a provider of this type. + DefaultHost string `json:"default_host"` + + // Description Short description shown in the provider picker. + Description string `json:"description"` + + // ExtraHeaders Catalog-declared list of optional per-provider routing/config headers the proxy stamps on every upstream request. Each entry surfaces an input on the dashboard's provider modal (one per item, labeled with `label`). Operators fill any subset; values land on the provider record's `extra_values` map keyed by `name`. Used by gateways like Portkey for `x-portkey-config: pc-...` (saved-config id resolving upstream provider + virtual key). + ExtraHeaders *[]AgentNetworkCatalogExtraHeader `json:"extra_headers,omitempty"` + + // Id Catalog provider identifier (referenced by AgentNetworkProvider.provider_id). + Id string `json:"id"` + + // IdentityInjection Catalog-declared identity-injection shape. Present when this provider supports stamping the caller's NetBird identity onto upstream requests. Exactly one of `header_pair` or `json_metadata` is set per provider entry. The dashboard reads the `customizable` flag on whichever shape is present to decide whether to surface the labels as editable inputs (true → editable with the catalog values shown as placeholders; false → fixed and read-only). + IdentityInjection *AgentNetworkCatalogIdentityInjection `json:"identity_injection,omitempty"` + + // Kind Presentation grouping for the provider Select on the dashboard. + // "provider" — first-party vendor API (OpenAI, Anthropic, …); the upstream is the model itself. + // "gateway" — routing/aggregation layer in front of multiple providers (LiteLLM, Portkey, …); typically pairs with NetBird identity stamping. + // "custom" — generic OpenAI-compatible self-hosted endpoint catch-all. + Kind AgentNetworkCatalogProviderKind `json:"kind"` + + // Models Catalog models available for this provider. + Models []AgentNetworkCatalogModel `json:"models"` + + // Name Display name for the provider. + Name string `json:"name"` +} + +// AgentNetworkCatalogProviderKind Presentation grouping for the provider Select on the dashboard. +// "provider" — first-party vendor API (OpenAI, Anthropic, …); the upstream is the model itself. +// "gateway" — routing/aggregation layer in front of multiple providers (LiteLLM, Portkey, …); typically pairs with NetBird identity stamping. +// "custom" — generic OpenAI-compatible self-hosted endpoint catch-all. +type AgentNetworkCatalogProviderKind string + +// AgentNetworkConsumption One per-(dimension, window) consumption counter row. The proxy ticks one row per dimension on every served LLM request; the dashboard reads this listing to surface live counter growth. +type AgentNetworkConsumption struct { + // CostUsd Total USD spend booked against this dimension for the window. + CostUsd float64 `json:"cost_usd"` + + // DimensionId NetBird user id (when `dimension_kind=user`) or NetBird group id (when `dimension_kind=group`). + DimensionId string `json:"dimension_id"` + + // DimensionKind Whether this row counts a single end user or a single source group across every member. + DimensionKind AgentNetworkConsumptionDimensionKind `json:"dimension_kind"` + + // TokensInput Total input tokens consumed within the window. + TokensInput int64 `json:"tokens_input"` + + // TokensOutput Total output tokens consumed within the window. + TokensOutput int64 `json:"tokens_output"` + + // UpdatedAt Timestamp of the last increment recorded for this row. + UpdatedAt *time.Time `json:"updated_at,omitempty"` + + // WindowSeconds Length of the aligned window this counter covers, in seconds. Distinct window lengths produce independent counters even on the same dimension. + WindowSeconds int64 `json:"window_seconds"` + + // WindowStartUtc UTC start of the aligned window this counter covers. Aligned to the unix epoch so every node computes the same boundary. + WindowStartUtc time.Time `json:"window_start_utc"` +} + +// AgentNetworkConsumptionDimensionKind Whether this row counts a single end user or a single source group across every member. +type AgentNetworkConsumptionDimensionKind string + +// AgentNetworkGuardrail defines model for AgentNetworkGuardrail. +type AgentNetworkGuardrail struct { + // Checks Guardrail check parameters. Each entry has an `enabled` flag plus per-check configuration; disabled entries are inert. + Checks AgentNetworkGuardrailChecks `json:"checks"` + + // CreatedAt Timestamp when the guardrail was created. + CreatedAt *time.Time `json:"created_at,omitempty"` + + // Description Optional human-readable description. + Description string `json:"description"` + + // Id Guardrail ID + Id string `json:"id"` + + // Name Display name for the guardrail. + Name string `json:"name"` + + // UpdatedAt Timestamp when the guardrail was last updated. + UpdatedAt *time.Time `json:"updated_at,omitempty"` +} + +// AgentNetworkGuardrailChecks Guardrail check parameters. Each entry has an `enabled` flag plus per-check configuration; disabled entries are inert. +type AgentNetworkGuardrailChecks struct { + ModelAllowlist struct { + Enabled bool `json:"enabled"` + + // Models Allowed catalog model ids. Requests for any other model are denied. + Models []string `json:"models"` + } `json:"model_allowlist"` + PromptCapture struct { + Enabled bool `json:"enabled"` + RedactPii bool `json:"redact_pii"` + } `json:"prompt_capture"` +} + +// AgentNetworkGuardrailRequest defines model for AgentNetworkGuardrailRequest. +type AgentNetworkGuardrailRequest struct { + // Checks Guardrail check parameters. Each entry has an `enabled` flag plus per-check configuration; disabled entries are inert. + Checks AgentNetworkGuardrailChecks `json:"checks"` + + // Description Optional human-readable description. + Description *string `json:"description,omitempty"` + + // Name Display name for the guardrail. + Name string `json:"name"` +} + +// AgentNetworkPolicy defines model for AgentNetworkPolicy. +type AgentNetworkPolicy struct { + // CreatedAt Timestamp when the policy was created. + CreatedAt *time.Time `json:"created_at,omitempty"` + + // Description Optional human-readable description. + Description string `json:"description"` + + // DestinationProviderIds Agent Network provider ids (returned by the providers API) the source groups can reach. + DestinationProviderIds []string `json:"destination_provider_ids"` + + // Enabled Whether the policy is enabled. + Enabled bool `json:"enabled"` + + // GuardrailIds Agent Network guardrail ids attached to this policy. + GuardrailIds []string `json:"guardrail_ids"` + + // Id Policy ID + Id string `json:"id"` + + // Limits Token and budget caps attached directly to the policy. These compose with any guardrail-level checks. + Limits AgentNetworkPolicyLimits `json:"limits"` + + // Name Display name for the policy. + Name string `json:"name"` + + // SourceGroups NetBird group ids whose members are allowed to call the destination providers. + SourceGroups []string `json:"source_groups"` + + // UpdatedAt Timestamp when the policy was last updated. + UpdatedAt *time.Time `json:"updated_at,omitempty"` +} + +// AgentNetworkPolicyBudgetLimit Per-policy USD spend cap. `group_cap_usd` is applied to each source group independently — every group in the policy's `source_groups` gets its own bucket of this size. `user_cap_usd` is applied independently to each individual user. Caps reset to zero at the start of each window. +type AgentNetworkPolicyBudgetLimit struct { + Enabled bool `json:"enabled"` + + // GroupCapUsd USD allowed per source group within the window (each group has its own bucket of this size). 0 means uncapped. + GroupCapUsd float64 `json:"group_cap_usd"` + + // UserCapUsd USD allowed per individual user within the window. 0 means uncapped. + UserCapUsd float64 `json:"user_cap_usd"` + + // WindowSeconds Reset frequency in seconds. Caps reset at the start of each window. Minimum 60 (one minute) when the limit is enabled. + WindowSeconds int64 `json:"window_seconds"` +} + +// AgentNetworkPolicyLimits Token and budget caps attached directly to the policy. These compose with any guardrail-level checks. +type AgentNetworkPolicyLimits struct { + // BudgetLimit Per-policy USD spend cap. `group_cap_usd` is applied to each source group independently — every group in the policy's `source_groups` gets its own bucket of this size. `user_cap_usd` is applied independently to each individual user. Caps reset to zero at the start of each window. + BudgetLimit AgentNetworkPolicyBudgetLimit `json:"budget_limit"` + + // TokenLimit Per-policy token cap. `group_cap` is applied to each source group independently — every group in the policy's `source_groups` gets its own bucket of this size. `user_cap` is applied independently to each individual user. Caps reset to zero at the start of each window. + TokenLimit AgentNetworkPolicyTokenLimit `json:"token_limit"` +} + +// AgentNetworkPolicyRequest defines model for AgentNetworkPolicyRequest. +type AgentNetworkPolicyRequest struct { + // Description Optional human-readable description. + Description *string `json:"description,omitempty"` + + // DestinationProviderIds Agent Network provider ids the source groups can reach. + DestinationProviderIds []string `json:"destination_provider_ids"` + + // Enabled Whether the policy is enabled. Defaults to true on create. + Enabled *bool `json:"enabled,omitempty"` + + // GuardrailIds Agent Network guardrail ids to attach to this policy. + GuardrailIds *[]string `json:"guardrail_ids,omitempty"` + + // Limits Token and budget caps attached directly to the policy. These compose with any guardrail-level checks. + Limits *AgentNetworkPolicyLimits `json:"limits,omitempty"` + + // Name Display name for the policy. + Name string `json:"name"` + + // SourceGroups NetBird group ids whose members are allowed to call the destination providers. + SourceGroups []string `json:"source_groups"` +} + +// AgentNetworkPolicyTokenLimit Per-policy token cap. `group_cap` is applied to each source group independently — every group in the policy's `source_groups` gets its own bucket of this size. `user_cap` is applied independently to each individual user. Caps reset to zero at the start of each window. +type AgentNetworkPolicyTokenLimit struct { + Enabled bool `json:"enabled"` + + // GroupCap Tokens allowed per source group within the window (each group has its own bucket of this size). 0 means uncapped. + GroupCap int64 `json:"group_cap"` + + // UserCap Tokens allowed per individual user within the window. 0 means uncapped. + UserCap int64 `json:"user_cap"` + + // WindowSeconds Reset frequency in seconds. The cap counter resets to zero at the start of each window. Minimum 60 (one minute) when the limit is enabled. + WindowSeconds int64 `json:"window_seconds"` +} + +// AgentNetworkProvider defines model for AgentNetworkProvider. +type AgentNetworkProvider struct { + // CreatedAt Timestamp when the provider was created. + CreatedAt *time.Time `json:"created_at,omitempty"` + + // Enabled Whether the provider is enabled. + Enabled bool `json:"enabled"` + + // ExtraValues Operator-typed values for catalog-declared extra headers. Keys are wire header names (e.g. `x-portkey-config`); values are the strings the proxy stamps on every upstream request to this provider. Catalog (AgentNetworkCatalogProvider.extra_headers) declares which keys are accepted; values not declared by the catalog are ignored at synth time. Empty / missing values mean no header stamped. + ExtraValues *map[string]string `json:"extra_values,omitempty"` + + // Id Provider ID + Id string `json:"id"` + + // IdentityHeaderGroups Wire header name the proxy stamps with the caller's NetBird groups as a comma-separated list (sorted) when the catalog entry's HeaderPair is `customizable`. Empty disables stamping for this dimension. Same per-catalog semantics as `identity_header_user_id`. + IdentityHeaderGroups *string `json:"identity_header_groups,omitempty"` + + // IdentityHeaderUserId Wire header name the proxy stamps with the caller's display identity (user email or peer name) when the catalog entry's HeaderPair is `customizable`. Empty disables stamping for this dimension. Ignored when the catalog entry has a fixed HeaderPair (e.g. LiteLLM, Portkey). Used today by Bifrost: typical values are `x-bf-lh-netbird_user_id` (always-on log metadata) or `x-bf-dim-netbird_user_id` (Prometheus / OTEL — requires the label to be pre-declared in the gateway's `client.prometheus_labels` config). + IdentityHeaderUserId *string `json:"identity_header_user_id,omitempty"` + + // Models Models exposed through this endpoint, with the operator's per-1k input/output prices. Empty means all catalog models are allowed at catalog prices. + Models []AgentNetworkProviderModel `json:"models"` + + // Name Display name shown in the dashboard. + Name string `json:"name"` + + // ProviderId Catalog identifier for the upstream AI provider (e.g. openai_api, anthropic_api, azure_openai_api, bedrock_api, vertex_ai_api, mistral_api, custom). + ProviderId string `json:"provider_id"` + + // UpdatedAt Timestamp when the provider was last updated. + UpdatedAt *time.Time `json:"updated_at,omitempty"` + + // UpstreamUrl Full upstream URL (with scheme) that NetBird forwards traffic to. + UpstreamUrl string `json:"upstream_url"` +} + +// AgentNetworkProviderModel A model exposed by the provider, with the operator's per-1k input/output prices in USD. +type AgentNetworkProviderModel struct { + // Id Model identifier (e.g. "gpt-4o-mini"). + Id string `json:"id"` + + // InputPer1k Cost per 1k input tokens, in USD. + InputPer1k float64 `json:"input_per_1k"` + + // OutputPer1k Cost per 1k output tokens, in USD. + OutputPer1k float64 `json:"output_per_1k"` +} + +// AgentNetworkProviderRequest defines model for AgentNetworkProviderRequest. +type AgentNetworkProviderRequest struct { + // ApiKey Upstream provider API key. Sealed at rest on the management server and never returned in responses. Required on create; optional on update (omit to keep the existing key). + ApiKey *string `json:"api_key,omitempty"` + + // BootstrapCluster Proxy cluster used to bootstrap the per-account agent-network endpoint when the first provider is created. Ignored on subsequent creates and on updates because the cluster is pinned on the account-level Settings row. + BootstrapCluster *string `json:"bootstrap_cluster,omitempty"` + + // Enabled Whether the provider is enabled. Defaults to true on create. + Enabled *bool `json:"enabled,omitempty"` + + // ExtraValues Operator-typed values for catalog-declared extra headers (see AgentNetworkProvider.extra_values). When present on a request, the whole map replaces the stored values. Empty strings drop the corresponding key. + ExtraValues *map[string]string `json:"extra_values,omitempty"` + + // IdentityHeaderGroups Wire header name for the caller's groups CSV. See AgentNetworkProvider.identity_header_groups. Same omit / empty semantics as `identity_header_user_id`. + IdentityHeaderGroups *string `json:"identity_header_groups,omitempty"` + + // IdentityHeaderUserId Wire header name for the caller's display identity. See AgentNetworkProvider.identity_header_user_id. When omitted on a request, the stored value is left unchanged; pass an empty string explicitly to clear it (which disables stamping for this dimension). + IdentityHeaderUserId *string `json:"identity_header_user_id,omitempty"` + + // Models Models exposed through this endpoint, with the operator's per-1k input/output prices. Empty means all catalog models are allowed at catalog prices. + Models *[]AgentNetworkProviderModel `json:"models,omitempty"` + + // Name Display name for the provider. + Name string `json:"name"` + + // ProviderId Catalog identifier for the upstream AI provider (e.g. openai_api, anthropic_api, azure_openai_api, bedrock_api, vertex_ai_api, mistral_api, custom). + ProviderId string `json:"provider_id"` + + // UpstreamUrl Full upstream URL (with scheme) that NetBird forwards traffic to. + UpstreamUrl string `json:"upstream_url"` +} + +// AgentNetworkSettings Per-account Agent Network gateway settings. One row per account; cluster and subdomain are auto-assigned on first provider create and immutable thereafter. +type AgentNetworkSettings struct { + // AccessLogRetentionDays Days to retain full access-log rows; older rows are swept. 0 or less means keep indefinitely. Usage records are retained independently. + AccessLogRetentionDays *int `json:"access_log_retention_days,omitempty"` + + // Cluster Address of the NetBird proxy cluster fronting this account's agent-network endpoint. + Cluster string `json:"cluster"` + + // CreatedAt Timestamp when the settings row was created. + CreatedAt *time.Time `json:"created_at,omitempty"` + + // EnableLogCollection Whether per-request access-log entries are collected for this account's agent-network traffic. + EnableLogCollection bool `json:"enable_log_collection"` + + // EnablePromptCollection Master switch for request/response prompt capture. Capture runs only when this is on AND a policy guardrail also enables it. + EnablePromptCollection bool `json:"enable_prompt_collection"` + + // Endpoint Bare hostname agents call for this account, computed as `.`. + Endpoint string `json:"endpoint"` + + // RedactPii Whether captured prompts have PII redacted. Effective redaction is the OR of this and any policy guardrail's redact setting. + RedactPii bool `json:"redact_pii"` + + // Subdomain Auto-generated DNS-safe label that prefixes the cluster to form the agent-network endpoint. + Subdomain string `json:"subdomain"` + + // UpdatedAt Timestamp when the settings row was last updated. + UpdatedAt *time.Time `json:"updated_at,omitempty"` +} + +// AgentNetworkSettingsRequest Mutable account-level Agent Network settings. Cluster and subdomain are immutable and not accepted here. +type AgentNetworkSettingsRequest struct { + // AccessLogRetentionDays Days to retain full access-log rows; older rows are swept. 0 or less means keep indefinitely. + AccessLogRetentionDays *int `json:"access_log_retention_days,omitempty"` + + // EnableLogCollection Whether per-request access-log entries are collected for this account's agent-network traffic. + EnableLogCollection bool `json:"enable_log_collection"` + + // EnablePromptCollection Master switch for request/response prompt capture. + EnablePromptCollection bool `json:"enable_prompt_collection"` + + // RedactPii Whether captured prompts have PII redacted. + RedactPii bool `json:"redact_pii"` +} + +// AgentNetworkUsageBucket One aggregated agent-network usage time bucket (UTC). The bucket width is set by the request's granularity. +type AgentNetworkUsageBucket struct { + // CostUsd Total estimated USD spend in the bucket. + CostUsd float64 `json:"cost_usd"` + + // InputTokens Total input (prompt) tokens in the bucket. + InputTokens int64 `json:"input_tokens"` + + // OutputTokens Total output (completion) tokens in the bucket. + OutputTokens int64 `json:"output_tokens"` + + // PeriodStart Start of the bucket in YYYY-MM-DD (UTC) — the day, the week start (Monday), or the month start, depending on granularity. + PeriodStart string `json:"period_start"` + + // TotalTokens Total tokens in the bucket. + TotalTokens int64 `json:"total_tokens"` +} + // AvailablePorts defines model for AvailablePorts. type AvailablePorts struct { // Tcp Number of available TCP ports left on the ingress peer @@ -4892,6 +5687,138 @@ type bearerAuthContextKey string // tokenAuthContextKey is the context key for TokenAuth security scheme type tokenAuthContextKey string +// GetApiAgentNetworkAccessLogSessionsParams defines parameters for GetApiAgentNetworkAccessLogSessions. +type GetApiAgentNetworkAccessLogSessionsParams struct { + // Page Page number for pagination (1-indexed). + Page *int `form:"page,omitempty" json:"page,omitempty"` + + // PageSize Number of sessions per page (max 100). + PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"` + + // SortBy Session-level field to sort by. "timestamp" is the session's last activity, "started_at" its first. + SortBy *GetApiAgentNetworkAccessLogSessionsParamsSortBy `form:"sort_by,omitempty" json:"sort_by,omitempty"` + + // SortOrder Sort order (ascending or descending). + SortOrder *GetApiAgentNetworkAccessLogSessionsParamsSortOrder `form:"sort_order,omitempty" json:"sort_order,omitempty"` + + // Search General search across log ID, host, path, model, and user email/name. + Search *string `form:"search,omitempty" json:"search,omitempty"` + + // UserId Filter by authenticated user ID. + UserId *string `form:"user_id,omitempty" json:"user_id,omitempty"` + + // SessionId Filter to a single conversation / coding session id. + SessionId *string `form:"session_id,omitempty" json:"session_id,omitempty"` + + // GroupId Filter by authorising group id. Repeat for multiple (matches any). + GroupId *[]string `form:"group_id,omitempty" json:"group_id,omitempty"` + + // ProviderId Filter by resolved provider id. Repeat for multiple (matches any). + ProviderId *[]string `form:"provider_id,omitempty" json:"provider_id,omitempty"` + + // Model Filter by model. Repeat for multiple (matches any). + Model *[]string `form:"model,omitempty" json:"model,omitempty"` + + // Decision Filter by policy decision (e.g. allow, deny). + Decision *string `form:"decision,omitempty" json:"decision,omitempty"` + + // Path Filter by request path prefix (matches entries whose path starts with this value). + Path *string `form:"path,omitempty" json:"path,omitempty"` + + // StartDate Filter by timestamp >= start_date (RFC3339 format). + StartDate *time.Time `form:"start_date,omitempty" json:"start_date,omitempty"` + + // EndDate Filter by timestamp <= end_date (RFC3339 format). + EndDate *time.Time `form:"end_date,omitempty" json:"end_date,omitempty"` +} + +// GetApiAgentNetworkAccessLogSessionsParamsSortBy defines parameters for GetApiAgentNetworkAccessLogSessions. +type GetApiAgentNetworkAccessLogSessionsParamsSortBy string + +// GetApiAgentNetworkAccessLogSessionsParamsSortOrder defines parameters for GetApiAgentNetworkAccessLogSessions. +type GetApiAgentNetworkAccessLogSessionsParamsSortOrder string + +// GetApiAgentNetworkAccessLogsParams defines parameters for GetApiAgentNetworkAccessLogs. +type GetApiAgentNetworkAccessLogsParams struct { + // Page Page number for pagination (1-indexed). + Page *int `form:"page,omitempty" json:"page,omitempty"` + + // PageSize Number of items per page (max 100). + PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"` + + // SortBy Field to sort by. + SortBy *GetApiAgentNetworkAccessLogsParamsSortBy `form:"sort_by,omitempty" json:"sort_by,omitempty"` + + // SortOrder Sort order (ascending or descending). + SortOrder *GetApiAgentNetworkAccessLogsParamsSortOrder `form:"sort_order,omitempty" json:"sort_order,omitempty"` + + // Search General search across log ID, host, path, model, and user email/name. + Search *string `form:"search,omitempty" json:"search,omitempty"` + + // UserId Filter by authenticated user ID. + UserId *string `form:"user_id,omitempty" json:"user_id,omitempty"` + + // SessionId Filter to a single conversation / coding session id (groups all requests of one session). + SessionId *string `form:"session_id,omitempty" json:"session_id,omitempty"` + + // GroupId Filter by authorising group id. Repeat for multiple (matches any). + GroupId *[]string `form:"group_id,omitempty" json:"group_id,omitempty"` + + // ProviderId Filter by resolved provider id. Repeat for multiple (matches any). + ProviderId *[]string `form:"provider_id,omitempty" json:"provider_id,omitempty"` + + // Model Filter by model. Repeat for multiple (matches any). + Model *[]string `form:"model,omitempty" json:"model,omitempty"` + + // Decision Filter by policy decision (e.g. allow, deny). + Decision *string `form:"decision,omitempty" json:"decision,omitempty"` + + // Path Filter by request path prefix (matches entries whose path starts with this value). + Path *string `form:"path,omitempty" json:"path,omitempty"` + + // StartDate Filter by timestamp >= start_date (RFC3339 format). + StartDate *time.Time `form:"start_date,omitempty" json:"start_date,omitempty"` + + // EndDate Filter by timestamp <= end_date (RFC3339 format). + EndDate *time.Time `form:"end_date,omitempty" json:"end_date,omitempty"` +} + +// GetApiAgentNetworkAccessLogsParamsSortBy defines parameters for GetApiAgentNetworkAccessLogs. +type GetApiAgentNetworkAccessLogsParamsSortBy string + +// GetApiAgentNetworkAccessLogsParamsSortOrder defines parameters for GetApiAgentNetworkAccessLogs. +type GetApiAgentNetworkAccessLogsParamsSortOrder string + +// GetApiAgentNetworkUsageOverviewParams defines parameters for GetApiAgentNetworkUsageOverview. +type GetApiAgentNetworkUsageOverviewParams struct { + // Granularity Time bucket width. Defaults to day. + Granularity *GetApiAgentNetworkUsageOverviewParamsGranularity `form:"granularity,omitempty" json:"granularity,omitempty"` + + // StartDate Filter by timestamp >= start_date (RFC3339 format). + StartDate *time.Time `form:"start_date,omitempty" json:"start_date,omitempty"` + + // EndDate Filter by timestamp <= end_date (RFC3339 format). + EndDate *time.Time `form:"end_date,omitempty" json:"end_date,omitempty"` + + // UserId Filter by user ID. + UserId *string `form:"user_id,omitempty" json:"user_id,omitempty"` + + // SessionId Filter to a single conversation / coding session id. + SessionId *string `form:"session_id,omitempty" json:"session_id,omitempty"` + + // GroupId Filter by authorising group id. Repeat for multiple (matches any). + GroupId *[]string `form:"group_id,omitempty" json:"group_id,omitempty"` + + // ProviderId Filter by resolved provider id. Repeat for multiple (matches any). + ProviderId *[]string `form:"provider_id,omitempty" json:"provider_id,omitempty"` + + // Model Filter by model. Repeat for multiple (matches any). + Model *[]string `form:"model,omitempty" json:"model,omitempty"` +} + +// GetApiAgentNetworkUsageOverviewParamsGranularity defines parameters for GetApiAgentNetworkUsageOverview. +type GetApiAgentNetworkUsageOverviewParamsGranularity string + // GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic. type GetApiEventsNetworkTrafficParams struct { // Page Page number @@ -5090,6 +6017,33 @@ type GetApiUsersParams struct { // PutApiAccountsAccountIdJSONRequestBody defines body for PutApiAccountsAccountId for application/json ContentType. type PutApiAccountsAccountIdJSONRequestBody = AccountRequest +// PostApiAgentNetworkBudgetRulesJSONRequestBody defines body for PostApiAgentNetworkBudgetRules for application/json ContentType. +type PostApiAgentNetworkBudgetRulesJSONRequestBody = AgentNetworkBudgetRuleRequest + +// PutApiAgentNetworkBudgetRulesRuleIdJSONRequestBody defines body for PutApiAgentNetworkBudgetRulesRuleId for application/json ContentType. +type PutApiAgentNetworkBudgetRulesRuleIdJSONRequestBody = AgentNetworkBudgetRuleRequest + +// PostApiAgentNetworkGuardrailsJSONRequestBody defines body for PostApiAgentNetworkGuardrails for application/json ContentType. +type PostApiAgentNetworkGuardrailsJSONRequestBody = AgentNetworkGuardrailRequest + +// PutApiAgentNetworkGuardrailsGuardrailIdJSONRequestBody defines body for PutApiAgentNetworkGuardrailsGuardrailId for application/json ContentType. +type PutApiAgentNetworkGuardrailsGuardrailIdJSONRequestBody = AgentNetworkGuardrailRequest + +// PostApiAgentNetworkPoliciesJSONRequestBody defines body for PostApiAgentNetworkPolicies for application/json ContentType. +type PostApiAgentNetworkPoliciesJSONRequestBody = AgentNetworkPolicyRequest + +// PutApiAgentNetworkPoliciesPolicyIdJSONRequestBody defines body for PutApiAgentNetworkPoliciesPolicyId for application/json ContentType. +type PutApiAgentNetworkPoliciesPolicyIdJSONRequestBody = AgentNetworkPolicyRequest + +// PostApiAgentNetworkProvidersJSONRequestBody defines body for PostApiAgentNetworkProviders for application/json ContentType. +type PostApiAgentNetworkProvidersJSONRequestBody = AgentNetworkProviderRequest + +// PutApiAgentNetworkProvidersProviderIdJSONRequestBody defines body for PutApiAgentNetworkProvidersProviderId for application/json ContentType. +type PutApiAgentNetworkProvidersProviderIdJSONRequestBody = AgentNetworkProviderRequest + +// PutApiAgentNetworkSettingsJSONRequestBody defines body for PutApiAgentNetworkSettings for application/json ContentType. +type PutApiAgentNetworkSettingsJSONRequestBody = AgentNetworkSettingsRequest + // PostApiDnsNameserversJSONRequestBody defines body for PostApiDnsNameservers for application/json ContentType. type PostApiDnsNameserversJSONRequestBody = NameserverGroupRequest diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 5dd529407..8027c0db6 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -843,10 +843,14 @@ type SyncResponse struct { NetworkMap *NetworkMap `protobuf:"bytes,5,opt,name=NetworkMap,proto3" json:"NetworkMap,omitempty"` // Posture checks to be evaluated by client Checks []*Checks `protobuf:"bytes,6,rep,name=Checks,proto3" json:"Checks,omitempty"` - // Absolute UTC instant at which the peer's SSO session expires. - // Unset when the peer is not SSO-registered or login expiration is disabled. - // Carried on every Sync snapshot so admin-side changes propagate live without - // a client reconnect. + // 3-state session deadline. Carried on every Sync snapshot so admin-side + // changes propagate live without a client reconnect. + // + // field unset (nil) → snapshot carries no info; client keeps the + // deadline it already had + // set, seconds=0 nanos=0 → explicit "expiry disabled" or peer is not + // SSO-registered; client clears its anchor + // set, valid timestamp → new absolute UTC deadline SessionExpiresAt *timestamppb.Timestamp `protobuf:"bytes,7,opt,name=sessionExpiresAt,proto3" json:"sessionExpiresAt,omitempty"` } @@ -1608,8 +1612,11 @@ type LoginResponse struct { PeerConfig *PeerConfig `protobuf:"bytes,2,opt,name=peerConfig,proto3" json:"peerConfig,omitempty"` // Posture checks to be evaluated by client Checks []*Checks `protobuf:"bytes,3,rep,name=Checks,proto3" json:"Checks,omitempty"` - // Absolute UTC instant at which the peer's SSO session expires. - // Unset when the peer is not SSO-registered or login expiration is disabled. + // 3-state session deadline; same encoding as SyncResponse.sessionExpiresAt. + // + // field unset (nil) → no info; client keeps any deadline it had + // set, seconds=0 nanos=0 → explicit "expiry disabled" / non-SSO peer + // set, valid timestamp → new absolute UTC deadline SessionExpiresAt *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=sessionExpiresAt,proto3" json:"sessionExpiresAt,omitempty"` } @@ -1739,7 +1746,10 @@ type ExtendAuthSessionResponse struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // Absolute UTC instant at which the peer's SSO session now expires. + // 3-state session deadline; same encoding as SyncResponse.sessionExpiresAt. + // In practice ExtendAuthSession only succeeds for SSO peers with expiry + // enabled, so this carries a valid timestamp on the success path. The + // 3-state encoding is documented here for symmetry with Login/Sync. SessionExpiresAt *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=sessionExpiresAt,proto3" json:"sessionExpiresAt,omitempty"` } diff --git a/shared/management/proto/proxy_service.pb.go b/shared/management/proto/proxy_service.pb.go index 22c215074..df42d78ff 100644 --- a/shared/management/proto/proxy_service.pb.go +++ b/shared/management/proto/proxy_service.pb.go @@ -117,6 +117,60 @@ func (PathRewriteMode) EnumDescriptor() ([]byte, []int) { return file_proxy_service_proto_rawDescGZIP(), []int{1} } +// MiddlewareSlot identifies where in the request lifecycle a middleware +// runs. Mirrors proxy/internal/middleware.Slot. +type MiddlewareSlot int32 + +const ( + MiddlewareSlot_MIDDLEWARE_SLOT_UNSPECIFIED MiddlewareSlot = 0 + MiddlewareSlot_MIDDLEWARE_SLOT_ON_REQUEST MiddlewareSlot = 1 + MiddlewareSlot_MIDDLEWARE_SLOT_ON_RESPONSE MiddlewareSlot = 2 + MiddlewareSlot_MIDDLEWARE_SLOT_TERMINAL MiddlewareSlot = 3 +) + +// Enum value maps for MiddlewareSlot. +var ( + MiddlewareSlot_name = map[int32]string{ + 0: "MIDDLEWARE_SLOT_UNSPECIFIED", + 1: "MIDDLEWARE_SLOT_ON_REQUEST", + 2: "MIDDLEWARE_SLOT_ON_RESPONSE", + 3: "MIDDLEWARE_SLOT_TERMINAL", + } + MiddlewareSlot_value = map[string]int32{ + "MIDDLEWARE_SLOT_UNSPECIFIED": 0, + "MIDDLEWARE_SLOT_ON_REQUEST": 1, + "MIDDLEWARE_SLOT_ON_RESPONSE": 2, + "MIDDLEWARE_SLOT_TERMINAL": 3, + } +) + +func (x MiddlewareSlot) Enum() *MiddlewareSlot { + p := new(MiddlewareSlot) + *p = x + return p +} + +func (x MiddlewareSlot) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (MiddlewareSlot) Descriptor() protoreflect.EnumDescriptor { + return file_proxy_service_proto_enumTypes[2].Descriptor() +} + +func (MiddlewareSlot) Type() protoreflect.EnumType { + return &file_proxy_service_proto_enumTypes[2] +} + +func (x MiddlewareSlot) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use MiddlewareSlot.Descriptor instead. +func (MiddlewareSlot) EnumDescriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{2} +} + type ProxyStatus int32 const ( @@ -159,11 +213,11 @@ func (x ProxyStatus) String() string { } func (ProxyStatus) Descriptor() protoreflect.EnumDescriptor { - return file_proxy_service_proto_enumTypes[2].Descriptor() + return file_proxy_service_proto_enumTypes[3].Descriptor() } func (ProxyStatus) Type() protoreflect.EnumType { - return &file_proxy_service_proto_enumTypes[2] + return &file_proxy_service_proto_enumTypes[3] } func (x ProxyStatus) Number() protoreflect.EnumNumber { @@ -172,7 +226,53 @@ func (x ProxyStatus) Number() protoreflect.EnumNumber { // Deprecated: Use ProxyStatus.Descriptor instead. func (ProxyStatus) EnumDescriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{2} + return file_proxy_service_proto_rawDescGZIP(), []int{3} +} + +type MiddlewareConfig_FailMode int32 + +const ( + MiddlewareConfig_FAIL_OPEN MiddlewareConfig_FailMode = 0 + MiddlewareConfig_FAIL_CLOSED MiddlewareConfig_FailMode = 1 +) + +// Enum value maps for MiddlewareConfig_FailMode. +var ( + MiddlewareConfig_FailMode_name = map[int32]string{ + 0: "FAIL_OPEN", + 1: "FAIL_CLOSED", + } + MiddlewareConfig_FailMode_value = map[string]int32{ + "FAIL_OPEN": 0, + "FAIL_CLOSED": 1, + } +) + +func (x MiddlewareConfig_FailMode) Enum() *MiddlewareConfig_FailMode { + p := new(MiddlewareConfig_FailMode) + *p = x + return p +} + +func (x MiddlewareConfig_FailMode) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (MiddlewareConfig_FailMode) Descriptor() protoreflect.EnumDescriptor { + return file_proxy_service_proto_enumTypes[4].Descriptor() +} + +func (MiddlewareConfig_FailMode) Type() protoreflect.EnumType { + return &file_proxy_service_proto_enumTypes[4] +} + +func (x MiddlewareConfig_FailMode) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use MiddlewareConfig_FailMode.Descriptor instead. +func (MiddlewareConfig_FailMode) EnumDescriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{4, 0} } // ProxyCapabilities describes what a proxy can handle. @@ -422,6 +522,25 @@ type PathTargetOptions struct { // reachable without WireGuard (public APIs, LAN services, localhost // sidecars). Defaults to false — embedded client is the standard path. DirectUpstream bool `protobuf:"varint,7,opt,name=direct_upstream,json=directUpstream,proto3" json:"direct_upstream,omitempty"` + // Proxy clamps to [0, proxy-wide max (1 MiB)] at apply time. Agent-network + // synthesized targets only; private services leave these zero. + CaptureMaxRequestBytes int64 `protobuf:"varint,8,opt,name=capture_max_request_bytes,json=captureMaxRequestBytes,proto3" json:"capture_max_request_bytes,omitempty"` + // Proxy clamps to [0, proxy-wide max (1 MiB)] at apply time. + CaptureMaxResponseBytes int64 `protobuf:"varint,9,opt,name=capture_max_response_bytes,json=captureMaxResponseBytes,proto3" json:"capture_max_response_bytes,omitempty"` + // Content types eligible for body capture (e.g. "application/json"). + CaptureContentTypes []string `protobuf:"bytes,10,rep,name=capture_content_types,json=captureContentTypes,proto3" json:"capture_content_types,omitempty"` + // Per-target middleware configurations populated by the agent-network + // synthesizer. Validated and clamped by the proxy at apply time. + Middlewares []*MiddlewareConfig `protobuf:"bytes,11,rep,name=middlewares,proto3" json:"middlewares,omitempty"` + // When true, the proxy stamps agent_network=true on access-log entries + // for this target so management routes them to the agent-network log + // surface. + AgentNetwork bool `protobuf:"varint,12,opt,name=agent_network,json=agentNetwork,proto3" json:"agent_network,omitempty"` + // When true, the proxy suppresses the per-request access-log emission for + // this target. Defaults false to preserve existing access-log behavior for + // every non-agent-network target. The agent-network synth target sets this + // true only when the account's EnableLogCollection toggle is off. + DisableAccessLog bool `protobuf:"varint,13,opt,name=disable_access_log,json=disableAccessLog,proto3" json:"disable_access_log,omitempty"` } func (x *PathTargetOptions) Reset() { @@ -505,6 +624,154 @@ func (x *PathTargetOptions) GetDirectUpstream() bool { return false } +func (x *PathTargetOptions) GetCaptureMaxRequestBytes() int64 { + if x != nil { + return x.CaptureMaxRequestBytes + } + return 0 +} + +func (x *PathTargetOptions) GetCaptureMaxResponseBytes() int64 { + if x != nil { + return x.CaptureMaxResponseBytes + } + return 0 +} + +func (x *PathTargetOptions) GetCaptureContentTypes() []string { + if x != nil { + return x.CaptureContentTypes + } + return nil +} + +func (x *PathTargetOptions) GetMiddlewares() []*MiddlewareConfig { + if x != nil { + return x.Middlewares + } + return nil +} + +func (x *PathTargetOptions) GetAgentNetwork() bool { + if x != nil { + return x.AgentNetwork + } + return false +} + +func (x *PathTargetOptions) GetDisableAccessLog() bool { + if x != nil { + return x.DisableAccessLog + } + return false +} + +// MiddlewareConfig is the per-target configuration for a single middleware. +// The proxy validates every incoming MiddlewareConfig at apply time: +// unknown ids are rejected, timeout is clamped to [10ms, 5s], and the +// declared slot must match the registered middleware's slot. +type MiddlewareConfig struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Middleware id; must match the proxy-local compiled-in registry. + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + Enabled bool `protobuf:"varint,2,opt,name=enabled,proto3" json:"enabled,omitempty"` + Slot MiddlewareSlot `protobuf:"varint,3,opt,name=slot,proto3,enum=management.MiddlewareSlot" json:"slot,omitempty"` + // Free-form JSON unmarshalled by the middleware factory into its own typed + // config struct. Empty / null / {} are valid (zero-value config). + ConfigJson []byte `protobuf:"bytes,4,opt,name=config_json,json=configJson,proto3" json:"config_json,omitempty"` + FailMode MiddlewareConfig_FailMode `protobuf:"varint,5,opt,name=fail_mode,json=failMode,proto3,enum=management.MiddlewareConfig_FailMode" json:"fail_mode,omitempty"` + // Clamped to [10ms, 5s] at apply time; zero → 500ms default. + Timeout *durationpb.Duration `protobuf:"bytes,6,opt,name=timeout,proto3" json:"timeout,omitempty"` + // When true, the middleware may mutate request headers or body (subject to + // policy). Honoured only when the implementation also declares + // MutationsSupported. + CanMutate bool `protobuf:"varint,7,opt,name=can_mutate,json=canMutate,proto3" json:"can_mutate,omitempty"` +} + +func (x *MiddlewareConfig) Reset() { + *x = MiddlewareConfig{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MiddlewareConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MiddlewareConfig) ProtoMessage() {} + +func (x *MiddlewareConfig) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MiddlewareConfig.ProtoReflect.Descriptor instead. +func (*MiddlewareConfig) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{4} +} + +func (x *MiddlewareConfig) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *MiddlewareConfig) GetEnabled() bool { + if x != nil { + return x.Enabled + } + return false +} + +func (x *MiddlewareConfig) GetSlot() MiddlewareSlot { + if x != nil { + return x.Slot + } + return MiddlewareSlot_MIDDLEWARE_SLOT_UNSPECIFIED +} + +func (x *MiddlewareConfig) GetConfigJson() []byte { + if x != nil { + return x.ConfigJson + } + return nil +} + +func (x *MiddlewareConfig) GetFailMode() MiddlewareConfig_FailMode { + if x != nil { + return x.FailMode + } + return MiddlewareConfig_FAIL_OPEN +} + +func (x *MiddlewareConfig) GetTimeout() *durationpb.Duration { + if x != nil { + return x.Timeout + } + return nil +} + +func (x *MiddlewareConfig) GetCanMutate() bool { + if x != nil { + return x.CanMutate + } + return false +} + type PathMapping struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -518,7 +785,7 @@ type PathMapping struct { func (x *PathMapping) Reset() { *x = PathMapping{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[4] + mi := &file_proxy_service_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -531,7 +798,7 @@ func (x *PathMapping) String() string { func (*PathMapping) ProtoMessage() {} func (x *PathMapping) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[4] + mi := &file_proxy_service_proto_msgTypes[5] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -544,7 +811,7 @@ func (x *PathMapping) ProtoReflect() protoreflect.Message { // Deprecated: Use PathMapping.ProtoReflect.Descriptor instead. func (*PathMapping) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{4} + return file_proxy_service_proto_rawDescGZIP(), []int{5} } func (x *PathMapping) GetPath() string { @@ -582,7 +849,7 @@ type HeaderAuth struct { func (x *HeaderAuth) Reset() { *x = HeaderAuth{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[5] + mi := &file_proxy_service_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -595,7 +862,7 @@ func (x *HeaderAuth) String() string { func (*HeaderAuth) ProtoMessage() {} func (x *HeaderAuth) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[5] + mi := &file_proxy_service_proto_msgTypes[6] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -608,7 +875,7 @@ func (x *HeaderAuth) ProtoReflect() protoreflect.Message { // Deprecated: Use HeaderAuth.ProtoReflect.Descriptor instead. func (*HeaderAuth) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{5} + return file_proxy_service_proto_rawDescGZIP(), []int{6} } func (x *HeaderAuth) GetHeader() string { @@ -641,7 +908,7 @@ type Authentication struct { func (x *Authentication) Reset() { *x = Authentication{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[6] + mi := &file_proxy_service_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -654,7 +921,7 @@ func (x *Authentication) String() string { func (*Authentication) ProtoMessage() {} func (x *Authentication) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[6] + mi := &file_proxy_service_proto_msgTypes[7] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -667,7 +934,7 @@ func (x *Authentication) ProtoReflect() protoreflect.Message { // Deprecated: Use Authentication.ProtoReflect.Descriptor instead. func (*Authentication) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{6} + return file_proxy_service_proto_rawDescGZIP(), []int{7} } func (x *Authentication) GetSessionKey() string { @@ -728,7 +995,7 @@ type AccessRestrictions struct { func (x *AccessRestrictions) Reset() { *x = AccessRestrictions{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[7] + mi := &file_proxy_service_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -741,7 +1008,7 @@ func (x *AccessRestrictions) String() string { func (*AccessRestrictions) ProtoMessage() {} func (x *AccessRestrictions) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[7] + mi := &file_proxy_service_proto_msgTypes[8] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -754,7 +1021,7 @@ func (x *AccessRestrictions) ProtoReflect() protoreflect.Message { // Deprecated: Use AccessRestrictions.ProtoReflect.Descriptor instead. func (*AccessRestrictions) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{7} + return file_proxy_service_proto_rawDescGZIP(), []int{8} } func (x *AccessRestrictions) GetAllowedCidrs() []string { @@ -822,7 +1089,7 @@ type ProxyMapping struct { func (x *ProxyMapping) Reset() { *x = ProxyMapping{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[8] + mi := &file_proxy_service_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -835,7 +1102,7 @@ func (x *ProxyMapping) String() string { func (*ProxyMapping) ProtoMessage() {} func (x *ProxyMapping) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[8] + mi := &file_proxy_service_proto_msgTypes[9] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -848,7 +1115,7 @@ func (x *ProxyMapping) ProtoReflect() protoreflect.Message { // Deprecated: Use ProxyMapping.ProtoReflect.Descriptor instead. func (*ProxyMapping) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{8} + return file_proxy_service_proto_rawDescGZIP(), []int{9} } func (x *ProxyMapping) GetType() ProxyMappingUpdateType { @@ -954,7 +1221,7 @@ type SendAccessLogRequest struct { func (x *SendAccessLogRequest) Reset() { *x = SendAccessLogRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[9] + mi := &file_proxy_service_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -967,7 +1234,7 @@ func (x *SendAccessLogRequest) String() string { func (*SendAccessLogRequest) ProtoMessage() {} func (x *SendAccessLogRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[9] + mi := &file_proxy_service_proto_msgTypes[10] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -980,7 +1247,7 @@ func (x *SendAccessLogRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SendAccessLogRequest.ProtoReflect.Descriptor instead. func (*SendAccessLogRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{9} + return file_proxy_service_proto_rawDescGZIP(), []int{10} } func (x *SendAccessLogRequest) GetLog() *AccessLog { @@ -1000,7 +1267,7 @@ type SendAccessLogResponse struct { func (x *SendAccessLogResponse) Reset() { *x = SendAccessLogResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[10] + mi := &file_proxy_service_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1013,7 +1280,7 @@ func (x *SendAccessLogResponse) String() string { func (*SendAccessLogResponse) ProtoMessage() {} func (x *SendAccessLogResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[10] + mi := &file_proxy_service_proto_msgTypes[11] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1026,7 +1293,7 @@ func (x *SendAccessLogResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SendAccessLogResponse.ProtoReflect.Descriptor instead. func (*SendAccessLogResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{10} + return file_proxy_service_proto_rawDescGZIP(), []int{11} } type AccessLog struct { @@ -1052,12 +1319,16 @@ type AccessLog struct { Protocol string `protobuf:"bytes,16,opt,name=protocol,proto3" json:"protocol,omitempty"` // Extra key-value metadata for the access log entry (e.g. crowdsec_verdict, scenario). Metadata map[string]string `protobuf:"bytes,17,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + // When true, the entry was emitted by an agent-network synth service. + // Management routes these to the agent-network access-log surface instead + // of the standard service log. + AgentNetwork bool `protobuf:"varint,18,opt,name=agent_network,json=agentNetwork,proto3" json:"agent_network,omitempty"` } func (x *AccessLog) Reset() { *x = AccessLog{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[11] + mi := &file_proxy_service_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1070,7 +1341,7 @@ func (x *AccessLog) String() string { func (*AccessLog) ProtoMessage() {} func (x *AccessLog) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[11] + mi := &file_proxy_service_proto_msgTypes[12] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1083,7 +1354,7 @@ func (x *AccessLog) ProtoReflect() protoreflect.Message { // Deprecated: Use AccessLog.ProtoReflect.Descriptor instead. func (*AccessLog) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{11} + return file_proxy_service_proto_rawDescGZIP(), []int{12} } func (x *AccessLog) GetTimestamp() *timestamppb.Timestamp { @@ -1205,6 +1476,13 @@ func (x *AccessLog) GetMetadata() map[string]string { return nil } +func (x *AccessLog) GetAgentNetwork() bool { + if x != nil { + return x.AgentNetwork + } + return false +} + type AuthenticateRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1223,7 +1501,7 @@ type AuthenticateRequest struct { func (x *AuthenticateRequest) Reset() { *x = AuthenticateRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[12] + mi := &file_proxy_service_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1236,7 +1514,7 @@ func (x *AuthenticateRequest) String() string { func (*AuthenticateRequest) ProtoMessage() {} func (x *AuthenticateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[12] + mi := &file_proxy_service_proto_msgTypes[13] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1249,7 +1527,7 @@ func (x *AuthenticateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use AuthenticateRequest.ProtoReflect.Descriptor instead. func (*AuthenticateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{12} + return file_proxy_service_proto_rawDescGZIP(), []int{13} } func (x *AuthenticateRequest) GetId() string { @@ -1328,7 +1606,7 @@ type HeaderAuthRequest struct { func (x *HeaderAuthRequest) Reset() { *x = HeaderAuthRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[13] + mi := &file_proxy_service_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1341,7 +1619,7 @@ func (x *HeaderAuthRequest) String() string { func (*HeaderAuthRequest) ProtoMessage() {} func (x *HeaderAuthRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[13] + mi := &file_proxy_service_proto_msgTypes[14] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1354,7 +1632,7 @@ func (x *HeaderAuthRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use HeaderAuthRequest.ProtoReflect.Descriptor instead. func (*HeaderAuthRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{13} + return file_proxy_service_proto_rawDescGZIP(), []int{14} } func (x *HeaderAuthRequest) GetHeaderValue() string { @@ -1382,7 +1660,7 @@ type PasswordRequest struct { func (x *PasswordRequest) Reset() { *x = PasswordRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[14] + mi := &file_proxy_service_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1395,7 +1673,7 @@ func (x *PasswordRequest) String() string { func (*PasswordRequest) ProtoMessage() {} func (x *PasswordRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[14] + mi := &file_proxy_service_proto_msgTypes[15] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1408,7 +1686,7 @@ func (x *PasswordRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PasswordRequest.ProtoReflect.Descriptor instead. func (*PasswordRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{14} + return file_proxy_service_proto_rawDescGZIP(), []int{15} } func (x *PasswordRequest) GetPassword() string { @@ -1429,7 +1707,7 @@ type PinRequest struct { func (x *PinRequest) Reset() { *x = PinRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[15] + mi := &file_proxy_service_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1442,7 +1720,7 @@ func (x *PinRequest) String() string { func (*PinRequest) ProtoMessage() {} func (x *PinRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[15] + mi := &file_proxy_service_proto_msgTypes[16] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1455,7 +1733,7 @@ func (x *PinRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PinRequest.ProtoReflect.Descriptor instead. func (*PinRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{15} + return file_proxy_service_proto_rawDescGZIP(), []int{16} } func (x *PinRequest) GetPin() string { @@ -1477,7 +1755,7 @@ type AuthenticateResponse struct { func (x *AuthenticateResponse) Reset() { *x = AuthenticateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[16] + mi := &file_proxy_service_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1490,7 +1768,7 @@ func (x *AuthenticateResponse) String() string { func (*AuthenticateResponse) ProtoMessage() {} func (x *AuthenticateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[16] + mi := &file_proxy_service_proto_msgTypes[17] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1503,7 +1781,7 @@ func (x *AuthenticateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use AuthenticateResponse.ProtoReflect.Descriptor instead. func (*AuthenticateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{16} + return file_proxy_service_proto_rawDescGZIP(), []int{17} } func (x *AuthenticateResponse) GetSuccess() bool { @@ -1532,7 +1810,7 @@ type SendStatusUpdateRequest struct { CertificateIssued bool `protobuf:"varint,4,opt,name=certificate_issued,json=certificateIssued,proto3" json:"certificate_issued,omitempty"` ErrorMessage *string `protobuf:"bytes,5,opt,name=error_message,json=errorMessage,proto3,oneof" json:"error_message,omitempty"` // Per-account inbound listener state for the account that owns - // service_id. Populated only when --private-inbound is enabled and the + // service_id. Populated only when --private is enabled and the // embedded client for the account is up. Field numbers >=50 reserved // for observability extensions. InboundListener *ProxyInboundListener `protobuf:"bytes,50,opt,name=inbound_listener,json=inboundListener,proto3,oneof" json:"inbound_listener,omitempty"` @@ -1541,7 +1819,7 @@ type SendStatusUpdateRequest struct { func (x *SendStatusUpdateRequest) Reset() { *x = SendStatusUpdateRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[17] + mi := &file_proxy_service_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1554,7 +1832,7 @@ func (x *SendStatusUpdateRequest) String() string { func (*SendStatusUpdateRequest) ProtoMessage() {} func (x *SendStatusUpdateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[17] + mi := &file_proxy_service_proto_msgTypes[18] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1567,7 +1845,7 @@ func (x *SendStatusUpdateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SendStatusUpdateRequest.ProtoReflect.Descriptor instead. func (*SendStatusUpdateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{17} + return file_proxy_service_proto_rawDescGZIP(), []int{18} } func (x *SendStatusUpdateRequest) GetServiceId() string { @@ -1633,7 +1911,7 @@ type ProxyInboundListener struct { func (x *ProxyInboundListener) Reset() { *x = ProxyInboundListener{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[18] + mi := &file_proxy_service_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1646,7 +1924,7 @@ func (x *ProxyInboundListener) String() string { func (*ProxyInboundListener) ProtoMessage() {} func (x *ProxyInboundListener) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[18] + mi := &file_proxy_service_proto_msgTypes[19] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1659,7 +1937,7 @@ func (x *ProxyInboundListener) ProtoReflect() protoreflect.Message { // Deprecated: Use ProxyInboundListener.ProtoReflect.Descriptor instead. func (*ProxyInboundListener) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{18} + return file_proxy_service_proto_rawDescGZIP(), []int{19} } func (x *ProxyInboundListener) GetTunnelIp() string { @@ -1693,7 +1971,7 @@ type SendStatusUpdateResponse struct { func (x *SendStatusUpdateResponse) Reset() { *x = SendStatusUpdateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[19] + mi := &file_proxy_service_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1706,7 +1984,7 @@ func (x *SendStatusUpdateResponse) String() string { func (*SendStatusUpdateResponse) ProtoMessage() {} func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[19] + mi := &file_proxy_service_proto_msgTypes[20] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1719,7 +1997,7 @@ func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SendStatusUpdateResponse.ProtoReflect.Descriptor instead. func (*SendStatusUpdateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{19} + return file_proxy_service_proto_rawDescGZIP(), []int{20} } // CreateProxyPeerRequest is sent by the proxy to create a peer connection @@ -1739,7 +2017,7 @@ type CreateProxyPeerRequest struct { func (x *CreateProxyPeerRequest) Reset() { *x = CreateProxyPeerRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[20] + mi := &file_proxy_service_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1752,7 +2030,7 @@ func (x *CreateProxyPeerRequest) String() string { func (*CreateProxyPeerRequest) ProtoMessage() {} func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[20] + mi := &file_proxy_service_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1765,7 +2043,7 @@ func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateProxyPeerRequest.ProtoReflect.Descriptor instead. func (*CreateProxyPeerRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{20} + return file_proxy_service_proto_rawDescGZIP(), []int{21} } func (x *CreateProxyPeerRequest) GetServiceId() string { @@ -1816,7 +2094,7 @@ type CreateProxyPeerResponse struct { func (x *CreateProxyPeerResponse) Reset() { *x = CreateProxyPeerResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[21] + mi := &file_proxy_service_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1829,7 +2107,7 @@ func (x *CreateProxyPeerResponse) String() string { func (*CreateProxyPeerResponse) ProtoMessage() {} func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[21] + mi := &file_proxy_service_proto_msgTypes[22] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1842,7 +2120,7 @@ func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateProxyPeerResponse.ProtoReflect.Descriptor instead. func (*CreateProxyPeerResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{21} + return file_proxy_service_proto_rawDescGZIP(), []int{22} } func (x *CreateProxyPeerResponse) GetSuccess() bool { @@ -1872,7 +2150,7 @@ type GetOIDCURLRequest struct { func (x *GetOIDCURLRequest) Reset() { *x = GetOIDCURLRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[22] + mi := &file_proxy_service_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1885,7 +2163,7 @@ func (x *GetOIDCURLRequest) String() string { func (*GetOIDCURLRequest) ProtoMessage() {} func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[22] + mi := &file_proxy_service_proto_msgTypes[23] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1898,7 +2176,7 @@ func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetOIDCURLRequest.ProtoReflect.Descriptor instead. func (*GetOIDCURLRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{22} + return file_proxy_service_proto_rawDescGZIP(), []int{23} } func (x *GetOIDCURLRequest) GetId() string { @@ -1933,7 +2211,7 @@ type GetOIDCURLResponse struct { func (x *GetOIDCURLResponse) Reset() { *x = GetOIDCURLResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[23] + mi := &file_proxy_service_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1946,7 +2224,7 @@ func (x *GetOIDCURLResponse) String() string { func (*GetOIDCURLResponse) ProtoMessage() {} func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[23] + mi := &file_proxy_service_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1959,7 +2237,7 @@ func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetOIDCURLResponse.ProtoReflect.Descriptor instead. func (*GetOIDCURLResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{23} + return file_proxy_service_proto_rawDescGZIP(), []int{24} } func (x *GetOIDCURLResponse) GetUrl() string { @@ -1981,7 +2259,7 @@ type ValidateSessionRequest struct { func (x *ValidateSessionRequest) Reset() { *x = ValidateSessionRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[24] + mi := &file_proxy_service_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1994,7 +2272,7 @@ func (x *ValidateSessionRequest) String() string { func (*ValidateSessionRequest) ProtoMessage() {} func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[24] + mi := &file_proxy_service_proto_msgTypes[25] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2007,7 +2285,7 @@ func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateSessionRequest.ProtoReflect.Descriptor instead. func (*ValidateSessionRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{24} + return file_proxy_service_proto_rawDescGZIP(), []int{25} } func (x *ValidateSessionRequest) GetDomain() string { @@ -2047,7 +2325,7 @@ type ValidateSessionResponse struct { func (x *ValidateSessionResponse) Reset() { *x = ValidateSessionResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[25] + mi := &file_proxy_service_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2060,7 +2338,7 @@ func (x *ValidateSessionResponse) String() string { func (*ValidateSessionResponse) ProtoMessage() {} func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[25] + mi := &file_proxy_service_proto_msgTypes[26] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2073,7 +2351,7 @@ func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateSessionResponse.ProtoReflect.Descriptor instead. func (*ValidateSessionResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{25} + return file_proxy_service_proto_rawDescGZIP(), []int{26} } func (x *ValidateSessionResponse) GetValid() bool { @@ -2133,7 +2411,7 @@ type ValidateTunnelPeerRequest struct { func (x *ValidateTunnelPeerRequest) Reset() { *x = ValidateTunnelPeerRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[26] + mi := &file_proxy_service_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2146,7 +2424,7 @@ func (x *ValidateTunnelPeerRequest) String() string { func (*ValidateTunnelPeerRequest) ProtoMessage() {} func (x *ValidateTunnelPeerRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[26] + mi := &file_proxy_service_proto_msgTypes[27] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2159,7 +2437,7 @@ func (x *ValidateTunnelPeerRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateTunnelPeerRequest.ProtoReflect.Descriptor instead. func (*ValidateTunnelPeerRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{26} + return file_proxy_service_proto_rawDescGZIP(), []int{27} } func (x *ValidateTunnelPeerRequest) GetTunnelIp() string { @@ -2213,7 +2491,7 @@ type ValidateTunnelPeerResponse struct { func (x *ValidateTunnelPeerResponse) Reset() { *x = ValidateTunnelPeerResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[27] + mi := &file_proxy_service_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2226,7 +2504,7 @@ func (x *ValidateTunnelPeerResponse) String() string { func (*ValidateTunnelPeerResponse) ProtoMessage() {} func (x *ValidateTunnelPeerResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[27] + mi := &file_proxy_service_proto_msgTypes[28] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2239,7 +2517,7 @@ func (x *ValidateTunnelPeerResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateTunnelPeerResponse.ProtoReflect.Descriptor instead. func (*ValidateTunnelPeerResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{27} + return file_proxy_service_proto_rawDescGZIP(), []int{28} } func (x *ValidateTunnelPeerResponse) GetValid() bool { @@ -2309,7 +2587,7 @@ type SyncMappingsRequest struct { func (x *SyncMappingsRequest) Reset() { *x = SyncMappingsRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[28] + mi := &file_proxy_service_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2322,7 +2600,7 @@ func (x *SyncMappingsRequest) String() string { func (*SyncMappingsRequest) ProtoMessage() {} func (x *SyncMappingsRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[28] + mi := &file_proxy_service_proto_msgTypes[29] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2335,7 +2613,7 @@ func (x *SyncMappingsRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SyncMappingsRequest.ProtoReflect.Descriptor instead. func (*SyncMappingsRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{28} + return file_proxy_service_proto_rawDescGZIP(), []int{29} } func (m *SyncMappingsRequest) GetMsg() isSyncMappingsRequest_Msg { @@ -2392,7 +2670,7 @@ type SyncMappingsInit struct { func (x *SyncMappingsInit) Reset() { *x = SyncMappingsInit{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[29] + mi := &file_proxy_service_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2405,7 +2683,7 @@ func (x *SyncMappingsInit) String() string { func (*SyncMappingsInit) ProtoMessage() {} func (x *SyncMappingsInit) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[29] + mi := &file_proxy_service_proto_msgTypes[30] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2418,7 +2696,7 @@ func (x *SyncMappingsInit) ProtoReflect() protoreflect.Message { // Deprecated: Use SyncMappingsInit.ProtoReflect.Descriptor instead. func (*SyncMappingsInit) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{29} + return file_proxy_service_proto_rawDescGZIP(), []int{30} } func (x *SyncMappingsInit) GetProxyId() string { @@ -2467,7 +2745,7 @@ type SyncMappingsAck struct { func (x *SyncMappingsAck) Reset() { *x = SyncMappingsAck{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[30] + mi := &file_proxy_service_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2480,7 +2758,7 @@ func (x *SyncMappingsAck) String() string { func (*SyncMappingsAck) ProtoMessage() {} func (x *SyncMappingsAck) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[30] + mi := &file_proxy_service_proto_msgTypes[31] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2493,7 +2771,7 @@ func (x *SyncMappingsAck) ProtoReflect() protoreflect.Message { // Deprecated: Use SyncMappingsAck.ProtoReflect.Descriptor instead. func (*SyncMappingsAck) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{30} + return file_proxy_service_proto_rawDescGZIP(), []int{31} } // SyncMappingsResponse is a batch of mappings sent by management. @@ -2511,7 +2789,7 @@ type SyncMappingsResponse struct { func (x *SyncMappingsResponse) Reset() { *x = SyncMappingsResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[31] + mi := &file_proxy_service_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2524,7 +2802,7 @@ func (x *SyncMappingsResponse) String() string { func (*SyncMappingsResponse) ProtoMessage() {} func (x *SyncMappingsResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[31] + mi := &file_proxy_service_proto_msgTypes[32] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2537,7 +2815,7 @@ func (x *SyncMappingsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SyncMappingsResponse.ProtoReflect.Descriptor instead. func (*SyncMappingsResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{31} + return file_proxy_service_proto_rawDescGZIP(), []int{32} } func (x *SyncMappingsResponse) GetMapping() []*ProxyMapping { @@ -2554,6 +2832,338 @@ func (x *SyncMappingsResponse) GetInitialSyncComplete() bool { return false } +// CheckLLMPolicyLimitsRequest carries the resolved caller identity and the +// upstream provider already chosen by llm_router. Management computes which +// policies authorise the request, picks the one with the most remaining +// headroom, and returns the attribution decision. +type CheckLLMPolicyLimitsRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // account_id is the netbird account the request belongs to. + AccountId string `protobuf:"bytes,1,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` + // user_id is the netbird user id of the caller. May be empty when the + // principal is a tunnel-peer that isn't bound to a user; group membership + // still gates the request in that case. + UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + // group_ids is the caller's full group membership at request time. + GroupIds []string `protobuf:"bytes,3,rep,name=group_ids,json=groupIds,proto3" json:"group_ids,omitempty"` + // provider_id is the agent-network provider record id chosen by llm_router. + ProviderId string `protobuf:"bytes,4,opt,name=provider_id,json=providerId,proto3" json:"provider_id,omitempty"` + // model is the upstream model identifier extracted from the request body. + Model string `protobuf:"bytes,5,opt,name=model,proto3" json:"model,omitempty"` +} + +func (x *CheckLLMPolicyLimitsRequest) Reset() { + *x = CheckLLMPolicyLimitsRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[33] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CheckLLMPolicyLimitsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CheckLLMPolicyLimitsRequest) ProtoMessage() {} + +func (x *CheckLLMPolicyLimitsRequest) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[33] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CheckLLMPolicyLimitsRequest.ProtoReflect.Descriptor instead. +func (*CheckLLMPolicyLimitsRequest) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{33} +} + +func (x *CheckLLMPolicyLimitsRequest) GetAccountId() string { + if x != nil { + return x.AccountId + } + return "" +} + +func (x *CheckLLMPolicyLimitsRequest) GetUserId() string { + if x != nil { + return x.UserId + } + return "" +} + +func (x *CheckLLMPolicyLimitsRequest) GetGroupIds() []string { + if x != nil { + return x.GroupIds + } + return nil +} + +func (x *CheckLLMPolicyLimitsRequest) GetProviderId() string { + if x != nil { + return x.ProviderId + } + return "" +} + +func (x *CheckLLMPolicyLimitsRequest) GetModel() string { + if x != nil { + return x.Model + } + return "" +} + +// CheckLLMPolicyLimitsResponse is management's allow-or-deny decision for a +// pre-flight check. +type CheckLLMPolicyLimitsResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // decision is "allow" or "deny". + Decision string `protobuf:"bytes,1,opt,name=decision,proto3" json:"decision,omitempty"` + // selected_policy_id names the policy that paid for this request. + SelectedPolicyId string `protobuf:"bytes,2,opt,name=selected_policy_id,json=selectedPolicyId,proto3" json:"selected_policy_id,omitempty"` + // attribution_group_id is the source group the request booked against. + AttributionGroupId string `protobuf:"bytes,3,opt,name=attribution_group_id,json=attributionGroupId,proto3" json:"attribution_group_id,omitempty"` + // window_seconds is the cap window length the selected policy uses. + WindowSeconds int64 `protobuf:"varint,4,opt,name=window_seconds,json=windowSeconds,proto3" json:"window_seconds,omitempty"` + // deny_code is set on decision="deny" with a stable label. + DenyCode string `protobuf:"bytes,5,opt,name=deny_code,json=denyCode,proto3" json:"deny_code,omitempty"` + // deny_reason is a short human-readable explanation paired with deny_code. + DenyReason string `protobuf:"bytes,6,opt,name=deny_reason,json=denyReason,proto3" json:"deny_reason,omitempty"` +} + +func (x *CheckLLMPolicyLimitsResponse) Reset() { + *x = CheckLLMPolicyLimitsResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[34] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CheckLLMPolicyLimitsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CheckLLMPolicyLimitsResponse) ProtoMessage() {} + +func (x *CheckLLMPolicyLimitsResponse) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[34] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CheckLLMPolicyLimitsResponse.ProtoReflect.Descriptor instead. +func (*CheckLLMPolicyLimitsResponse) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{34} +} + +func (x *CheckLLMPolicyLimitsResponse) GetDecision() string { + if x != nil { + return x.Decision + } + return "" +} + +func (x *CheckLLMPolicyLimitsResponse) GetSelectedPolicyId() string { + if x != nil { + return x.SelectedPolicyId + } + return "" +} + +func (x *CheckLLMPolicyLimitsResponse) GetAttributionGroupId() string { + if x != nil { + return x.AttributionGroupId + } + return "" +} + +func (x *CheckLLMPolicyLimitsResponse) GetWindowSeconds() int64 { + if x != nil { + return x.WindowSeconds + } + return 0 +} + +func (x *CheckLLMPolicyLimitsResponse) GetDenyCode() string { + if x != nil { + return x.DenyCode + } + return "" +} + +func (x *CheckLLMPolicyLimitsResponse) GetDenyReason() string { + if x != nil { + return x.DenyReason + } + return "" +} + +// RecordLLMUsageRequest is the post-flight increment the proxy posts after +// the upstream call. Counters are keyed on (account, dimension, window). +type RecordLLMUsageRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + AccountId string `protobuf:"bytes,1,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` + UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + // group_id is the selected policy's attribution group, recorded against the + // policy window (window_seconds). + GroupId string `protobuf:"bytes,3,opt,name=group_id,json=groupId,proto3" json:"group_id,omitempty"` + WindowSeconds int64 `protobuf:"varint,4,opt,name=window_seconds,json=windowSeconds,proto3" json:"window_seconds,omitempty"` + TokensInput int64 `protobuf:"varint,5,opt,name=tokens_input,json=tokensInput,proto3" json:"tokens_input,omitempty"` + TokensOutput int64 `protobuf:"varint,6,opt,name=tokens_output,json=tokensOutput,proto3" json:"tokens_output,omitempty"` + CostUsd float64 `protobuf:"fixed64,7,opt,name=cost_usd,json=costUsd,proto3" json:"cost_usd,omitempty"` + // group_ids is the caller's full group membership, used to fan the same + // usage out to every applicable account-level budget rule's own window. + GroupIds []string `protobuf:"bytes,8,rep,name=group_ids,json=groupIds,proto3" json:"group_ids,omitempty"` +} + +func (x *RecordLLMUsageRequest) Reset() { + *x = RecordLLMUsageRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[35] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RecordLLMUsageRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RecordLLMUsageRequest) ProtoMessage() {} + +func (x *RecordLLMUsageRequest) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[35] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RecordLLMUsageRequest.ProtoReflect.Descriptor instead. +func (*RecordLLMUsageRequest) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{35} +} + +func (x *RecordLLMUsageRequest) GetAccountId() string { + if x != nil { + return x.AccountId + } + return "" +} + +func (x *RecordLLMUsageRequest) GetUserId() string { + if x != nil { + return x.UserId + } + return "" +} + +func (x *RecordLLMUsageRequest) GetGroupId() string { + if x != nil { + return x.GroupId + } + return "" +} + +func (x *RecordLLMUsageRequest) GetWindowSeconds() int64 { + if x != nil { + return x.WindowSeconds + } + return 0 +} + +func (x *RecordLLMUsageRequest) GetTokensInput() int64 { + if x != nil { + return x.TokensInput + } + return 0 +} + +func (x *RecordLLMUsageRequest) GetTokensOutput() int64 { + if x != nil { + return x.TokensOutput + } + return 0 +} + +func (x *RecordLLMUsageRequest) GetCostUsd() float64 { + if x != nil { + return x.CostUsd + } + return 0 +} + +func (x *RecordLLMUsageRequest) GetGroupIds() []string { + if x != nil { + return x.GroupIds + } + return nil +} + +type RecordLLMUsageResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *RecordLLMUsageResponse) Reset() { + *x = RecordLLMUsageResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[36] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RecordLLMUsageResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RecordLLMUsageResponse) ProtoMessage() {} + +func (x *RecordLLMUsageResponse) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[36] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RecordLLMUsageResponse.ProtoReflect.Descriptor instead. +func (*RecordLLMUsageResponse) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{36} +} + var File_proxy_service_proto protoreflect.FileDescriptor var file_proxy_service_proto_rawDesc = []byte{ @@ -2610,7 +3220,7 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, - 0x22, 0xf7, 0x03, 0x0a, 0x11, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, + 0x22, 0xb6, 0x06, 0x0a, 0x11, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0f, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x74, 0x6c, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x54, 0x6c, 0x73, 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x12, 0x42, @@ -2637,368 +3247,479 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x6c, 0x65, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x27, 0x0a, 0x0f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x70, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, - 0x55, 0x70, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x1a, 0x40, 0x0a, 0x12, 0x43, 0x75, 0x73, 0x74, - 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, - 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, - 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x72, 0x0a, 0x0b, 0x50, 0x61, - 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, - 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x16, 0x0a, - 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, - 0x61, 0x72, 0x67, 0x65, 0x74, 0x12, 0x37, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, - 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0x47, - 0x0a, 0x0a, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x12, 0x16, 0x0a, 0x06, - 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x68, 0x65, - 0x61, 0x64, 0x65, 0x72, 0x12, 0x21, 0x0a, 0x0c, 0x68, 0x61, 0x73, 0x68, 0x65, 0x64, 0x5f, 0x76, - 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x68, 0x61, 0x73, 0x68, - 0x65, 0x64, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x22, 0xe5, 0x01, 0x0a, 0x0e, 0x41, 0x75, 0x74, 0x68, - 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, - 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79, 0x12, 0x35, 0x0a, 0x17, 0x6d, - 0x61, 0x78, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x61, 0x67, 0x65, 0x5f, 0x73, - 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x14, 0x6d, 0x61, - 0x78, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x41, 0x67, 0x65, 0x53, 0x65, 0x63, 0x6f, 0x6e, - 0x64, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x10, - 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x70, 0x69, 0x6e, - 0x12, 0x12, 0x0a, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, - 0x6f, 0x69, 0x64, 0x63, 0x12, 0x39, 0x0a, 0x0c, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x61, - 0x75, 0x74, 0x68, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, - 0x74, 0x68, 0x52, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x73, 0x22, - 0xdd, 0x01, 0x0a, 0x12, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x74, 0x72, 0x69, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, - 0x64, 0x5f, 0x63, 0x69, 0x64, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x61, - 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x43, 0x69, 0x64, 0x72, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x62, - 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x5f, 0x63, 0x69, 0x64, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x0c, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x43, 0x69, 0x64, 0x72, 0x73, - 0x12, 0x2b, 0x0a, 0x11, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x63, 0x6f, 0x75, 0x6e, - 0x74, 0x72, 0x69, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x10, 0x61, 0x6c, 0x6c, - 0x6f, 0x77, 0x65, 0x64, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, 0x12, 0x2b, 0x0a, - 0x11, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x69, - 0x65, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x10, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, - 0x64, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x63, 0x72, - 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x0c, 0x63, 0x72, 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x4d, 0x6f, 0x64, 0x65, 0x22, - 0x80, 0x04, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, - 0x12, 0x36, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, - 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, - 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, - 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, - 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, - 0x2b, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x4d, - 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1d, 0x0a, 0x0a, - 0x61, 0x75, 0x74, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x09, 0x61, 0x75, 0x74, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2e, 0x0a, 0x04, 0x61, - 0x75, 0x74, 0x68, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x04, 0x61, 0x75, 0x74, 0x68, 0x12, 0x28, 0x0a, 0x10, 0x70, - 0x61, 0x73, 0x73, 0x5f, 0x68, 0x6f, 0x73, 0x74, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, - 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x70, 0x61, 0x73, 0x73, 0x48, 0x6f, 0x73, 0x74, 0x48, - 0x65, 0x61, 0x64, 0x65, 0x72, 0x12, 0x2b, 0x0a, 0x11, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, - 0x5f, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x10, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, - 0x74, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, - 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x6c, 0x69, 0x73, - 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x4f, 0x0a, 0x13, 0x61, 0x63, 0x63, 0x65, 0x73, - 0x73, 0x5f, 0x72, 0x65, 0x73, 0x74, 0x72, 0x69, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x0c, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x74, 0x72, 0x69, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x12, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x74, - 0x72, 0x69, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x72, 0x69, 0x76, - 0x61, 0x74, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x70, 0x72, 0x69, 0x76, 0x61, - 0x74, 0x65, 0x22, 0x3f, 0x0a, 0x14, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, - 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x03, 0x6c, 0x6f, - 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x03, - 0x6c, 0x6f, 0x67, 0x22, 0x17, 0x0a, 0x15, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, - 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x84, 0x05, 0x0a, - 0x09, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, - 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, - 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, - 0x74, 0x61, 0x6d, 0x70, 0x12, 0x15, 0x0a, 0x06, 0x6c, 0x6f, 0x67, 0x5f, 0x69, 0x64, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6c, 0x6f, 0x67, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, + 0x55, 0x70, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x39, 0x0a, 0x19, 0x63, 0x61, 0x70, 0x74, + 0x75, 0x72, 0x65, 0x5f, 0x6d, 0x61, 0x78, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, + 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x08, 0x20, 0x01, 0x28, 0x03, 0x52, 0x16, 0x63, 0x61, 0x70, + 0x74, 0x75, 0x72, 0x65, 0x4d, 0x61, 0x78, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x42, 0x79, + 0x74, 0x65, 0x73, 0x12, 0x3b, 0x0a, 0x1a, 0x63, 0x61, 0x70, 0x74, 0x75, 0x72, 0x65, 0x5f, 0x6d, + 0x61, 0x78, 0x5f, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x62, 0x79, 0x74, 0x65, + 0x73, 0x18, 0x09, 0x20, 0x01, 0x28, 0x03, 0x52, 0x17, 0x63, 0x61, 0x70, 0x74, 0x75, 0x72, 0x65, + 0x4d, 0x61, 0x78, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x79, 0x74, 0x65, 0x73, + 0x12, 0x32, 0x0a, 0x15, 0x63, 0x61, 0x70, 0x74, 0x75, 0x72, 0x65, 0x5f, 0x63, 0x6f, 0x6e, 0x74, + 0x65, 0x6e, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, + 0x13, 0x63, 0x61, 0x70, 0x74, 0x75, 0x72, 0x65, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x54, + 0x79, 0x70, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x0b, 0x6d, 0x69, 0x64, 0x64, 0x6c, 0x65, 0x77, 0x61, + 0x72, 0x65, 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x69, 0x64, 0x64, 0x6c, 0x65, 0x77, 0x61, 0x72, + 0x65, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x6d, 0x69, 0x64, 0x64, 0x6c, 0x65, 0x77, + 0x61, 0x72, 0x65, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x5f, 0x6e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x61, 0x67, 0x65, + 0x6e, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x2c, 0x0a, 0x12, 0x64, 0x69, 0x73, + 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, 0x6c, 0x6f, 0x67, 0x18, + 0x0d, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x63, + 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x1a, 0x40, 0x0a, 0x12, 0x43, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, + 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, + 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xd1, 0x02, 0x0a, 0x10, 0x4d, 0x69, + 0x64, 0x64, 0x6c, 0x65, 0x77, 0x61, 0x72, 0x65, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x0e, + 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x18, + 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x2e, 0x0a, 0x04, 0x73, 0x6c, 0x6f, 0x74, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x69, 0x64, 0x64, 0x6c, 0x65, 0x77, 0x61, 0x72, 0x65, 0x53, 0x6c, + 0x6f, 0x74, 0x52, 0x04, 0x73, 0x6c, 0x6f, 0x74, 0x12, 0x1f, 0x0a, 0x0b, 0x63, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x5f, 0x6a, 0x73, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x63, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x4a, 0x73, 0x6f, 0x6e, 0x12, 0x42, 0x0a, 0x09, 0x66, 0x61, 0x69, + 0x6c, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x25, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x69, 0x64, 0x64, 0x6c, 0x65, + 0x77, 0x61, 0x72, 0x65, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x46, 0x61, 0x69, 0x6c, 0x4d, + 0x6f, 0x64, 0x65, 0x52, 0x08, 0x66, 0x61, 0x69, 0x6c, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x33, 0x0a, + 0x07, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x74, 0x69, 0x6d, 0x65, 0x6f, + 0x75, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x63, 0x61, 0x6e, 0x5f, 0x6d, 0x75, 0x74, 0x61, 0x74, 0x65, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x61, 0x6e, 0x4d, 0x75, 0x74, 0x61, 0x74, + 0x65, 0x22, 0x2a, 0x0a, 0x08, 0x46, 0x61, 0x69, 0x6c, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x0d, 0x0a, + 0x09, 0x46, 0x41, 0x49, 0x4c, 0x5f, 0x4f, 0x50, 0x45, 0x4e, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, + 0x46, 0x41, 0x49, 0x4c, 0x5f, 0x43, 0x4c, 0x4f, 0x53, 0x45, 0x44, 0x10, 0x01, 0x22, 0x72, 0x0a, + 0x0b, 0x50, 0x61, 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x12, 0x0a, 0x04, + 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, + 0x12, 0x16, 0x0a, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x12, 0x37, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, + 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x22, 0x47, 0x0a, 0x0a, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x12, + 0x16, 0x0a, 0x06, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x06, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x12, 0x21, 0x0a, 0x0c, 0x68, 0x61, 0x73, 0x68, 0x65, + 0x64, 0x5f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x68, + 0x61, 0x73, 0x68, 0x65, 0x64, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x22, 0xe5, 0x01, 0x0a, 0x0e, 0x41, + 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1f, 0x0a, + 0x0b, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79, 0x12, 0x35, + 0x0a, 0x17, 0x6d, 0x61, 0x78, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x61, 0x67, + 0x65, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x14, 0x6d, 0x61, 0x78, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x41, 0x67, 0x65, 0x53, 0x65, + 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, + 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, + 0x64, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, + 0x70, 0x69, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x12, 0x39, 0x0a, 0x0c, 0x68, 0x65, 0x61, 0x64, 0x65, + 0x72, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x65, 0x61, 0x64, 0x65, + 0x72, 0x41, 0x75, 0x74, 0x68, 0x52, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, + 0x68, 0x73, 0x22, 0xdd, 0x01, 0x0a, 0x12, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, + 0x74, 0x72, 0x69, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x61, 0x6c, 0x6c, + 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x63, 0x69, 0x64, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0c, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x43, 0x69, 0x64, 0x72, 0x73, 0x12, 0x23, + 0x0a, 0x0d, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x5f, 0x63, 0x69, 0x64, 0x72, 0x73, 0x18, + 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x43, 0x69, + 0x64, 0x72, 0x73, 0x12, 0x2b, 0x0a, 0x11, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x63, + 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x10, + 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, + 0x12, 0x2b, 0x0a, 0x11, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x5f, 0x63, 0x6f, 0x75, 0x6e, + 0x74, 0x72, 0x69, 0x65, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x10, 0x62, 0x6c, 0x6f, + 0x63, 0x6b, 0x65, 0x64, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, 0x12, 0x23, 0x0a, + 0x0d, 0x63, 0x72, 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x63, 0x72, 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x4d, 0x6f, + 0x64, 0x65, 0x22, 0x80, 0x04, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, + 0x69, 0x6e, 0x67, 0x12, 0x36, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, + 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, + 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, + 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, - 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x6f, 0x73, - 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, 0x12, 0x0a, - 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, - 0x68, 0x12, 0x1f, 0x0a, 0x0b, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x73, - 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0a, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x4d, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x08, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, - 0x05, 0x52, 0x0c, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x43, 0x6f, 0x64, 0x65, 0x12, - 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x0a, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, - 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x18, 0x0b, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x61, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x63, 0x68, 0x61, 0x6e, - 0x69, 0x73, 0x6d, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x0c, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, - 0x61, 0x75, 0x74, 0x68, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x0d, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x0b, 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, - 0x21, 0x0a, 0x0c, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x18, - 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x62, 0x79, 0x74, 0x65, 0x73, 0x55, 0x70, 0x6c, 0x6f, - 0x61, 0x64, 0x12, 0x25, 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x64, 0x6f, 0x77, 0x6e, - 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, - 0x73, 0x44, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x10, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3f, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, - 0x61, 0x18, 0x11, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x2e, 0x4d, - 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, - 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x3b, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, - 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, - 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, - 0x02, 0x38, 0x01, 0x22, 0xf8, 0x01, 0x0a, 0x13, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, - 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, - 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, - 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x39, 0x0a, 0x08, 0x70, 0x61, - 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, - 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x08, 0x70, 0x61, 0x73, - 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x2a, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x03, 0x70, 0x69, - 0x6e, 0x12, 0x40, 0x0a, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x61, 0x75, 0x74, 0x68, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x0a, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, - 0x75, 0x74, 0x68, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x57, - 0x0a, 0x11, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x12, 0x21, 0x0a, 0x0c, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x76, 0x61, - 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, - 0x72, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, - 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x68, 0x65, 0x61, - 0x64, 0x65, 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x2d, 0x0a, 0x0f, 0x50, 0x61, 0x73, 0x73, 0x77, - 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, - 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, - 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x1e, 0x0a, 0x0a, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x22, 0x55, 0x0a, 0x14, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, - 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, - 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, - 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xda, 0x02, - 0x0a, 0x17, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, - 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, - 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, - 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, - 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x2f, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x2d, 0x0a, 0x12, 0x63, 0x65, 0x72, 0x74, - 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x5f, 0x69, 0x73, 0x73, 0x75, 0x65, 0x64, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, - 0x65, 0x49, 0x73, 0x73, 0x75, 0x65, 0x64, 0x12, 0x28, 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, - 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, - 0x52, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, - 0x01, 0x12, 0x50, 0x0a, 0x10, 0x69, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x5f, 0x6c, 0x69, 0x73, - 0x74, 0x65, 0x6e, 0x65, 0x72, 0x18, 0x32, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x49, 0x6e, - 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x4c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x65, 0x72, 0x48, 0x01, 0x52, - 0x0f, 0x69, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x4c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x65, 0x72, - 0x88, 0x01, 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x69, 0x6e, 0x62, 0x6f, 0x75, 0x6e, - 0x64, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x65, 0x72, 0x22, 0x6f, 0x0a, 0x14, 0x50, 0x72, - 0x6f, 0x78, 0x79, 0x49, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x4c, 0x69, 0x73, 0x74, 0x65, 0x6e, - 0x65, 0x72, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x5f, 0x69, 0x70, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x74, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x49, 0x70, 0x12, - 0x1d, 0x0a, 0x0a, 0x68, 0x74, 0x74, 0x70, 0x73, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x0d, 0x52, 0x09, 0x68, 0x74, 0x74, 0x70, 0x73, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1b, - 0x0a, 0x09, 0x68, 0x74, 0x74, 0x70, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x0d, 0x52, 0x08, 0x68, 0x74, 0x74, 0x70, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x1a, 0x0a, 0x18, 0x53, - 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xb8, 0x01, 0x0a, 0x16, 0x43, 0x72, 0x65, 0x61, - 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, - 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, - 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x30, 0x0a, 0x14, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, - 0x61, 0x72, 0x64, 0x5f, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, - 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6c, 0x75, 0x73, - 0x74, 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, - 0x65, 0x72, 0x22, 0x6f, 0x0a, 0x17, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, - 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, - 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, - 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, - 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, - 0x52, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, - 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x22, 0x65, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, - 0x4c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, - 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, - 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, - 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, - 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x72, 0x6c, 0x22, 0x26, 0x0a, 0x12, 0x47, 0x65, - 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, - 0x72, 0x6c, 0x22, 0x55, 0x0a, 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, - 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, - 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, - 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, - 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xdc, 0x01, 0x0a, 0x17, 0x56, 0x61, - 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x75, - 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, - 0x65, 0x72, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x65, 0x6d, 0x61, - 0x69, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x73, 0x65, 0x72, 0x45, 0x6d, - 0x61, 0x69, 0x6c, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x5f, 0x72, 0x65, - 0x61, 0x73, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x64, 0x65, 0x6e, 0x69, - 0x65, 0x64, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0e, 0x70, 0x65, 0x65, 0x72, - 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x0c, 0x70, 0x65, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x28, - 0x0a, 0x10, 0x70, 0x65, 0x65, 0x72, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x6e, 0x61, 0x6d, - 0x65, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0e, 0x70, 0x65, 0x65, 0x72, 0x47, 0x72, - 0x6f, 0x75, 0x70, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x22, 0x50, 0x0a, 0x19, 0x56, 0x61, 0x6c, 0x69, - 0x64, 0x61, 0x74, 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x5f, - 0x69, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x74, 0x75, 0x6e, 0x6e, 0x65, 0x6c, - 0x49, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x84, 0x02, 0x0a, 0x1a, 0x56, - 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x50, 0x65, 0x65, - 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, - 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, - 0x5f, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x73, - 0x65, 0x72, 0x45, 0x6d, 0x61, 0x69, 0x6c, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x6e, 0x69, 0x65, - 0x64, 0x5f, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, - 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x12, 0x23, 0x0a, 0x0d, - 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, - 0x6e, 0x12, 0x24, 0x0a, 0x0e, 0x70, 0x65, 0x65, 0x72, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, - 0x69, 0x64, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x65, 0x65, 0x72, 0x47, - 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x28, 0x0a, 0x10, 0x70, 0x65, 0x65, 0x72, 0x5f, - 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x0e, 0x70, 0x65, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x4e, 0x61, 0x6d, 0x65, - 0x73, 0x22, 0x81, 0x01, 0x0a, 0x13, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, - 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x32, 0x0a, 0x04, 0x69, 0x6e, 0x69, - 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, - 0x73, 0x49, 0x6e, 0x69, 0x74, 0x48, 0x00, 0x52, 0x04, 0x69, 0x6e, 0x69, 0x74, 0x12, 0x2f, 0x0a, - 0x03, 0x61, 0x63, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, - 0x69, 0x6e, 0x67, 0x73, 0x41, 0x63, 0x6b, 0x48, 0x00, 0x52, 0x03, 0x61, 0x63, 0x6b, 0x42, 0x05, - 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x22, 0xdf, 0x01, 0x0a, 0x10, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, - 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x70, 0x72, - 0x6f, 0x78, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x72, - 0x6f, 0x78, 0x79, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, - 0x39, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, - 0x09, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, - 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, - 0x72, 0x65, 0x73, 0x73, 0x12, 0x41, 0x0a, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, - 0x74, 0x69, 0x65, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, - 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x52, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, - 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x22, 0x11, 0x0a, 0x0f, 0x53, 0x79, 0x6e, 0x63, 0x4d, - 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x41, 0x63, 0x6b, 0x22, 0x7e, 0x0a, 0x14, 0x53, 0x79, - 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x07, 0x6d, - 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, - 0x6c, 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x79, - 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x2a, 0x64, 0x0a, 0x16, 0x50, 0x72, - 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, - 0x54, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, - 0x59, 0x50, 0x45, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, - 0x14, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x4f, 0x44, - 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, - 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x52, 0x45, 0x4d, 0x4f, 0x56, 0x45, 0x44, 0x10, 0x02, - 0x2a, 0x46, 0x0a, 0x0f, 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, - 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, - 0x49, 0x54, 0x45, 0x5f, 0x44, 0x45, 0x46, 0x41, 0x55, 0x4c, 0x54, 0x10, 0x00, 0x12, 0x19, 0x0a, - 0x15, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, 0x50, 0x52, - 0x45, 0x53, 0x45, 0x52, 0x56, 0x45, 0x10, 0x01, 0x2a, 0xc8, 0x01, 0x0a, 0x0b, 0x50, 0x72, 0x6f, - 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x52, 0x4f, 0x58, - 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, - 0x10, 0x00, 0x12, 0x17, 0x0a, 0x13, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, - 0x55, 0x53, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, 0x10, 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x50, - 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x54, 0x55, 0x4e, 0x4e, - 0x45, 0x4c, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x02, - 0x12, 0x24, 0x0a, 0x20, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, - 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x50, 0x45, 0x4e, - 0x44, 0x49, 0x4e, 0x47, 0x10, 0x03, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, - 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, - 0x54, 0x45, 0x5f, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, 0x12, 0x16, 0x0a, 0x12, 0x50, - 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x45, 0x52, 0x52, 0x4f, - 0x52, 0x10, 0x05, 0x32, 0xb8, 0x06, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x12, 0x5f, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, - 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, - 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, - 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x30, 0x01, 0x12, 0x55, 0x0a, 0x0c, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, - 0x70, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, 0x01, 0x30, 0x01, 0x12, 0x54, 0x0a, 0x0d, - 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x20, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, - 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, - 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x51, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, - 0x74, 0x65, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x12, 0x2b, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, + 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, + 0x1d, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x06, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x75, 0x74, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2e, + 0x0a, 0x04, 0x61, 0x75, 0x74, 0x68, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, + 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x04, 0x61, 0x75, 0x74, 0x68, 0x12, 0x28, + 0x0a, 0x10, 0x70, 0x61, 0x73, 0x73, 0x5f, 0x68, 0x6f, 0x73, 0x74, 0x5f, 0x68, 0x65, 0x61, 0x64, + 0x65, 0x72, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x70, 0x61, 0x73, 0x73, 0x48, 0x6f, + 0x73, 0x74, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x12, 0x2b, 0x0a, 0x11, 0x72, 0x65, 0x77, 0x72, + 0x69, 0x74, 0x65, 0x5f, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x52, 0x65, 0x64, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x0a, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x69, 0x73, + 0x74, 0x65, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, + 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x4f, 0x0a, 0x13, 0x61, 0x63, + 0x63, 0x65, 0x73, 0x73, 0x5f, 0x72, 0x65, 0x73, 0x74, 0x72, 0x69, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x74, 0x72, + 0x69, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x12, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, + 0x65, 0x73, 0x74, 0x72, 0x69, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x70, + 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x70, 0x72, + 0x69, 0x76, 0x61, 0x74, 0x65, 0x22, 0x3f, 0x0a, 0x14, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, + 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, + 0x03, 0x6c, 0x6f, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, + 0x67, 0x52, 0x03, 0x6c, 0x6f, 0x67, 0x22, 0x17, 0x0a, 0x15, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, + 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0xa9, 0x05, 0x0a, 0x09, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x38, 0x0a, + 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, + 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x15, 0x0a, 0x06, 0x6c, 0x6f, 0x67, 0x5f, 0x69, + 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6c, 0x6f, 0x67, 0x49, 0x64, 0x12, 0x1d, + 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1d, 0x0a, + 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, + 0x68, 0x6f, 0x73, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, + 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x70, 0x61, 0x74, 0x68, 0x12, 0x1f, 0x0a, 0x0b, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x5f, 0x6d, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0a, 0x64, 0x75, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x4d, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, + 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x23, 0x0a, + 0x0d, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x09, + 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x43, 0x6f, + 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, + 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, + 0x25, 0x0a, 0x0e, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, + 0x6d, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x61, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x63, + 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, + 0x64, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, + 0x21, 0x0a, 0x0c, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, + 0x0d, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, 0x63, 0x63, 0x65, + 0x73, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x75, 0x70, 0x6c, 0x6f, + 0x61, 0x64, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x62, 0x79, 0x74, 0x65, 0x73, 0x55, + 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x25, 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x64, + 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x62, + 0x79, 0x74, 0x65, 0x73, 0x44, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x1a, 0x0a, 0x08, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x10, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3f, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, + 0x64, 0x61, 0x74, 0x61, 0x18, 0x11, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, + 0x67, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, + 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x23, 0x0a, 0x0d, 0x61, 0x67, 0x65, + 0x6e, 0x74, 0x5f, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x12, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x1a, 0x3b, + 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, + 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, + 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xf8, 0x01, 0x0a, 0x13, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5d, 0x0a, 0x10, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, - 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x12, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, - 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, - 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x4b, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x12, 0x1d, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, - 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, - 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, - 0x0f, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, - 0x12, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, + 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, + 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, + 0x49, 0x64, 0x12, 0x39, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x48, 0x00, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x2a, 0x0a, + 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x48, 0x00, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x40, 0x0a, 0x0b, 0x68, 0x65, 0x61, + 0x64, 0x65, 0x72, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x65, 0x61, 0x64, + 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, + 0x0a, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x42, 0x09, 0x0a, 0x07, 0x72, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x57, 0x0a, 0x11, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, + 0x41, 0x75, 0x74, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x21, 0x0a, 0x0c, 0x68, + 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x1f, + 0x0a, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0a, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x22, + 0x2d, 0x0a, 0x0f, 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x1e, + 0x0a, 0x0a, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, + 0x70, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x22, 0x55, + 0x0a, 0x14, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, + 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, + 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xda, 0x02, 0x0a, 0x17, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, + 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, + 0x2f, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, + 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x12, 0x2d, 0x0a, 0x12, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x5f, + 0x69, 0x73, 0x73, 0x75, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x63, 0x65, + 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x49, 0x73, 0x73, 0x75, 0x65, 0x64, 0x12, + 0x28, 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, 0x12, 0x50, 0x0a, 0x10, 0x69, 0x6e, 0x62, + 0x6f, 0x75, 0x6e, 0x64, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x65, 0x72, 0x18, 0x32, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x49, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x4c, 0x69, 0x73, + 0x74, 0x65, 0x6e, 0x65, 0x72, 0x48, 0x01, 0x52, 0x0f, 0x69, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, + 0x4c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x65, 0x72, 0x88, 0x01, 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, + 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x42, 0x13, 0x0a, + 0x11, 0x5f, 0x69, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, + 0x65, 0x72, 0x22, 0x6f, 0x0a, 0x14, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x49, 0x6e, 0x62, 0x6f, 0x75, + 0x6e, 0x64, 0x4c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x65, 0x72, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x75, + 0x6e, 0x6e, 0x65, 0x6c, 0x5f, 0x69, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x74, + 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x49, 0x70, 0x12, 0x1d, 0x0a, 0x0a, 0x68, 0x74, 0x74, 0x70, 0x73, + 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x68, 0x74, 0x74, + 0x70, 0x73, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x68, 0x74, 0x74, 0x70, 0x5f, 0x70, + 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x68, 0x74, 0x74, 0x70, 0x50, + 0x6f, 0x72, 0x74, 0x22, 0x1a, 0x0a, 0x18, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0xb8, 0x01, 0x0a, 0x16, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, + 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, + 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, + 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, + 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x30, + 0x0a, 0x14, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x5f, 0x70, 0x75, 0x62, 0x6c, + 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x77, 0x69, + 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, + 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x22, 0x6f, 0x0a, 0x17, 0x43, 0x72, + 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, + 0x28, 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, + 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x65, 0x0a, 0x11, 0x47, + 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, + 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, + 0x21, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, 0x6c, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, + 0x72, 0x6c, 0x22, 0x26, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x22, 0x55, 0x0a, 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, - 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x63, 0x0a, 0x12, 0x56, 0x61, 0x6c, - 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x12, - 0x25, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, - 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x26, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x23, 0x0a, 0x0d, + 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x22, 0xdc, 0x01, 0x0a, 0x17, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, + 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, + 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x76, 0x61, + 0x6c, 0x69, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, + 0x75, 0x73, 0x65, 0x72, 0x5f, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x09, 0x75, 0x73, 0x65, 0x72, 0x45, 0x6d, 0x61, 0x69, 0x6c, 0x12, 0x23, 0x0a, 0x0d, 0x64, + 0x65, 0x6e, 0x69, 0x65, 0x64, 0x5f, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0c, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, + 0x12, 0x24, 0x0a, 0x0e, 0x70, 0x65, 0x65, 0x72, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, + 0x64, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x65, 0x65, 0x72, 0x47, 0x72, + 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x28, 0x0a, 0x10, 0x70, 0x65, 0x65, 0x72, 0x5f, 0x67, + 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0e, 0x70, 0x65, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x4e, 0x61, 0x6d, 0x65, 0x73, + 0x22, 0x50, 0x0a, 0x19, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x75, 0x6e, 0x6e, + 0x65, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, + 0x09, 0x74, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x5f, 0x69, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x08, 0x74, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x49, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x22, 0x84, 0x02, 0x0a, 0x1a, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, + 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, + 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, + 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x73, 0x65, 0x72, 0x45, 0x6d, 0x61, 0x69, 0x6c, 0x12, + 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x5f, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x52, 0x65, + 0x61, 0x73, 0x6f, 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, + 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x24, 0x0a, 0x0e, 0x70, 0x65, 0x65, + 0x72, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x0c, 0x70, 0x65, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, + 0x28, 0x0a, 0x10, 0x70, 0x65, 0x65, 0x72, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x6e, 0x61, + 0x6d, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0e, 0x70, 0x65, 0x65, 0x72, 0x47, + 0x72, 0x6f, 0x75, 0x70, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x22, 0x81, 0x01, 0x0a, 0x13, 0x53, 0x79, + 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x32, 0x0a, 0x04, 0x69, 0x6e, 0x69, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, + 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x48, 0x00, 0x52, + 0x04, 0x69, 0x6e, 0x69, 0x74, 0x12, 0x2f, 0x0a, 0x03, 0x61, 0x63, 0x6b, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x41, 0x63, 0x6b, 0x48, + 0x00, 0x52, 0x03, 0x61, 0x63, 0x6b, 0x42, 0x05, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x22, 0xdf, 0x01, + 0x0a, 0x10, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x49, 0x6e, + 0x69, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x49, 0x64, 0x12, 0x18, 0x0a, + 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, + 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x72, 0x74, + 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, + 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, + 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x41, 0x0a, 0x0c, + 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x18, 0x05, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, + 0x73, 0x52, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x22, + 0x11, 0x0a, 0x0f, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x41, + 0x63, 0x6b, 0x22, 0x7e, 0x0a, 0x14, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, + 0x67, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, + 0x70, 0x70, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, + 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, + 0x0a, 0x15, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, + 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, + 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, + 0x74, 0x65, 0x22, 0xa9, 0x01, 0x0a, 0x1b, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x4c, 0x4c, 0x4d, 0x50, + 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, + 0x64, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x67, 0x72, + 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x67, + 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x1f, 0x0a, 0x0b, 0x70, 0x72, 0x6f, 0x76, 0x69, + 0x64, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x70, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x6d, 0x6f, 0x64, 0x65, + 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x22, 0xff, + 0x01, 0x0a, 0x1c, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x4c, 0x4c, 0x4d, 0x50, 0x6f, 0x6c, 0x69, 0x63, + 0x79, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x1a, 0x0a, 0x08, 0x64, 0x65, 0x63, 0x69, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x64, 0x65, 0x63, 0x69, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x2c, 0x0a, 0x12, 0x73, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x5f, 0x70, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x5f, 0x69, + 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, + 0x64, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x64, 0x12, 0x30, 0x0a, 0x14, 0x61, 0x74, 0x74, + 0x72, 0x69, 0x62, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, + 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x61, 0x74, 0x74, 0x72, 0x69, 0x62, 0x75, + 0x74, 0x69, 0x6f, 0x6e, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x12, 0x25, 0x0a, 0x0e, 0x77, + 0x69, 0x6e, 0x64, 0x6f, 0x77, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x0d, 0x77, 0x69, 0x6e, 0x64, 0x6f, 0x77, 0x53, 0x65, 0x63, 0x6f, 0x6e, + 0x64, 0x73, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x65, 0x6e, 0x79, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, + 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x64, 0x65, 0x6e, 0x79, 0x43, 0x6f, 0x64, 0x65, 0x12, + 0x1f, 0x0a, 0x0b, 0x64, 0x65, 0x6e, 0x79, 0x5f, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x64, 0x65, 0x6e, 0x79, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, + 0x22, 0x91, 0x02, 0x0a, 0x15, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x4c, 0x4c, 0x4d, 0x55, 0x73, + 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, + 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, + 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, + 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, + 0x49, 0x64, 0x12, 0x19, 0x0a, 0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x12, 0x25, 0x0a, + 0x0e, 0x77, 0x69, 0x6e, 0x64, 0x6f, 0x77, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x77, 0x69, 0x6e, 0x64, 0x6f, 0x77, 0x53, 0x65, 0x63, + 0x6f, 0x6e, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x5f, 0x69, + 0x6e, 0x70, 0x75, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x74, 0x6f, 0x6b, 0x65, + 0x6e, 0x73, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x73, 0x5f, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0c, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x12, 0x19, 0x0a, 0x08, + 0x63, 0x6f, 0x73, 0x74, 0x5f, 0x75, 0x73, 0x64, 0x18, 0x07, 0x20, 0x01, 0x28, 0x01, 0x52, 0x07, + 0x63, 0x6f, 0x73, 0x74, 0x55, 0x73, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x67, 0x72, 0x6f, 0x75, 0x70, + 0x5f, 0x69, 0x64, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x67, 0x72, 0x6f, 0x75, + 0x70, 0x49, 0x64, 0x73, 0x22, 0x18, 0x0a, 0x16, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x4c, 0x4c, + 0x4d, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x64, + 0x0a, 0x16, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, + 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, + 0x00, 0x12, 0x18, 0x0a, 0x14, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, + 0x5f, 0x4d, 0x4f, 0x44, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, 0x55, + 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x52, 0x45, 0x4d, 0x4f, 0x56, + 0x45, 0x44, 0x10, 0x02, 0x2a, 0x46, 0x0a, 0x0f, 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, + 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x41, 0x54, 0x48, 0x5f, + 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, 0x44, 0x45, 0x46, 0x41, 0x55, 0x4c, 0x54, 0x10, + 0x00, 0x12, 0x19, 0x0a, 0x15, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, + 0x45, 0x5f, 0x50, 0x52, 0x45, 0x53, 0x45, 0x52, 0x56, 0x45, 0x10, 0x01, 0x2a, 0x90, 0x01, 0x0a, + 0x0e, 0x4d, 0x69, 0x64, 0x64, 0x6c, 0x65, 0x77, 0x61, 0x72, 0x65, 0x53, 0x6c, 0x6f, 0x74, 0x12, + 0x1f, 0x0a, 0x1b, 0x4d, 0x49, 0x44, 0x44, 0x4c, 0x45, 0x57, 0x41, 0x52, 0x45, 0x5f, 0x53, 0x4c, + 0x4f, 0x54, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, + 0x12, 0x1e, 0x0a, 0x1a, 0x4d, 0x49, 0x44, 0x44, 0x4c, 0x45, 0x57, 0x41, 0x52, 0x45, 0x5f, 0x53, + 0x4c, 0x4f, 0x54, 0x5f, 0x4f, 0x4e, 0x5f, 0x52, 0x45, 0x51, 0x55, 0x45, 0x53, 0x54, 0x10, 0x01, + 0x12, 0x1f, 0x0a, 0x1b, 0x4d, 0x49, 0x44, 0x44, 0x4c, 0x45, 0x57, 0x41, 0x52, 0x45, 0x5f, 0x53, + 0x4c, 0x4f, 0x54, 0x5f, 0x4f, 0x4e, 0x5f, 0x52, 0x45, 0x53, 0x50, 0x4f, 0x4e, 0x53, 0x45, 0x10, + 0x02, 0x12, 0x1c, 0x0a, 0x18, 0x4d, 0x49, 0x44, 0x44, 0x4c, 0x45, 0x57, 0x41, 0x52, 0x45, 0x5f, + 0x53, 0x4c, 0x4f, 0x54, 0x5f, 0x54, 0x45, 0x52, 0x4d, 0x49, 0x4e, 0x41, 0x4c, 0x10, 0x03, 0x2a, + 0xc8, 0x01, 0x0a, 0x0b, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, + 0x18, 0x0a, 0x14, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, + 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x00, 0x12, 0x17, 0x0a, 0x13, 0x50, 0x52, 0x4f, + 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, + 0x10, 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, + 0x55, 0x53, 0x5f, 0x54, 0x55, 0x4e, 0x4e, 0x45, 0x4c, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x43, 0x52, + 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x02, 0x12, 0x24, 0x0a, 0x20, 0x50, 0x52, 0x4f, 0x58, 0x59, + 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, + 0x41, 0x54, 0x45, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x03, 0x12, 0x23, 0x0a, + 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, + 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, + 0x10, 0x04, 0x12, 0x16, 0x0a, 0x12, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, + 0x55, 0x53, 0x5f, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x05, 0x32, 0xfc, 0x07, 0x0a, 0x0c, 0x50, + 0x72, 0x6f, 0x78, 0x79, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x5f, 0x0a, 0x10, 0x47, + 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, + 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, + 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, + 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x30, 0x01, 0x12, 0x55, 0x0a, 0x0c, + 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x1f, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, + 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, + 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, + 0x01, 0x30, 0x01, 0x12, 0x54, 0x0a, 0x0d, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, + 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x51, 0x0a, 0x0c, 0x41, 0x75, 0x74, + 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, + 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5d, 0x0a, 0x10, + 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, + 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, + 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x43, + 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x12, 0x22, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, + 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4b, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x4f, 0x49, + 0x44, 0x43, 0x55, 0x52, 0x4c, 0x12, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, + 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, + 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, + 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x63, 0x0a, 0x12, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x75, 0x6e, 0x6e, + 0x65, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x12, 0x25, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x75, 0x6e, 0x6e, - 0x65, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x08, - 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x26, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, + 0x61, 0x74, 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x69, 0x0a, 0x14, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x4c, 0x4c, + 0x4d, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x12, 0x27, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b, + 0x4c, 0x4c, 0x4d, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x4c, 0x4c, 0x4d, 0x50, 0x6f, 0x6c, 0x69, + 0x63, 0x79, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x57, 0x0a, 0x0e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x4c, 0x4c, 0x4d, 0x55, 0x73, 0x61, + 0x67, 0x65, 0x12, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x4c, 0x4c, 0x4d, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x4c, 0x4c, 0x4d, 0x55, 0x73, 0x61, 0x67, + 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -3013,99 +3734,114 @@ func file_proxy_service_proto_rawDescGZIP() []byte { return file_proxy_service_proto_rawDescData } -var file_proxy_service_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 34) +var file_proxy_service_proto_enumTypes = make([]protoimpl.EnumInfo, 5) +var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 39) var file_proxy_service_proto_goTypes = []interface{}{ - (ProxyMappingUpdateType)(0), // 0: management.ProxyMappingUpdateType - (PathRewriteMode)(0), // 1: management.PathRewriteMode - (ProxyStatus)(0), // 2: management.ProxyStatus - (*ProxyCapabilities)(nil), // 3: management.ProxyCapabilities - (*GetMappingUpdateRequest)(nil), // 4: management.GetMappingUpdateRequest - (*GetMappingUpdateResponse)(nil), // 5: management.GetMappingUpdateResponse - (*PathTargetOptions)(nil), // 6: management.PathTargetOptions - (*PathMapping)(nil), // 7: management.PathMapping - (*HeaderAuth)(nil), // 8: management.HeaderAuth - (*Authentication)(nil), // 9: management.Authentication - (*AccessRestrictions)(nil), // 10: management.AccessRestrictions - (*ProxyMapping)(nil), // 11: management.ProxyMapping - (*SendAccessLogRequest)(nil), // 12: management.SendAccessLogRequest - (*SendAccessLogResponse)(nil), // 13: management.SendAccessLogResponse - (*AccessLog)(nil), // 14: management.AccessLog - (*AuthenticateRequest)(nil), // 15: management.AuthenticateRequest - (*HeaderAuthRequest)(nil), // 16: management.HeaderAuthRequest - (*PasswordRequest)(nil), // 17: management.PasswordRequest - (*PinRequest)(nil), // 18: management.PinRequest - (*AuthenticateResponse)(nil), // 19: management.AuthenticateResponse - (*SendStatusUpdateRequest)(nil), // 20: management.SendStatusUpdateRequest - (*ProxyInboundListener)(nil), // 21: management.ProxyInboundListener - (*SendStatusUpdateResponse)(nil), // 22: management.SendStatusUpdateResponse - (*CreateProxyPeerRequest)(nil), // 23: management.CreateProxyPeerRequest - (*CreateProxyPeerResponse)(nil), // 24: management.CreateProxyPeerResponse - (*GetOIDCURLRequest)(nil), // 25: management.GetOIDCURLRequest - (*GetOIDCURLResponse)(nil), // 26: management.GetOIDCURLResponse - (*ValidateSessionRequest)(nil), // 27: management.ValidateSessionRequest - (*ValidateSessionResponse)(nil), // 28: management.ValidateSessionResponse - (*ValidateTunnelPeerRequest)(nil), // 29: management.ValidateTunnelPeerRequest - (*ValidateTunnelPeerResponse)(nil), // 30: management.ValidateTunnelPeerResponse - (*SyncMappingsRequest)(nil), // 31: management.SyncMappingsRequest - (*SyncMappingsInit)(nil), // 32: management.SyncMappingsInit - (*SyncMappingsAck)(nil), // 33: management.SyncMappingsAck - (*SyncMappingsResponse)(nil), // 34: management.SyncMappingsResponse - nil, // 35: management.PathTargetOptions.CustomHeadersEntry - nil, // 36: management.AccessLog.MetadataEntry - (*timestamppb.Timestamp)(nil), // 37: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 38: google.protobuf.Duration + (ProxyMappingUpdateType)(0), // 0: management.ProxyMappingUpdateType + (PathRewriteMode)(0), // 1: management.PathRewriteMode + (MiddlewareSlot)(0), // 2: management.MiddlewareSlot + (ProxyStatus)(0), // 3: management.ProxyStatus + (MiddlewareConfig_FailMode)(0), // 4: management.MiddlewareConfig.FailMode + (*ProxyCapabilities)(nil), // 5: management.ProxyCapabilities + (*GetMappingUpdateRequest)(nil), // 6: management.GetMappingUpdateRequest + (*GetMappingUpdateResponse)(nil), // 7: management.GetMappingUpdateResponse + (*PathTargetOptions)(nil), // 8: management.PathTargetOptions + (*MiddlewareConfig)(nil), // 9: management.MiddlewareConfig + (*PathMapping)(nil), // 10: management.PathMapping + (*HeaderAuth)(nil), // 11: management.HeaderAuth + (*Authentication)(nil), // 12: management.Authentication + (*AccessRestrictions)(nil), // 13: management.AccessRestrictions + (*ProxyMapping)(nil), // 14: management.ProxyMapping + (*SendAccessLogRequest)(nil), // 15: management.SendAccessLogRequest + (*SendAccessLogResponse)(nil), // 16: management.SendAccessLogResponse + (*AccessLog)(nil), // 17: management.AccessLog + (*AuthenticateRequest)(nil), // 18: management.AuthenticateRequest + (*HeaderAuthRequest)(nil), // 19: management.HeaderAuthRequest + (*PasswordRequest)(nil), // 20: management.PasswordRequest + (*PinRequest)(nil), // 21: management.PinRequest + (*AuthenticateResponse)(nil), // 22: management.AuthenticateResponse + (*SendStatusUpdateRequest)(nil), // 23: management.SendStatusUpdateRequest + (*ProxyInboundListener)(nil), // 24: management.ProxyInboundListener + (*SendStatusUpdateResponse)(nil), // 25: management.SendStatusUpdateResponse + (*CreateProxyPeerRequest)(nil), // 26: management.CreateProxyPeerRequest + (*CreateProxyPeerResponse)(nil), // 27: management.CreateProxyPeerResponse + (*GetOIDCURLRequest)(nil), // 28: management.GetOIDCURLRequest + (*GetOIDCURLResponse)(nil), // 29: management.GetOIDCURLResponse + (*ValidateSessionRequest)(nil), // 30: management.ValidateSessionRequest + (*ValidateSessionResponse)(nil), // 31: management.ValidateSessionResponse + (*ValidateTunnelPeerRequest)(nil), // 32: management.ValidateTunnelPeerRequest + (*ValidateTunnelPeerResponse)(nil), // 33: management.ValidateTunnelPeerResponse + (*SyncMappingsRequest)(nil), // 34: management.SyncMappingsRequest + (*SyncMappingsInit)(nil), // 35: management.SyncMappingsInit + (*SyncMappingsAck)(nil), // 36: management.SyncMappingsAck + (*SyncMappingsResponse)(nil), // 37: management.SyncMappingsResponse + (*CheckLLMPolicyLimitsRequest)(nil), // 38: management.CheckLLMPolicyLimitsRequest + (*CheckLLMPolicyLimitsResponse)(nil), // 39: management.CheckLLMPolicyLimitsResponse + (*RecordLLMUsageRequest)(nil), // 40: management.RecordLLMUsageRequest + (*RecordLLMUsageResponse)(nil), // 41: management.RecordLLMUsageResponse + nil, // 42: management.PathTargetOptions.CustomHeadersEntry + nil, // 43: management.AccessLog.MetadataEntry + (*timestamppb.Timestamp)(nil), // 44: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 45: google.protobuf.Duration } var file_proxy_service_proto_depIdxs = []int32{ - 37, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp - 3, // 1: management.GetMappingUpdateRequest.capabilities:type_name -> management.ProxyCapabilities - 11, // 2: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping - 38, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration + 44, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp + 5, // 1: management.GetMappingUpdateRequest.capabilities:type_name -> management.ProxyCapabilities + 14, // 2: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping + 45, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration 1, // 4: management.PathTargetOptions.path_rewrite:type_name -> management.PathRewriteMode - 35, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry - 38, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration - 6, // 7: management.PathMapping.options:type_name -> management.PathTargetOptions - 8, // 8: management.Authentication.header_auths:type_name -> management.HeaderAuth - 0, // 9: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType - 7, // 10: management.ProxyMapping.path:type_name -> management.PathMapping - 9, // 11: management.ProxyMapping.auth:type_name -> management.Authentication - 10, // 12: management.ProxyMapping.access_restrictions:type_name -> management.AccessRestrictions - 14, // 13: management.SendAccessLogRequest.log:type_name -> management.AccessLog - 37, // 14: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp - 36, // 15: management.AccessLog.metadata:type_name -> management.AccessLog.MetadataEntry - 17, // 16: management.AuthenticateRequest.password:type_name -> management.PasswordRequest - 18, // 17: management.AuthenticateRequest.pin:type_name -> management.PinRequest - 16, // 18: management.AuthenticateRequest.header_auth:type_name -> management.HeaderAuthRequest - 2, // 19: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus - 21, // 20: management.SendStatusUpdateRequest.inbound_listener:type_name -> management.ProxyInboundListener - 32, // 21: management.SyncMappingsRequest.init:type_name -> management.SyncMappingsInit - 33, // 22: management.SyncMappingsRequest.ack:type_name -> management.SyncMappingsAck - 37, // 23: management.SyncMappingsInit.started_at:type_name -> google.protobuf.Timestamp - 3, // 24: management.SyncMappingsInit.capabilities:type_name -> management.ProxyCapabilities - 11, // 25: management.SyncMappingsResponse.mapping:type_name -> management.ProxyMapping - 4, // 26: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest - 31, // 27: management.ProxyService.SyncMappings:input_type -> management.SyncMappingsRequest - 12, // 28: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest - 15, // 29: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest - 20, // 30: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest - 23, // 31: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest - 25, // 32: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest - 27, // 33: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest - 29, // 34: management.ProxyService.ValidateTunnelPeer:input_type -> management.ValidateTunnelPeerRequest - 5, // 35: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse - 34, // 36: management.ProxyService.SyncMappings:output_type -> management.SyncMappingsResponse - 13, // 37: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse - 19, // 38: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse - 22, // 39: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse - 24, // 40: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse - 26, // 41: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse - 28, // 42: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse - 30, // 43: management.ProxyService.ValidateTunnelPeer:output_type -> management.ValidateTunnelPeerResponse - 35, // [35:44] is the sub-list for method output_type - 26, // [26:35] is the sub-list for method input_type - 26, // [26:26] is the sub-list for extension type_name - 26, // [26:26] is the sub-list for extension extendee - 0, // [0:26] is the sub-list for field type_name + 42, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry + 45, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration + 9, // 7: management.PathTargetOptions.middlewares:type_name -> management.MiddlewareConfig + 2, // 8: management.MiddlewareConfig.slot:type_name -> management.MiddlewareSlot + 4, // 9: management.MiddlewareConfig.fail_mode:type_name -> management.MiddlewareConfig.FailMode + 45, // 10: management.MiddlewareConfig.timeout:type_name -> google.protobuf.Duration + 8, // 11: management.PathMapping.options:type_name -> management.PathTargetOptions + 11, // 12: management.Authentication.header_auths:type_name -> management.HeaderAuth + 0, // 13: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType + 10, // 14: management.ProxyMapping.path:type_name -> management.PathMapping + 12, // 15: management.ProxyMapping.auth:type_name -> management.Authentication + 13, // 16: management.ProxyMapping.access_restrictions:type_name -> management.AccessRestrictions + 17, // 17: management.SendAccessLogRequest.log:type_name -> management.AccessLog + 44, // 18: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp + 43, // 19: management.AccessLog.metadata:type_name -> management.AccessLog.MetadataEntry + 20, // 20: management.AuthenticateRequest.password:type_name -> management.PasswordRequest + 21, // 21: management.AuthenticateRequest.pin:type_name -> management.PinRequest + 19, // 22: management.AuthenticateRequest.header_auth:type_name -> management.HeaderAuthRequest + 3, // 23: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus + 24, // 24: management.SendStatusUpdateRequest.inbound_listener:type_name -> management.ProxyInboundListener + 35, // 25: management.SyncMappingsRequest.init:type_name -> management.SyncMappingsInit + 36, // 26: management.SyncMappingsRequest.ack:type_name -> management.SyncMappingsAck + 44, // 27: management.SyncMappingsInit.started_at:type_name -> google.protobuf.Timestamp + 5, // 28: management.SyncMappingsInit.capabilities:type_name -> management.ProxyCapabilities + 14, // 29: management.SyncMappingsResponse.mapping:type_name -> management.ProxyMapping + 6, // 30: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest + 34, // 31: management.ProxyService.SyncMappings:input_type -> management.SyncMappingsRequest + 15, // 32: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest + 18, // 33: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest + 23, // 34: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest + 26, // 35: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest + 28, // 36: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest + 30, // 37: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest + 32, // 38: management.ProxyService.ValidateTunnelPeer:input_type -> management.ValidateTunnelPeerRequest + 38, // 39: management.ProxyService.CheckLLMPolicyLimits:input_type -> management.CheckLLMPolicyLimitsRequest + 40, // 40: management.ProxyService.RecordLLMUsage:input_type -> management.RecordLLMUsageRequest + 7, // 41: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse + 37, // 42: management.ProxyService.SyncMappings:output_type -> management.SyncMappingsResponse + 16, // 43: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse + 22, // 44: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse + 25, // 45: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse + 27, // 46: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse + 29, // 47: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse + 31, // 48: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse + 33, // 49: management.ProxyService.ValidateTunnelPeer:output_type -> management.ValidateTunnelPeerResponse + 39, // 50: management.ProxyService.CheckLLMPolicyLimits:output_type -> management.CheckLLMPolicyLimitsResponse + 41, // 51: management.ProxyService.RecordLLMUsage:output_type -> management.RecordLLMUsageResponse + 41, // [41:52] is the sub-list for method output_type + 30, // [30:41] is the sub-list for method input_type + 30, // [30:30] is the sub-list for extension type_name + 30, // [30:30] is the sub-list for extension extendee + 0, // [0:30] is the sub-list for field type_name } func init() { file_proxy_service_proto_init() } @@ -3163,7 +3899,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PathMapping); i { + switch v := v.(*MiddlewareConfig); i { case 0: return &v.state case 1: @@ -3175,7 +3911,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*HeaderAuth); i { + switch v := v.(*PathMapping); i { case 0: return &v.state case 1: @@ -3187,7 +3923,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Authentication); i { + switch v := v.(*HeaderAuth); i { case 0: return &v.state case 1: @@ -3199,7 +3935,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AccessRestrictions); i { + switch v := v.(*Authentication); i { case 0: return &v.state case 1: @@ -3211,7 +3947,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProxyMapping); i { + switch v := v.(*AccessRestrictions); i { case 0: return &v.state case 1: @@ -3223,7 +3959,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendAccessLogRequest); i { + switch v := v.(*ProxyMapping); i { case 0: return &v.state case 1: @@ -3235,7 +3971,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendAccessLogResponse); i { + switch v := v.(*SendAccessLogRequest); i { case 0: return &v.state case 1: @@ -3247,7 +3983,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AccessLog); i { + switch v := v.(*SendAccessLogResponse); i { case 0: return &v.state case 1: @@ -3259,7 +3995,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AuthenticateRequest); i { + switch v := v.(*AccessLog); i { case 0: return &v.state case 1: @@ -3271,7 +4007,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[13].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*HeaderAuthRequest); i { + switch v := v.(*AuthenticateRequest); i { case 0: return &v.state case 1: @@ -3283,7 +4019,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[14].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PasswordRequest); i { + switch v := v.(*HeaderAuthRequest); i { case 0: return &v.state case 1: @@ -3295,7 +4031,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[15].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PinRequest); i { + switch v := v.(*PasswordRequest); i { case 0: return &v.state case 1: @@ -3307,7 +4043,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[16].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AuthenticateResponse); i { + switch v := v.(*PinRequest); i { case 0: return &v.state case 1: @@ -3319,7 +4055,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendStatusUpdateRequest); i { + switch v := v.(*AuthenticateResponse); i { case 0: return &v.state case 1: @@ -3331,7 +4067,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProxyInboundListener); i { + switch v := v.(*SendStatusUpdateRequest); i { case 0: return &v.state case 1: @@ -3343,7 +4079,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendStatusUpdateResponse); i { + switch v := v.(*ProxyInboundListener); i { case 0: return &v.state case 1: @@ -3355,7 +4091,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateProxyPeerRequest); i { + switch v := v.(*SendStatusUpdateResponse); i { case 0: return &v.state case 1: @@ -3367,7 +4103,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateProxyPeerResponse); i { + switch v := v.(*CreateProxyPeerRequest); i { case 0: return &v.state case 1: @@ -3379,7 +4115,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetOIDCURLRequest); i { + switch v := v.(*CreateProxyPeerResponse); i { case 0: return &v.state case 1: @@ -3391,7 +4127,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetOIDCURLResponse); i { + switch v := v.(*GetOIDCURLRequest); i { case 0: return &v.state case 1: @@ -3403,7 +4139,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ValidateSessionRequest); i { + switch v := v.(*GetOIDCURLResponse); i { case 0: return &v.state case 1: @@ -3415,7 +4151,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ValidateSessionResponse); i { + switch v := v.(*ValidateSessionRequest); i { case 0: return &v.state case 1: @@ -3427,7 +4163,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ValidateTunnelPeerRequest); i { + switch v := v.(*ValidateSessionResponse); i { case 0: return &v.state case 1: @@ -3439,7 +4175,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ValidateTunnelPeerResponse); i { + switch v := v.(*ValidateTunnelPeerRequest); i { case 0: return &v.state case 1: @@ -3451,7 +4187,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SyncMappingsRequest); i { + switch v := v.(*ValidateTunnelPeerResponse); i { case 0: return &v.state case 1: @@ -3463,7 +4199,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SyncMappingsInit); i { + switch v := v.(*SyncMappingsRequest); i { case 0: return &v.state case 1: @@ -3475,7 +4211,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SyncMappingsAck); i { + switch v := v.(*SyncMappingsInit); i { case 0: return &v.state case 1: @@ -3487,6 +4223,18 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncMappingsAck); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*SyncMappingsResponse); i { case 0: return &v.state @@ -3498,16 +4246,64 @@ func file_proxy_service_proto_init() { return nil } } + file_proxy_service_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CheckLLMPolicyLimitsRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CheckLLMPolicyLimitsResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RecordLLMUsageRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RecordLLMUsageResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } file_proxy_service_proto_msgTypes[0].OneofWrappers = []interface{}{} - file_proxy_service_proto_msgTypes[12].OneofWrappers = []interface{}{ + file_proxy_service_proto_msgTypes[13].OneofWrappers = []interface{}{ (*AuthenticateRequest_Password)(nil), (*AuthenticateRequest_Pin)(nil), (*AuthenticateRequest_HeaderAuth)(nil), } - file_proxy_service_proto_msgTypes[17].OneofWrappers = []interface{}{} - file_proxy_service_proto_msgTypes[21].OneofWrappers = []interface{}{} - file_proxy_service_proto_msgTypes[28].OneofWrappers = []interface{}{ + file_proxy_service_proto_msgTypes[18].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[22].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[29].OneofWrappers = []interface{}{ (*SyncMappingsRequest_Init)(nil), (*SyncMappingsRequest_Ack)(nil), } @@ -3516,8 +4312,8 @@ func file_proxy_service_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_proxy_service_proto_rawDesc, - NumEnums: 3, - NumMessages: 34, + NumEnums: 5, + NumMessages: 39, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/proxy_service.proto b/shared/management/proto/proxy_service.proto index 14d188877..89d1f8749 100644 --- a/shared/management/proto/proxy_service.proto +++ b/shared/management/proto/proxy_service.proto @@ -43,6 +43,18 @@ service ProxyService { // issue a session cookie without redirecting through the OIDC flow. // Mirrors ValidateSession's response shape. rpc ValidateTunnelPeer(ValidateTunnelPeerRequest) returns (ValidateTunnelPeerResponse); + + // CheckLLMPolicyLimits is the pre-flight RPC the proxy calls before each + // LLM request. Management runs the per-policy headroom selection across + // every policy authorising the caller's user / groups for the resolved + // provider and returns the chosen attribution policy + group, or a deny + // when no applicable policy has headroom > 0. + rpc CheckLLMPolicyLimits(CheckLLMPolicyLimitsRequest) returns (CheckLLMPolicyLimitsResponse); + + // RecordLLMUsage is the post-flight RPC the proxy calls after the upstream + // returns. Increments the per-(dimension, window) counters for the + // attribution policy chosen by CheckLLMPolicyLimits. + rpc RecordLLMUsage(RecordLLMUsageRequest) returns (RecordLLMUsageResponse); } // ProxyCapabilities describes what a proxy can handle. @@ -107,6 +119,59 @@ message PathTargetOptions { // reachable without WireGuard (public APIs, LAN services, localhost // sidecars). Defaults to false — embedded client is the standard path. bool direct_upstream = 7; + // Proxy clamps to [0, proxy-wide max (1 MiB)] at apply time. Agent-network + // synthesized targets only; private services leave these zero. + int64 capture_max_request_bytes = 8; + // Proxy clamps to [0, proxy-wide max (1 MiB)] at apply time. + int64 capture_max_response_bytes = 9; + // Content types eligible for body capture (e.g. "application/json"). + repeated string capture_content_types = 10; + // Per-target middleware configurations populated by the agent-network + // synthesizer. Validated and clamped by the proxy at apply time. + repeated MiddlewareConfig middlewares = 11; + // When true, the proxy stamps agent_network=true on access-log entries + // for this target so management routes them to the agent-network log + // surface. + bool agent_network = 12; + // When true, the proxy suppresses the per-request access-log emission for + // this target. Defaults false to preserve existing access-log behavior for + // every non-agent-network target. The agent-network synth target sets this + // true only when the account's EnableLogCollection toggle is off. + bool disable_access_log = 13; +} + +// MiddlewareSlot identifies where in the request lifecycle a middleware +// runs. Mirrors proxy/internal/middleware.Slot. +enum MiddlewareSlot { + MIDDLEWARE_SLOT_UNSPECIFIED = 0; + MIDDLEWARE_SLOT_ON_REQUEST = 1; + MIDDLEWARE_SLOT_ON_RESPONSE = 2; + MIDDLEWARE_SLOT_TERMINAL = 3; +} + +// MiddlewareConfig is the per-target configuration for a single middleware. +// The proxy validates every incoming MiddlewareConfig at apply time: +// unknown ids are rejected, timeout is clamped to [10ms, 5s], and the +// declared slot must match the registered middleware's slot. +message MiddlewareConfig { + // Middleware id; must match the proxy-local compiled-in registry. + string id = 1; + bool enabled = 2; + MiddlewareSlot slot = 3; + // Free-form JSON unmarshalled by the middleware factory into its own typed + // config struct. Empty / null / {} are valid (zero-value config). + bytes config_json = 4; + enum FailMode { + FAIL_OPEN = 0; + FAIL_CLOSED = 1; + } + FailMode fail_mode = 5; + // Clamped to [10ms, 5s] at apply time; zero → 500ms default. + google.protobuf.Duration timeout = 6; + // When true, the middleware may mutate request headers or body (subject to + // policy). Honoured only when the implementation also declares + // MutationsSupported. + bool can_mutate = 7; } message PathMapping { @@ -190,6 +255,10 @@ message AccessLog { string protocol = 16; // Extra key-value metadata for the access log entry (e.g. crowdsec_verdict, scenario). map metadata = 17; + // When true, the entry was emitted by an agent-network synth service. + // Management routes these to the agent-network access-log surface instead + // of the standard service log. + bool agent_network = 18; } message AuthenticateRequest { @@ -376,3 +445,59 @@ message SyncMappingsResponse { bool initial_sync_complete = 2; } +// CheckLLMPolicyLimitsRequest carries the resolved caller identity and the +// upstream provider already chosen by llm_router. Management computes which +// policies authorise the request, picks the one with the most remaining +// headroom, and returns the attribution decision. +message CheckLLMPolicyLimitsRequest { + // account_id is the netbird account the request belongs to. + string account_id = 1; + // user_id is the netbird user id of the caller. May be empty when the + // principal is a tunnel-peer that isn't bound to a user; group membership + // still gates the request in that case. + string user_id = 2; + // group_ids is the caller's full group membership at request time. + repeated string group_ids = 3; + // provider_id is the agent-network provider record id chosen by llm_router. + string provider_id = 4; + // model is the upstream model identifier extracted from the request body. + string model = 5; +} + +// CheckLLMPolicyLimitsResponse is management's allow-or-deny decision for a +// pre-flight check. +message CheckLLMPolicyLimitsResponse { + // decision is "allow" or "deny". + string decision = 1; + // selected_policy_id names the policy that paid for this request. + string selected_policy_id = 2; + // attribution_group_id is the source group the request booked against. + string attribution_group_id = 3; + // window_seconds is the cap window length the selected policy uses. + int64 window_seconds = 4; + // deny_code is set on decision="deny" with a stable label. + string deny_code = 5; + // deny_reason is a short human-readable explanation paired with deny_code. + string deny_reason = 6; +} + +// RecordLLMUsageRequest is the post-flight increment the proxy posts after +// the upstream call. Counters are keyed on (account, dimension, window). +message RecordLLMUsageRequest { + string account_id = 1; + string user_id = 2; + // group_id is the selected policy's attribution group, recorded against the + // policy window (window_seconds). + string group_id = 3; + int64 window_seconds = 4; + int64 tokens_input = 5; + int64 tokens_output = 6; + double cost_usd = 7; + // group_ids is the caller's full group membership, used to fan the same + // usage out to every applicable account-level budget rule's own window. + repeated string group_ids = 8; +} + +message RecordLLMUsageResponse { +} + diff --git a/shared/management/proto/proxy_service_grpc.pb.go b/shared/management/proto/proxy_service_grpc.pb.go index 40064fe61..76c1f005f 100644 --- a/shared/management/proto/proxy_service_grpc.pb.go +++ b/shared/management/proto/proxy_service_grpc.pb.go @@ -43,6 +43,16 @@ type ProxyServiceClient interface { // issue a session cookie without redirecting through the OIDC flow. // Mirrors ValidateSession's response shape. ValidateTunnelPeer(ctx context.Context, in *ValidateTunnelPeerRequest, opts ...grpc.CallOption) (*ValidateTunnelPeerResponse, error) + // CheckLLMPolicyLimits is the pre-flight RPC the proxy calls before each + // LLM request. Management runs the per-policy headroom selection across + // every policy authorising the caller's user / groups for the resolved + // provider and returns the chosen attribution policy + group, or a deny + // when no applicable policy has headroom > 0. + CheckLLMPolicyLimits(ctx context.Context, in *CheckLLMPolicyLimitsRequest, opts ...grpc.CallOption) (*CheckLLMPolicyLimitsResponse, error) + // RecordLLMUsage is the post-flight RPC the proxy calls after the upstream + // returns. Increments the per-(dimension, window) counters for the + // attribution policy chosen by CheckLLMPolicyLimits. + RecordLLMUsage(ctx context.Context, in *RecordLLMUsageRequest, opts ...grpc.CallOption) (*RecordLLMUsageResponse, error) } type proxyServiceClient struct { @@ -179,6 +189,24 @@ func (c *proxyServiceClient) ValidateTunnelPeer(ctx context.Context, in *Validat return out, nil } +func (c *proxyServiceClient) CheckLLMPolicyLimits(ctx context.Context, in *CheckLLMPolicyLimitsRequest, opts ...grpc.CallOption) (*CheckLLMPolicyLimitsResponse, error) { + out := new(CheckLLMPolicyLimitsResponse) + err := c.cc.Invoke(ctx, "/management.ProxyService/CheckLLMPolicyLimits", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *proxyServiceClient) RecordLLMUsage(ctx context.Context, in *RecordLLMUsageRequest, opts ...grpc.CallOption) (*RecordLLMUsageResponse, error) { + out := new(RecordLLMUsageResponse) + err := c.cc.Invoke(ctx, "/management.ProxyService/RecordLLMUsage", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // ProxyServiceServer is the server API for ProxyService service. // All implementations must embed UnimplementedProxyServiceServer // for forward compatibility @@ -208,6 +236,16 @@ type ProxyServiceServer interface { // issue a session cookie without redirecting through the OIDC flow. // Mirrors ValidateSession's response shape. ValidateTunnelPeer(context.Context, *ValidateTunnelPeerRequest) (*ValidateTunnelPeerResponse, error) + // CheckLLMPolicyLimits is the pre-flight RPC the proxy calls before each + // LLM request. Management runs the per-policy headroom selection across + // every policy authorising the caller's user / groups for the resolved + // provider and returns the chosen attribution policy + group, or a deny + // when no applicable policy has headroom > 0. + CheckLLMPolicyLimits(context.Context, *CheckLLMPolicyLimitsRequest) (*CheckLLMPolicyLimitsResponse, error) + // RecordLLMUsage is the post-flight RPC the proxy calls after the upstream + // returns. Increments the per-(dimension, window) counters for the + // attribution policy chosen by CheckLLMPolicyLimits. + RecordLLMUsage(context.Context, *RecordLLMUsageRequest) (*RecordLLMUsageResponse, error) mustEmbedUnimplementedProxyServiceServer() } @@ -242,6 +280,12 @@ func (UnimplementedProxyServiceServer) ValidateSession(context.Context, *Validat func (UnimplementedProxyServiceServer) ValidateTunnelPeer(context.Context, *ValidateTunnelPeerRequest) (*ValidateTunnelPeerResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method ValidateTunnelPeer not implemented") } +func (UnimplementedProxyServiceServer) CheckLLMPolicyLimits(context.Context, *CheckLLMPolicyLimitsRequest) (*CheckLLMPolicyLimitsResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method CheckLLMPolicyLimits not implemented") +} +func (UnimplementedProxyServiceServer) RecordLLMUsage(context.Context, *RecordLLMUsageRequest) (*RecordLLMUsageResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method RecordLLMUsage not implemented") +} func (UnimplementedProxyServiceServer) mustEmbedUnimplementedProxyServiceServer() {} // UnsafeProxyServiceServer may be embedded to opt out of forward compatibility for this service. @@ -428,6 +472,42 @@ func _ProxyService_ValidateTunnelPeer_Handler(srv interface{}, ctx context.Conte return interceptor(ctx, in, info, handler) } +func _ProxyService_CheckLLMPolicyLimits_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CheckLLMPolicyLimitsRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ProxyServiceServer).CheckLLMPolicyLimits(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/management.ProxyService/CheckLLMPolicyLimits", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ProxyServiceServer).CheckLLMPolicyLimits(ctx, req.(*CheckLLMPolicyLimitsRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ProxyService_RecordLLMUsage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RecordLLMUsageRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ProxyServiceServer).RecordLLMUsage(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/management.ProxyService/RecordLLMUsage", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ProxyServiceServer).RecordLLMUsage(ctx, req.(*RecordLLMUsageRequest)) + } + return interceptor(ctx, in, info, handler) +} + // ProxyService_ServiceDesc is the grpc.ServiceDesc for ProxyService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -463,6 +543,14 @@ var ProxyService_ServiceDesc = grpc.ServiceDesc{ MethodName: "ValidateTunnelPeer", Handler: _ProxyService_ValidateTunnelPeer_Handler, }, + { + MethodName: "CheckLLMPolicyLimits", + Handler: _ProxyService_CheckLLMPolicyLimits_Handler, + }, + { + MethodName: "RecordLLMUsage", + Handler: _ProxyService_RecordLLMUsage_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/shared/management/status/error.go b/shared/management/status/error.go index 1957c5591..e31663450 100644 --- a/shared/management/status/error.go +++ b/shared/management/status/error.go @@ -219,6 +219,26 @@ func NewNetworkResourceNotFoundError(resourceID string) error { return Errorf(NotFound, "network resource: %s not found", resourceID) } +// NewAgentNetworkProviderNotFoundError creates a new Error with NotFound type for a missing Agent Network provider. +func NewAgentNetworkProviderNotFoundError(providerID string) error { + return Errorf(NotFound, "agent network provider: %s not found", providerID) +} + +// NewAgentNetworkPolicyNotFoundError creates a new Error with NotFound type for a missing Agent Network policy. +func NewAgentNetworkPolicyNotFoundError(policyID string) error { + return Errorf(NotFound, "agent network policy: %s not found", policyID) +} + +// NewAgentNetworkGuardrailNotFoundError creates a new Error with NotFound type for a missing Agent Network guardrail. +func NewAgentNetworkGuardrailNotFoundError(guardrailID string) error { + return Errorf(NotFound, "agent network guardrail: %s not found", guardrailID) +} + +// NewAgentNetworkBudgetRuleNotFoundError creates a new Error with NotFound type for a missing Agent Network budget rule. +func NewAgentNetworkBudgetRuleNotFoundError(ruleID string) error { + return Errorf(NotFound, "agent network budget rule: %s not found", ruleID) +} + // NewPermissionDeniedError creates a new Error with PermissionDenied type for a permission denied error. func NewPermissionDeniedError() error { return Errorf(PermissionDenied, "permission denied") From 980598ed4ab21d4ba3ec5bb45b88c7471a050864 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 1 Jul 2026 14:50:25 +0200 Subject: [PATCH 03/19] [management, client] Add management-controlled client metrics push (#5886) * [management, client] Add management-controlled client metrics push Allow enabling/disabling client metrics push from the dashboard via account settings instead of requiring env vars on every client. - Add MetricsConfig proto message to NetbirdConfig - Add MetricsPushEnabled to account Settings (DB-persisted) - Expose metrics_push_enabled in OpenAPI and dashboard API handler - Populate MetricsConfig in sync and login responses - Client dynamically starts/stops push based on management config - NB_METRICS_PUSH_ENABLED env var overrides management when explicitly set - Add activity events for metrics push enable/disable * Remove log line * [management] Fix peer update test for MetricsConfig in NetbirdConfig Update TestUpdateAccountPeers assertions: NetbirdConfig is no longer nil in peer update responses since it now carries MetricsConfig even when STUN/TURN config is absent. * Regenerate proto files with protoc v7.34.1 * [management] Read metrics push setting in Postgres account query getAccountPgx omitted settings_metrics_push_enabled from its hand-written SELECT and Scan, so the toggle was always read back as false on Postgres and never reached clients. * [client] Fix metrics push getting stuck off after engine restart Engine restarts (backoff retries within the same login session) cancel e.ctx, which the push goroutine's lifetime was tied to. The goroutine died silently but ClientMetrics.push stayed non-nil since only an explicit stop clears it, so the next UpdatePushFromMgm call saw a "push already running" state and never restarted it. Give the Engine its own metricsCtx sourced from ConnectClient.ctx, which outlives engine restarts, so handleMetricsUpdate stops tying the push to the wrong-scoped context. Additionally make ClientMetrics.push an atomic.Pointer that the push goroutine clears via CompareAndSwap on exit, so the tracked state can never drift from the goroutine's actual lifetime regardless of which context a future caller passes in. * [management] Regenerate OpenAPI types with oapi-codegen v2.7.1 types.gen.go was regenerated with a stale local v2.6.0 binary, causing the CI git-diff check against generate.sh's pinned v2.7.1 to fail. --- client/internal/connect.go | 5 + client/internal/engine.go | 13 + client/internal/metrics/env.go | 7 + client/internal/metrics/metrics.go | 74 +- .../internals/shared/grpc/conversion.go | 19 +- management/internals/shared/grpc/server.go | 2 +- management/server/account.go | 14 +- management/server/activity/codes.go | 8 + .../handlers/accounts/accounts_handler.go | 4 + .../accounts/accounts_handler_test.go | 6 + management/server/peer_test.go | 6 +- management/server/store/sql_store.go | 8 +- management/server/types/settings.go | 4 + shared/management/http/api/openapi.yml | 4 + shared/management/http/api/types.gen.go | 3 + shared/management/proto/management.pb.go | 659 ++++++++++-------- shared/management/proto/management.proto | 6 + 17 files changed, 526 insertions(+), 316 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index 7cd2bab22..eff2c9489 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -314,6 +314,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), true) c.statusRecorder.MarkManagementConnected() + if metricsConfig := loginResp.GetNetbirdConfig().GetMetrics(); metricsConfig != nil { + c.clientMetrics.UpdatePushFromMgm(c.ctx, metricsConfig.GetEnabled()) + } + localPeerState := peer.LocalPeerState{ IP: loginResp.GetPeerConfig().GetAddress(), PubKey: myPrivateKey.PublicKey().String(), @@ -399,6 +403,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan StateManager: stateManager, UpdateManager: c.updateManager, ClientMetrics: c.clientMetrics, + MetricsCtx: c.ctx, }, mobileDependency) engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engine = engine diff --git a/client/internal/engine.go b/client/internal/engine.go index de151592d..fb1d08f5e 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -172,6 +172,7 @@ type EngineServices struct { StateManager *statemanager.Manager UpdateManager *updater.Manager ClientMetrics *metrics.ClientMetrics + MetricsCtx context.Context } // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. @@ -264,6 +265,7 @@ type Engine struct { // clientMetrics collects and pushes metrics clientMetrics *metrics.ClientMetrics + metricsCtx context.Context jobExecutor *jobexec.Executor jobExecutorWG sync.WaitGroup @@ -316,6 +318,7 @@ func NewEngine( probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), jobExecutor: jobexec.NewExecutor(), clientMetrics: services.ClientMetrics, + metricsCtx: services.MetricsCtx, updateManager: services.UpdateManager, syncStoreDir: config.StateDir, } @@ -997,6 +1000,8 @@ func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error { return fmt.Errorf("handle the flow configuration: %w", err) } + e.handleMetricsUpdate(wCfg.GetMetrics()) + if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil { log.Warnf("Failed to update DNS server config: %v", err) } @@ -1066,6 +1071,14 @@ func (e *Engine) handleFlowUpdate(config *mgmProto.FlowConfig) error { return e.flowManager.Update(flowConfig) } +func (e *Engine) handleMetricsUpdate(config *mgmProto.MetricsConfig) { + if config == nil { + return + } + log.Infof("received metrics configuration from management: enabled=%v", config.GetEnabled()) + e.clientMetrics.UpdatePushFromMgm(e.metricsCtx, config.GetEnabled()) +} + func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*nftypes.FlowConfig, error) { if config.GetInterval() == nil { return nil, errors.New("flow interval is nil") diff --git a/client/internal/metrics/env.go b/client/internal/metrics/env.go index 1f06ce484..c19dcc7f1 100644 --- a/client/internal/metrics/env.go +++ b/client/internal/metrics/env.go @@ -60,6 +60,13 @@ func getMetricsInterval() time.Duration { return interval } +// isMetricsPushEnvSet returns true if NB_METRICS_PUSH_ENABLED is explicitly set (to any value). +// When set, the env var takes full precedence over management server configuration. +func isMetricsPushEnvSet() bool { + _, set := os.LookupEnv(EnvMetricsPushEnabled) + return set +} + func isForceSending() bool { force, _ := strconv.ParseBool(os.Getenv(EnvMetricsForceSending)) return force diff --git a/client/internal/metrics/metrics.go b/client/internal/metrics/metrics.go index f18082995..cfe477107 100644 --- a/client/internal/metrics/metrics.go +++ b/client/internal/metrics/metrics.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "sync" + "sync/atomic" "time" log "github.com/sirupsen/logrus" @@ -75,7 +76,7 @@ type ClientMetrics struct { agentInfo AgentInfo mu sync.RWMutex - push *Push + push atomic.Pointer[Push] pushMu sync.Mutex wg sync.WaitGroup pushCancel context.CancelFunc @@ -167,10 +168,7 @@ func (c *ClientMetrics) UpdateAgentInfo(agentInfo AgentInfo, publicKey string) { c.agentInfo = agentInfo c.mu.Unlock() - c.pushMu.Lock() - push := c.push - c.pushMu.Unlock() - if push != nil { + if push := c.push.Load(); push != nil { push.SetPeerID(agentInfo.peerID) } } @@ -184,7 +182,7 @@ func (c *ClientMetrics) Export(w io.Writer) error { return c.impl.Export(w) } -// StartPush starts periodic pushing of metrics with the given configuration +// StartPush starts periodic pushing of metrics with the given configuration. // Precedence: PushConfig.ServerAddress > remote config server_url func (c *ClientMetrics) StartPush(ctx context.Context, config PushConfig) { if c == nil { @@ -194,11 +192,58 @@ func (c *ClientMetrics) StartPush(ctx context.Context, config PushConfig) { c.pushMu.Lock() defer c.pushMu.Unlock() - if c.push != nil { + if c.push.Load() != nil { log.Warnf("metrics push already running") return } + c.startPushLocked(ctx, config) +} + +// StopPush stops the periodic metrics push. +func (c *ClientMetrics) StopPush() { + if c == nil { + return + } + c.pushMu.Lock() + defer c.pushMu.Unlock() + + c.stopPushLocked() +} + +// UpdatePushFromMgm updates metrics push based on management server configuration. +// If NB_METRICS_PUSH_ENABLED is explicitly set (true or false), management config is ignored. +// When unset, management controls whether push is enabled. +func (c *ClientMetrics) UpdatePushFromMgm(ctx context.Context, enabled bool) { + if c == nil { + return + } + + if isMetricsPushEnvSet() { + log.Debugf("ignoring management config, env var is explicitly set: %s", EnvMetricsPushEnabled) + return + } + + c.pushMu.Lock() + defer c.pushMu.Unlock() + + if enabled { + if c.push.Load() != nil { + return + } + log.Infof("enabled metrics push by management") + c.startPushLocked(ctx, PushConfigFromEnv()) + } else { + if c.push.Load() == nil { + return + } + log.Infof("disabled metrics push by management") + c.stopPushLocked() + } +} + +// startPushLocked starts push. Caller must hold pushMu. +func (c *ClientMetrics) startPushLocked(ctx context.Context, config PushConfig) { c.mu.RLock() agentVersion := c.agentInfo.Version peerID := c.agentInfo.peerID @@ -214,26 +259,23 @@ func (c *ClientMetrics) StartPush(ctx context.Context, config PushConfig) { ctx, cancel := context.WithCancel(ctx) c.pushCancel = cancel + c.push.Store(push) c.wg.Add(1) go func() { defer c.wg.Done() push.Start(ctx) + c.push.CompareAndSwap(push, nil) }() - c.push = push } -func (c *ClientMetrics) StopPush() { - if c == nil { - return - } - c.pushMu.Lock() - defer c.pushMu.Unlock() - if c.push == nil { +// stopPushLocked stops push. Caller must hold pushMu. +func (c *ClientMetrics) stopPushLocked() { + if c.push.Load() == nil { return } c.pushCancel() c.wg.Wait() - c.push = nil + c.push.Store(nil) } diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index ced982a30..973749eb0 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -47,9 +47,16 @@ func init() { precomputedDeprecatedRemotePeersConstraint = constraint } -func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { +func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings, settings *types.Settings) *proto.NetbirdConfig { if config == nil { - return nil + if settings == nil { + return nil + } + return &proto.NetbirdConfig{ + Metrics: &proto.MetricsConfig{ + Enabled: settings.MetricsPushEnabled, + }, + } } var stuns []*proto.HostConfig @@ -110,6 +117,12 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken Relay: relayCfg, } + if settings != nil { + nbConfig.Metrics = &proto.MetricsConfig{ + Enabled: settings.MetricsPushEnabled, + } + } + return nbConfig } @@ -166,7 +179,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb Checks: toProtocolChecks(ctx, checks), } - nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings) + nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings, settings) extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings) response.NetbirdConfig = extendedConfig diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 476aaa9d6..fa06687d0 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -917,7 +917,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ - NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), + NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil, settings), PeerConfig: toPeerConfig(peer, network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, enableSSH), Checks: toProtocolChecks(ctx, postureChecks), } diff --git a/management/server/account.go b/management/server/account.go index 2c57c4637..94335cf27 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -358,7 +358,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion || oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways || oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled || - oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { + oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration || + oldSettings.MetricsPushEnabled != newSettings.MetricsPushEnabled { // Session deadline is derived from LastLogin + PeerLoginExpiration // on every Login/Sync response. Without a fan-out push, connected // peers keep the deadline they received at login time and only see @@ -409,6 +410,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleAutoUpdateAlwaysSettings(ctx, oldSettings, newSettings, userID, accountID) am.handlePeerExposeSettings(ctx, oldSettings, newSettings, userID, accountID) + am.handleMetricsPushSettings(ctx, oldSettings, newSettings, userID, accountID) if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil { return nil, err } @@ -563,6 +565,16 @@ func (am *DefaultAccountManager) handleLazyConnectionSettings(ctx context.Contex } } +func (am *DefaultAccountManager) handleMetricsPushSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { + if oldSettings.MetricsPushEnabled != newSettings.MetricsPushEnabled { + if newSettings.MetricsPushEnabled { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountMetricsPushEnabled, nil) + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountMetricsPushDisabled, nil) + } + } +} + func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled { event := activity.AccountPeerLoginExpirationEnabled diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index cfd809871..eaa638fac 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -276,6 +276,11 @@ const ( // AgentNetworkSettingsUpdated indicates that a user updated Agent Network account settings AgentNetworkSettingsUpdated Activity = 139 + // AccountMetricsPushEnabled indicates that a user enabled metrics push for the account + AccountMetricsPushEnabled Activity = 140 + // AccountMetricsPushDisabled indicates that a user disabled metrics push for the account + AccountMetricsPushDisabled Activity = 141 + AccountDeleted Activity = 99999 ) @@ -449,6 +454,9 @@ var activityMap = map[Activity]Code{ AgentNetworkSettingsUpdated: {"Agent Network settings updated", "agent_network.settings.update"}, + AccountMetricsPushEnabled: {"Account metrics push enabled", "account.setting.metrics.push.enable"}, + AccountMetricsPushDisabled: {"Account metrics push disabled", "account.setting.metrics.push.disable"}, + DomainAdded: {"Domain added", "domain.add"}, DomainDeleted: {"Domain deleted", "domain.delete"}, DomainValidated: {"Domain validated", "domain.validate"}, diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index 209d593bd..d4342bf57 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -283,6 +283,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS if req.Settings.Ipv6EnabledGroups != nil { returnSettings.IPv6EnabledGroups = *req.Settings.Ipv6EnabledGroups } + if req.Settings.MetricsPushEnabled != nil { + returnSettings.MetricsPushEnabled = *req.Settings.MetricsPushEnabled + } return returnSettings, nil } @@ -413,6 +416,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A AutoUpdateVersion: &settings.AutoUpdateVersion, AutoUpdateAlways: &settings.AutoUpdateAlways, Ipv6EnabledGroups: &settings.IPv6EnabledGroups, + MetricsPushEnabled: &settings.MetricsPushEnabled, EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled, LocalAuthDisabled: &settings.LocalAuthDisabled, LocalMfaEnabled: &settings.LocalMfaEnabled, diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 8db76719c..df89fde9a 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -129,6 +129,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), + MetricsPushEnabled: br(false), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), LocalMfaEnabled: br(false), @@ -156,6 +157,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), + MetricsPushEnabled: br(false), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), LocalMfaEnabled: br(false), @@ -183,6 +185,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateAlways: br(false), AutoUpdateVersion: sr("latest"), + MetricsPushEnabled: br(false), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), LocalMfaEnabled: br(false), @@ -210,6 +213,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), + MetricsPushEnabled: br(false), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), LocalMfaEnabled: br(false), @@ -237,6 +241,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), + MetricsPushEnabled: br(false), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), LocalMfaEnabled: br(false), @@ -264,6 +269,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), + MetricsPushEnabled: br(false), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), LocalMfaEnabled: br(false), diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 6f139e43f..6c243c4c7 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1048,7 +1048,11 @@ func testUpdateAccountPeers(t *testing.T) { for _, channel := range peerChannels { update := <-channel - assert.Nil(t, update.Update.NetbirdConfig) + assert.NotNil(t, update.Update.NetbirdConfig) + assert.Nil(t, update.Update.NetbirdConfig.Stuns) + assert.Nil(t, update.Update.NetbirdConfig.Turns) + assert.Nil(t, update.Update.NetbirdConfig.Signal) + assert.Nil(t, update.Update.NetbirdConfig.Relay) assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers)) assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules)) } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 18be1b6ed..69efe65f1 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -1605,7 +1605,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, settings_network_range_v6, settings_ipv6_enabled_groups, settings_lazy_connection_enabled, - settings_local_mfa_enabled, + settings_local_mfa_enabled, settings_metrics_push_enabled, -- Embedded ExtraSettings settings_extra_peer_approval_enabled, settings_extra_user_approval_required, settings_extra_integrated_validator, settings_extra_integrated_validator_groups @@ -1628,6 +1628,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc sIPv6EnabledGroups sql.NullString sLazyConnectionEnabled sql.NullBool sLocalMFAEnabled sql.NullBool + sMetricsPushEnabled sql.NullBool sExtraPeerApprovalEnabled sql.NullBool sExtraUserApprovalRequired sql.NullBool sExtraIntegratedValidator sql.NullString @@ -1650,7 +1651,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, &sNetworkRangeV6, &sIPv6EnabledGroups, &sLazyConnectionEnabled, - &sLocalMFAEnabled, + &sLocalMFAEnabled, &sMetricsPushEnabled, &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, ) @@ -1716,6 +1717,9 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc if sLocalMFAEnabled.Valid { account.Settings.LocalMfaEnabled = sLocalMFAEnabled.Bool } + if sMetricsPushEnabled.Valid { + account.Settings.MetricsPushEnabled = sMetricsPushEnabled.Bool + } if sJWTAllowGroups.Valid { _ = json.Unmarshal([]byte(sJWTAllowGroups.String), &account.Settings.JWTAllowGroups) } diff --git a/management/server/types/settings.go b/management/server/types/settings.go index 97ffa5e76..d17d0ef2b 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -73,6 +73,9 @@ type Settings struct { // For new accounts this defaults to the All group. IPv6EnabledGroups []string `gorm:"serializer:json"` + // MetricsPushEnabled globally enables or disables client metrics push for the account + MetricsPushEnabled bool `gorm:"default:false"` + // EmbeddedIdpEnabled indicates if the embedded identity provider is enabled. // This is a runtime-only field, not stored in the database. EmbeddedIdpEnabled bool `gorm:"-"` @@ -110,6 +113,7 @@ func (s *Settings) Copy() *Settings { AutoUpdateVersion: s.AutoUpdateVersion, AutoUpdateAlways: s.AutoUpdateAlways, IPv6EnabledGroups: slices.Clone(s.IPv6EnabledGroups), + MetricsPushEnabled: s.MetricsPushEnabled, EmbeddedIdpEnabled: s.EmbeddedIdpEnabled, LocalAuthDisabled: s.LocalAuthDisabled, LocalMfaEnabled: s.LocalMfaEnabled, diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index dffb7d7de..f11eb2c0a 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -371,6 +371,10 @@ components: description: When true, updates are installed automatically in the background. When false, updates require user interaction from the UI. type: boolean example: false + metrics_push_enabled: + description: Enables or disables client metrics push for all peers in the account + type: boolean + example: false embedded_idp_enabled: description: Indicates whether the embedded identity provider (Dex) is enabled for this account. This is a read-only field. type: boolean diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 7ea9514c0..2a766b845 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1684,6 +1684,9 @@ type AccountSettings struct { // LocalMfaEnabled Enables or disables TOTP multi-factor authentication for local users. Only applicable when the embedded identity provider is enabled. LocalMfaEnabled *bool `json:"local_mfa_enabled,omitempty"` + // MetricsPushEnabled Enables or disables client metrics push for all peers in the account + MetricsPushEnabled *bool `json:"metrics_push_enabled,omitempty"` + // NetworkRange Allows to define a custom network range for the account in CIDR format NetworkRange *string `json:"network_range,omitempty"` diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 8027c0db6..faf21e60f 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -424,7 +424,7 @@ func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { // Deprecated: Use DeviceAuthorizationFlowProvider.Descriptor instead. func (DeviceAuthorizationFlowProvider) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{33, 0} + return file_management_proto_rawDescGZIP(), []int{34, 0} } type EncryptedMessage struct { @@ -1907,9 +1907,10 @@ type NetbirdConfig struct { // a list of TURN servers Turns []*ProtectedHostConfig `protobuf:"bytes,2,rep,name=turns,proto3" json:"turns,omitempty"` // a Signal server config - Signal *HostConfig `protobuf:"bytes,3,opt,name=signal,proto3" json:"signal,omitempty"` - Relay *RelayConfig `protobuf:"bytes,4,opt,name=relay,proto3" json:"relay,omitempty"` - Flow *FlowConfig `protobuf:"bytes,5,opt,name=flow,proto3" json:"flow,omitempty"` + Signal *HostConfig `protobuf:"bytes,3,opt,name=signal,proto3" json:"signal,omitempty"` + Relay *RelayConfig `protobuf:"bytes,4,opt,name=relay,proto3" json:"relay,omitempty"` + Flow *FlowConfig `protobuf:"bytes,5,opt,name=flow,proto3" json:"flow,omitempty"` + Metrics *MetricsConfig `protobuf:"bytes,6,opt,name=metrics,proto3" json:"metrics,omitempty"` } func (x *NetbirdConfig) Reset() { @@ -1979,6 +1980,13 @@ func (x *NetbirdConfig) GetFlow() *FlowConfig { return nil } +func (x *NetbirdConfig) GetMetrics() *MetricsConfig { + if x != nil { + return x.Metrics + } + return nil +} + // HostConfig describes connection properties of some server (e.g. STUN, Signal, Management) type HostConfig struct { state protoimpl.MessageState @@ -2205,6 +2213,53 @@ func (x *FlowConfig) GetDnsCollection() bool { return false } +type MetricsConfig struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"` +} + +func (x *MetricsConfig) Reset() { + *x = MetricsConfig{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[23] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MetricsConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MetricsConfig) ProtoMessage() {} + +func (x *MetricsConfig) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[23] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MetricsConfig.ProtoReflect.Descriptor instead. +func (*MetricsConfig) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{23} +} + +func (x *MetricsConfig) GetEnabled() bool { + if x != nil { + return x.Enabled + } + return false +} + // JWTConfig represents JWT authentication configuration for validating tokens. type JWTConfig struct { state protoimpl.MessageState @@ -2224,7 +2279,7 @@ type JWTConfig struct { func (x *JWTConfig) Reset() { *x = JWTConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2237,7 +2292,7 @@ func (x *JWTConfig) String() string { func (*JWTConfig) ProtoMessage() {} func (x *JWTConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2250,7 +2305,7 @@ func (x *JWTConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use JWTConfig.ProtoReflect.Descriptor instead. func (*JWTConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{23} + return file_management_proto_rawDescGZIP(), []int{24} } func (x *JWTConfig) GetIssuer() string { @@ -2303,7 +2358,7 @@ type ProtectedHostConfig struct { func (x *ProtectedHostConfig) Reset() { *x = ProtectedHostConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2316,7 +2371,7 @@ func (x *ProtectedHostConfig) String() string { func (*ProtectedHostConfig) ProtoMessage() {} func (x *ProtectedHostConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[25] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2329,7 +2384,7 @@ func (x *ProtectedHostConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use ProtectedHostConfig.ProtoReflect.Descriptor instead. func (*ProtectedHostConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{24} + return file_management_proto_rawDescGZIP(), []int{25} } func (x *ProtectedHostConfig) GetHostConfig() *HostConfig { @@ -2380,7 +2435,7 @@ type PeerConfig struct { func (x *PeerConfig) Reset() { *x = PeerConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2393,7 +2448,7 @@ func (x *PeerConfig) String() string { func (*PeerConfig) ProtoMessage() {} func (x *PeerConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[26] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2406,7 +2461,7 @@ func (x *PeerConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use PeerConfig.ProtoReflect.Descriptor instead. func (*PeerConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{25} + return file_management_proto_rawDescGZIP(), []int{26} } func (x *PeerConfig) GetAddress() string { @@ -2486,7 +2541,7 @@ type AutoUpdateSettings struct { func (x *AutoUpdateSettings) Reset() { *x = AutoUpdateSettings{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2499,7 +2554,7 @@ func (x *AutoUpdateSettings) String() string { func (*AutoUpdateSettings) ProtoMessage() {} func (x *AutoUpdateSettings) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[27] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2512,7 +2567,7 @@ func (x *AutoUpdateSettings) ProtoReflect() protoreflect.Message { // Deprecated: Use AutoUpdateSettings.ProtoReflect.Descriptor instead. func (*AutoUpdateSettings) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{26} + return file_management_proto_rawDescGZIP(), []int{27} } func (x *AutoUpdateSettings) GetVersion() string { @@ -2567,7 +2622,7 @@ type NetworkMap struct { func (x *NetworkMap) Reset() { *x = NetworkMap{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2580,7 +2635,7 @@ func (x *NetworkMap) String() string { func (*NetworkMap) ProtoMessage() {} func (x *NetworkMap) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[28] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2593,7 +2648,7 @@ func (x *NetworkMap) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkMap.ProtoReflect.Descriptor instead. func (*NetworkMap) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{27} + return file_management_proto_rawDescGZIP(), []int{28} } func (x *NetworkMap) GetSerial() uint64 { @@ -2703,7 +2758,7 @@ type SSHAuth struct { func (x *SSHAuth) Reset() { *x = SSHAuth{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2716,7 +2771,7 @@ func (x *SSHAuth) String() string { func (*SSHAuth) ProtoMessage() {} func (x *SSHAuth) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[29] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2729,7 +2784,7 @@ func (x *SSHAuth) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHAuth.ProtoReflect.Descriptor instead. func (*SSHAuth) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{28} + return file_management_proto_rawDescGZIP(), []int{29} } func (x *SSHAuth) GetUserIDClaim() string { @@ -2764,7 +2819,7 @@ type MachineUserIndexes struct { func (x *MachineUserIndexes) Reset() { *x = MachineUserIndexes{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2777,7 +2832,7 @@ func (x *MachineUserIndexes) String() string { func (*MachineUserIndexes) ProtoMessage() {} func (x *MachineUserIndexes) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[30] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2790,7 +2845,7 @@ func (x *MachineUserIndexes) ProtoReflect() protoreflect.Message { // Deprecated: Use MachineUserIndexes.ProtoReflect.Descriptor instead. func (*MachineUserIndexes) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{29} + return file_management_proto_rawDescGZIP(), []int{30} } func (x *MachineUserIndexes) GetIndexes() []uint32 { @@ -2821,7 +2876,7 @@ type RemotePeerConfig struct { func (x *RemotePeerConfig) Reset() { *x = RemotePeerConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2834,7 +2889,7 @@ func (x *RemotePeerConfig) String() string { func (*RemotePeerConfig) ProtoMessage() {} func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[31] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2847,7 +2902,7 @@ func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use RemotePeerConfig.ProtoReflect.Descriptor instead. func (*RemotePeerConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{30} + return file_management_proto_rawDescGZIP(), []int{31} } func (x *RemotePeerConfig) GetWgPubKey() string { @@ -2902,7 +2957,7 @@ type SSHConfig struct { func (x *SSHConfig) Reset() { *x = SSHConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[31] + mi := &file_management_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2915,7 +2970,7 @@ func (x *SSHConfig) String() string { func (*SSHConfig) ProtoMessage() {} func (x *SSHConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[31] + mi := &file_management_proto_msgTypes[32] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2928,7 +2983,7 @@ func (x *SSHConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHConfig.ProtoReflect.Descriptor instead. func (*SSHConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31} + return file_management_proto_rawDescGZIP(), []int{32} } func (x *SSHConfig) GetSshEnabled() bool { @@ -2962,7 +3017,7 @@ type DeviceAuthorizationFlowRequest struct { func (x *DeviceAuthorizationFlowRequest) Reset() { *x = DeviceAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[32] + mi := &file_management_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2975,7 +3030,7 @@ func (x *DeviceAuthorizationFlowRequest) String() string { func (*DeviceAuthorizationFlowRequest) ProtoMessage() {} func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[32] + mi := &file_management_proto_msgTypes[33] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2988,7 +3043,7 @@ func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{32} + return file_management_proto_rawDescGZIP(), []int{33} } // DeviceAuthorizationFlow represents Device Authorization Flow information @@ -3007,7 +3062,7 @@ type DeviceAuthorizationFlow struct { func (x *DeviceAuthorizationFlow) Reset() { *x = DeviceAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[33] + mi := &file_management_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3020,7 +3075,7 @@ func (x *DeviceAuthorizationFlow) String() string { func (*DeviceAuthorizationFlow) ProtoMessage() {} func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[33] + mi := &file_management_proto_msgTypes[34] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3033,7 +3088,7 @@ func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlow.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{33} + return file_management_proto_rawDescGZIP(), []int{34} } func (x *DeviceAuthorizationFlow) GetProvider() DeviceAuthorizationFlowProvider { @@ -3060,7 +3115,7 @@ type PKCEAuthorizationFlowRequest struct { func (x *PKCEAuthorizationFlowRequest) Reset() { *x = PKCEAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[34] + mi := &file_management_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3073,7 +3128,7 @@ func (x *PKCEAuthorizationFlowRequest) String() string { func (*PKCEAuthorizationFlowRequest) ProtoMessage() {} func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[34] + mi := &file_management_proto_msgTypes[35] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3086,7 +3141,7 @@ func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{34} + return file_management_proto_rawDescGZIP(), []int{35} } // PKCEAuthorizationFlow represents Authorization Code Flow information @@ -3103,7 +3158,7 @@ type PKCEAuthorizationFlow struct { func (x *PKCEAuthorizationFlow) Reset() { *x = PKCEAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[35] + mi := &file_management_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3116,7 +3171,7 @@ func (x *PKCEAuthorizationFlow) String() string { func (*PKCEAuthorizationFlow) ProtoMessage() {} func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[35] + mi := &file_management_proto_msgTypes[36] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3129,7 +3184,7 @@ func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlow.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{35} + return file_management_proto_rawDescGZIP(), []int{36} } func (x *PKCEAuthorizationFlow) GetProviderConfig() *ProviderConfig { @@ -3177,7 +3232,7 @@ type ProviderConfig struct { func (x *ProviderConfig) Reset() { *x = ProviderConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[36] + mi := &file_management_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3190,7 +3245,7 @@ func (x *ProviderConfig) String() string { func (*ProviderConfig) ProtoMessage() {} func (x *ProviderConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[36] + mi := &file_management_proto_msgTypes[37] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3203,7 +3258,7 @@ func (x *ProviderConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use ProviderConfig.ProtoReflect.Descriptor instead. func (*ProviderConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{36} + return file_management_proto_rawDescGZIP(), []int{37} } func (x *ProviderConfig) GetClientID() string { @@ -3312,7 +3367,7 @@ type Route struct { func (x *Route) Reset() { *x = Route{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[37] + mi := &file_management_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3325,7 +3380,7 @@ func (x *Route) String() string { func (*Route) ProtoMessage() {} func (x *Route) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[37] + mi := &file_management_proto_msgTypes[38] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3338,7 +3393,7 @@ func (x *Route) ProtoReflect() protoreflect.Message { // Deprecated: Use Route.ProtoReflect.Descriptor instead. func (*Route) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{37} + return file_management_proto_rawDescGZIP(), []int{38} } func (x *Route) GetID() string { @@ -3427,7 +3482,7 @@ type DNSConfig struct { func (x *DNSConfig) Reset() { *x = DNSConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[38] + mi := &file_management_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3440,7 +3495,7 @@ func (x *DNSConfig) String() string { func (*DNSConfig) ProtoMessage() {} func (x *DNSConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[38] + mi := &file_management_proto_msgTypes[39] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3453,7 +3508,7 @@ func (x *DNSConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use DNSConfig.ProtoReflect.Descriptor instead. func (*DNSConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{38} + return file_management_proto_rawDescGZIP(), []int{39} } func (x *DNSConfig) GetServiceEnable() bool { @@ -3502,7 +3557,7 @@ type CustomZone struct { func (x *CustomZone) Reset() { *x = CustomZone{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[39] + mi := &file_management_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3515,7 +3570,7 @@ func (x *CustomZone) String() string { func (*CustomZone) ProtoMessage() {} func (x *CustomZone) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[39] + mi := &file_management_proto_msgTypes[40] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3528,7 +3583,7 @@ func (x *CustomZone) ProtoReflect() protoreflect.Message { // Deprecated: Use CustomZone.ProtoReflect.Descriptor instead. func (*CustomZone) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{39} + return file_management_proto_rawDescGZIP(), []int{40} } func (x *CustomZone) GetDomain() string { @@ -3575,7 +3630,7 @@ type SimpleRecord struct { func (x *SimpleRecord) Reset() { *x = SimpleRecord{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[40] + mi := &file_management_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3588,7 +3643,7 @@ func (x *SimpleRecord) String() string { func (*SimpleRecord) ProtoMessage() {} func (x *SimpleRecord) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[40] + mi := &file_management_proto_msgTypes[41] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3601,7 +3656,7 @@ func (x *SimpleRecord) ProtoReflect() protoreflect.Message { // Deprecated: Use SimpleRecord.ProtoReflect.Descriptor instead. func (*SimpleRecord) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{40} + return file_management_proto_rawDescGZIP(), []int{41} } func (x *SimpleRecord) GetName() string { @@ -3654,7 +3709,7 @@ type NameServerGroup struct { func (x *NameServerGroup) Reset() { *x = NameServerGroup{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[41] + mi := &file_management_proto_msgTypes[42] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3667,7 +3722,7 @@ func (x *NameServerGroup) String() string { func (*NameServerGroup) ProtoMessage() {} func (x *NameServerGroup) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[41] + mi := &file_management_proto_msgTypes[42] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3680,7 +3735,7 @@ func (x *NameServerGroup) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServerGroup.ProtoReflect.Descriptor instead. func (*NameServerGroup) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{41} + return file_management_proto_rawDescGZIP(), []int{42} } func (x *NameServerGroup) GetNameServers() []*NameServer { @@ -3725,7 +3780,7 @@ type NameServer struct { func (x *NameServer) Reset() { *x = NameServer{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[42] + mi := &file_management_proto_msgTypes[43] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3738,7 +3793,7 @@ func (x *NameServer) String() string { func (*NameServer) ProtoMessage() {} func (x *NameServer) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[42] + mi := &file_management_proto_msgTypes[43] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3751,7 +3806,7 @@ func (x *NameServer) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServer.ProtoReflect.Descriptor instead. func (*NameServer) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{42} + return file_management_proto_rawDescGZIP(), []int{43} } func (x *NameServer) GetIP() string { @@ -3802,7 +3857,7 @@ type FirewallRule struct { func (x *FirewallRule) Reset() { *x = FirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[43] + mi := &file_management_proto_msgTypes[44] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3815,7 +3870,7 @@ func (x *FirewallRule) String() string { func (*FirewallRule) ProtoMessage() {} func (x *FirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[43] + mi := &file_management_proto_msgTypes[44] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3828,7 +3883,7 @@ func (x *FirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use FirewallRule.ProtoReflect.Descriptor instead. func (*FirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{43} + return file_management_proto_rawDescGZIP(), []int{44} } // Deprecated: Do not use. @@ -3907,7 +3962,7 @@ type NetworkAddress struct { func (x *NetworkAddress) Reset() { *x = NetworkAddress{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[44] + mi := &file_management_proto_msgTypes[45] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3920,7 +3975,7 @@ func (x *NetworkAddress) String() string { func (*NetworkAddress) ProtoMessage() {} func (x *NetworkAddress) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[44] + mi := &file_management_proto_msgTypes[45] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3933,7 +3988,7 @@ func (x *NetworkAddress) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkAddress.ProtoReflect.Descriptor instead. func (*NetworkAddress) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{44} + return file_management_proto_rawDescGZIP(), []int{45} } func (x *NetworkAddress) GetNetIP() string { @@ -3961,7 +4016,7 @@ type Checks struct { func (x *Checks) Reset() { *x = Checks{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[45] + mi := &file_management_proto_msgTypes[46] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3974,7 +4029,7 @@ func (x *Checks) String() string { func (*Checks) ProtoMessage() {} func (x *Checks) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[45] + mi := &file_management_proto_msgTypes[46] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3987,7 +4042,7 @@ func (x *Checks) ProtoReflect() protoreflect.Message { // Deprecated: Use Checks.ProtoReflect.Descriptor instead. func (*Checks) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{45} + return file_management_proto_rawDescGZIP(), []int{46} } func (x *Checks) GetFiles() []string { @@ -4012,7 +4067,7 @@ type PortInfo struct { func (x *PortInfo) Reset() { *x = PortInfo{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[46] + mi := &file_management_proto_msgTypes[47] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4025,7 +4080,7 @@ func (x *PortInfo) String() string { func (*PortInfo) ProtoMessage() {} func (x *PortInfo) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[46] + mi := &file_management_proto_msgTypes[47] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4038,7 +4093,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. func (*PortInfo) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{46} + return file_management_proto_rawDescGZIP(), []int{47} } func (m *PortInfo) GetPortSelection() isPortInfo_PortSelection { @@ -4109,7 +4164,7 @@ type RouteFirewallRule struct { func (x *RouteFirewallRule) Reset() { *x = RouteFirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[47] + mi := &file_management_proto_msgTypes[48] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4122,7 +4177,7 @@ func (x *RouteFirewallRule) String() string { func (*RouteFirewallRule) ProtoMessage() {} func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[47] + mi := &file_management_proto_msgTypes[48] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4135,7 +4190,7 @@ func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use RouteFirewallRule.ProtoReflect.Descriptor instead. func (*RouteFirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{47} + return file_management_proto_rawDescGZIP(), []int{48} } func (x *RouteFirewallRule) GetSourceRanges() []string { @@ -4226,7 +4281,7 @@ type ForwardingRule struct { func (x *ForwardingRule) Reset() { *x = ForwardingRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[48] + mi := &file_management_proto_msgTypes[49] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4239,7 +4294,7 @@ func (x *ForwardingRule) String() string { func (*ForwardingRule) ProtoMessage() {} func (x *ForwardingRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[48] + mi := &file_management_proto_msgTypes[49] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4252,7 +4307,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead. func (*ForwardingRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{48} + return file_management_proto_rawDescGZIP(), []int{49} } func (x *ForwardingRule) GetProtocol() RuleProtocol { @@ -4301,7 +4356,7 @@ type ExposeServiceRequest struct { func (x *ExposeServiceRequest) Reset() { *x = ExposeServiceRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[49] + mi := &file_management_proto_msgTypes[50] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4314,7 +4369,7 @@ func (x *ExposeServiceRequest) String() string { func (*ExposeServiceRequest) ProtoMessage() {} func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[49] + mi := &file_management_proto_msgTypes[50] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4327,7 +4382,7 @@ func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead. func (*ExposeServiceRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{49} + return file_management_proto_rawDescGZIP(), []int{50} } func (x *ExposeServiceRequest) GetPort() uint32 { @@ -4400,7 +4455,7 @@ type ExposeServiceResponse struct { func (x *ExposeServiceResponse) Reset() { *x = ExposeServiceResponse{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[50] + mi := &file_management_proto_msgTypes[51] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4413,7 +4468,7 @@ func (x *ExposeServiceResponse) String() string { func (*ExposeServiceResponse) ProtoMessage() {} func (x *ExposeServiceResponse) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[50] + mi := &file_management_proto_msgTypes[51] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4426,7 +4481,7 @@ func (x *ExposeServiceResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ExposeServiceResponse.ProtoReflect.Descriptor instead. func (*ExposeServiceResponse) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{50} + return file_management_proto_rawDescGZIP(), []int{51} } func (x *ExposeServiceResponse) GetServiceName() string { @@ -4468,7 +4523,7 @@ type RenewExposeRequest struct { func (x *RenewExposeRequest) Reset() { *x = RenewExposeRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[51] + mi := &file_management_proto_msgTypes[52] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4481,7 +4536,7 @@ func (x *RenewExposeRequest) String() string { func (*RenewExposeRequest) ProtoMessage() {} func (x *RenewExposeRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[51] + mi := &file_management_proto_msgTypes[52] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4494,7 +4549,7 @@ func (x *RenewExposeRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RenewExposeRequest.ProtoReflect.Descriptor instead. func (*RenewExposeRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{51} + return file_management_proto_rawDescGZIP(), []int{52} } func (x *RenewExposeRequest) GetDomain() string { @@ -4513,7 +4568,7 @@ type RenewExposeResponse struct { func (x *RenewExposeResponse) Reset() { *x = RenewExposeResponse{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[52] + mi := &file_management_proto_msgTypes[53] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4526,7 +4581,7 @@ func (x *RenewExposeResponse) String() string { func (*RenewExposeResponse) ProtoMessage() {} func (x *RenewExposeResponse) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[52] + mi := &file_management_proto_msgTypes[53] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4539,7 +4594,7 @@ func (x *RenewExposeResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RenewExposeResponse.ProtoReflect.Descriptor instead. func (*RenewExposeResponse) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{52} + return file_management_proto_rawDescGZIP(), []int{53} } type StopExposeRequest struct { @@ -4553,7 +4608,7 @@ type StopExposeRequest struct { func (x *StopExposeRequest) Reset() { *x = StopExposeRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[53] + mi := &file_management_proto_msgTypes[54] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4566,7 +4621,7 @@ func (x *StopExposeRequest) String() string { func (*StopExposeRequest) ProtoMessage() {} func (x *StopExposeRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[53] + mi := &file_management_proto_msgTypes[54] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4579,7 +4634,7 @@ func (x *StopExposeRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StopExposeRequest.ProtoReflect.Descriptor instead. func (*StopExposeRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{53} + return file_management_proto_rawDescGZIP(), []int{54} } func (x *StopExposeRequest) GetDomain() string { @@ -4598,7 +4653,7 @@ type StopExposeResponse struct { func (x *StopExposeResponse) Reset() { *x = StopExposeResponse{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[54] + mi := &file_management_proto_msgTypes[55] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4611,7 +4666,7 @@ func (x *StopExposeResponse) String() string { func (*StopExposeResponse) ProtoMessage() {} func (x *StopExposeResponse) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[54] + mi := &file_management_proto_msgTypes[55] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4624,7 +4679,7 @@ func (x *StopExposeResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StopExposeResponse.ProtoReflect.Descriptor instead. func (*StopExposeResponse) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{54} + return file_management_proto_rawDescGZIP(), []int{55} } type PortInfo_Range struct { @@ -4639,7 +4694,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[56] + mi := &file_management_proto_msgTypes[57] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4652,7 +4707,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[56] + mi := &file_management_proto_msgTypes[57] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4665,7 +4720,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. func (*PortInfo_Range) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{46, 0} + return file_management_proto_rawDescGZIP(), []int{47, 0} } func (x *PortInfo_Range) GetStart() uint32 { @@ -4915,7 +4970,7 @@ var file_management_proto_rawDesc = []byte{ 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, - 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0xff, 0x01, 0x0a, 0x0d, 0x4e, 0x65, 0x74, 0x62, + 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0xb4, 0x02, 0x0a, 0x0d, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2c, 0x0a, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, @@ -4931,43 +4986,49 @@ var file_management_proto_rawDesc = []byte{ 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x2a, 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x22, 0x98, 0x01, 0x0a, 0x0a, 0x48, 0x6f, - 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, 0x3b, 0x0a, 0x08, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, - 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x02, 0x12, - 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x54, - 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, - 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, - 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, - 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, - 0x75, 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x03, 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, - 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, - 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, - 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, - 0x12, 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x69, - 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, - 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, - 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, - 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, - 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x08, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x22, 0xa3, 0x01, 0x0a, 0x09, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x12, 0x33, 0x0a, 0x07, 0x6d, 0x65, 0x74, + 0x72, 0x69, 0x63, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x22, 0x98, + 0x01, 0x0a, 0x0a, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, + 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, + 0x3b, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, + 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, 0x08, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, + 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, + 0x54, 0x50, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, 0x12, + 0x08, 0x0a, 0x04, 0x44, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, 0x6c, + 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, 0x0c, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, + 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, + 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, + 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, + 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, + 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, + 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, + 0x61, 0x74, 0x75, 0x72, 0x65, 0x12, 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, + 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x18, 0x0a, 0x07, + 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, + 0x72, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, + 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, + 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, + 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, + 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x29, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x72, + 0x69, 0x63, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x64, 0x22, 0xa3, 0x01, 0x0a, 0x09, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x16, 0x0a, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x75, 0x64, @@ -5435,7 +5496,7 @@ func file_management_proto_rawDescGZIP() []byte { } var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 8) -var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 57) +var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 58) var file_management_proto_goTypes = []interface{}{ (JobStatus)(0), // 0: management.JobStatus (PeerCapability)(0), // 1: management.PeerCapability @@ -5468,42 +5529,43 @@ var file_management_proto_goTypes = []interface{}{ (*HostConfig)(nil), // 28: management.HostConfig (*RelayConfig)(nil), // 29: management.RelayConfig (*FlowConfig)(nil), // 30: management.FlowConfig - (*JWTConfig)(nil), // 31: management.JWTConfig - (*ProtectedHostConfig)(nil), // 32: management.ProtectedHostConfig - (*PeerConfig)(nil), // 33: management.PeerConfig - (*AutoUpdateSettings)(nil), // 34: management.AutoUpdateSettings - (*NetworkMap)(nil), // 35: management.NetworkMap - (*SSHAuth)(nil), // 36: management.SSHAuth - (*MachineUserIndexes)(nil), // 37: management.MachineUserIndexes - (*RemotePeerConfig)(nil), // 38: management.RemotePeerConfig - (*SSHConfig)(nil), // 39: management.SSHConfig - (*DeviceAuthorizationFlowRequest)(nil), // 40: management.DeviceAuthorizationFlowRequest - (*DeviceAuthorizationFlow)(nil), // 41: management.DeviceAuthorizationFlow - (*PKCEAuthorizationFlowRequest)(nil), // 42: management.PKCEAuthorizationFlowRequest - (*PKCEAuthorizationFlow)(nil), // 43: management.PKCEAuthorizationFlow - (*ProviderConfig)(nil), // 44: management.ProviderConfig - (*Route)(nil), // 45: management.Route - (*DNSConfig)(nil), // 46: management.DNSConfig - (*CustomZone)(nil), // 47: management.CustomZone - (*SimpleRecord)(nil), // 48: management.SimpleRecord - (*NameServerGroup)(nil), // 49: management.NameServerGroup - (*NameServer)(nil), // 50: management.NameServer - (*FirewallRule)(nil), // 51: management.FirewallRule - (*NetworkAddress)(nil), // 52: management.NetworkAddress - (*Checks)(nil), // 53: management.Checks - (*PortInfo)(nil), // 54: management.PortInfo - (*RouteFirewallRule)(nil), // 55: management.RouteFirewallRule - (*ForwardingRule)(nil), // 56: management.ForwardingRule - (*ExposeServiceRequest)(nil), // 57: management.ExposeServiceRequest - (*ExposeServiceResponse)(nil), // 58: management.ExposeServiceResponse - (*RenewExposeRequest)(nil), // 59: management.RenewExposeRequest - (*RenewExposeResponse)(nil), // 60: management.RenewExposeResponse - (*StopExposeRequest)(nil), // 61: management.StopExposeRequest - (*StopExposeResponse)(nil), // 62: management.StopExposeResponse - nil, // 63: management.SSHAuth.MachineUsersEntry - (*PortInfo_Range)(nil), // 64: management.PortInfo.Range - (*timestamppb.Timestamp)(nil), // 65: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 66: google.protobuf.Duration + (*MetricsConfig)(nil), // 31: management.MetricsConfig + (*JWTConfig)(nil), // 32: management.JWTConfig + (*ProtectedHostConfig)(nil), // 33: management.ProtectedHostConfig + (*PeerConfig)(nil), // 34: management.PeerConfig + (*AutoUpdateSettings)(nil), // 35: management.AutoUpdateSettings + (*NetworkMap)(nil), // 36: management.NetworkMap + (*SSHAuth)(nil), // 37: management.SSHAuth + (*MachineUserIndexes)(nil), // 38: management.MachineUserIndexes + (*RemotePeerConfig)(nil), // 39: management.RemotePeerConfig + (*SSHConfig)(nil), // 40: management.SSHConfig + (*DeviceAuthorizationFlowRequest)(nil), // 41: management.DeviceAuthorizationFlowRequest + (*DeviceAuthorizationFlow)(nil), // 42: management.DeviceAuthorizationFlow + (*PKCEAuthorizationFlowRequest)(nil), // 43: management.PKCEAuthorizationFlowRequest + (*PKCEAuthorizationFlow)(nil), // 44: management.PKCEAuthorizationFlow + (*ProviderConfig)(nil), // 45: management.ProviderConfig + (*Route)(nil), // 46: management.Route + (*DNSConfig)(nil), // 47: management.DNSConfig + (*CustomZone)(nil), // 48: management.CustomZone + (*SimpleRecord)(nil), // 49: management.SimpleRecord + (*NameServerGroup)(nil), // 50: management.NameServerGroup + (*NameServer)(nil), // 51: management.NameServer + (*FirewallRule)(nil), // 52: management.FirewallRule + (*NetworkAddress)(nil), // 53: management.NetworkAddress + (*Checks)(nil), // 54: management.Checks + (*PortInfo)(nil), // 55: management.PortInfo + (*RouteFirewallRule)(nil), // 56: management.RouteFirewallRule + (*ForwardingRule)(nil), // 57: management.ForwardingRule + (*ExposeServiceRequest)(nil), // 58: management.ExposeServiceRequest + (*ExposeServiceResponse)(nil), // 59: management.ExposeServiceResponse + (*RenewExposeRequest)(nil), // 60: management.RenewExposeRequest + (*RenewExposeResponse)(nil), // 61: management.RenewExposeResponse + (*StopExposeRequest)(nil), // 62: management.StopExposeRequest + (*StopExposeResponse)(nil), // 63: management.StopExposeResponse + nil, // 64: management.SSHAuth.MachineUsersEntry + (*PortInfo_Range)(nil), // 65: management.PortInfo.Range + (*timestamppb.Timestamp)(nil), // 66: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 67: google.protobuf.Duration } var file_management_proto_depIdxs = []int32{ 11, // 0: management.JobRequest.bundle:type_name -> management.BundleParameters @@ -5511,99 +5573,100 @@ var file_management_proto_depIdxs = []int32{ 12, // 2: management.JobResponse.bundle:type_name -> management.BundleResult 21, // 3: management.SyncRequest.meta:type_name -> management.PeerSystemMeta 27, // 4: management.SyncResponse.netbirdConfig:type_name -> management.NetbirdConfig - 33, // 5: management.SyncResponse.peerConfig:type_name -> management.PeerConfig - 38, // 6: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig - 35, // 7: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap - 53, // 8: management.SyncResponse.Checks:type_name -> management.Checks - 65, // 9: management.SyncResponse.sessionExpiresAt:type_name -> google.protobuf.Timestamp + 34, // 5: management.SyncResponse.peerConfig:type_name -> management.PeerConfig + 39, // 6: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig + 36, // 7: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap + 54, // 8: management.SyncResponse.Checks:type_name -> management.Checks + 66, // 9: management.SyncResponse.sessionExpiresAt:type_name -> google.protobuf.Timestamp 21, // 10: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta 21, // 11: management.LoginRequest.meta:type_name -> management.PeerSystemMeta 17, // 12: management.LoginRequest.peerKeys:type_name -> management.PeerKeys - 52, // 13: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress + 53, // 13: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress 18, // 14: management.PeerSystemMeta.environment:type_name -> management.Environment 19, // 15: management.PeerSystemMeta.files:type_name -> management.File 20, // 16: management.PeerSystemMeta.flags:type_name -> management.Flags 1, // 17: management.PeerSystemMeta.capabilities:type_name -> management.PeerCapability 27, // 18: management.LoginResponse.netbirdConfig:type_name -> management.NetbirdConfig - 33, // 19: management.LoginResponse.peerConfig:type_name -> management.PeerConfig - 53, // 20: management.LoginResponse.Checks:type_name -> management.Checks - 65, // 21: management.LoginResponse.sessionExpiresAt:type_name -> google.protobuf.Timestamp + 34, // 19: management.LoginResponse.peerConfig:type_name -> management.PeerConfig + 54, // 20: management.LoginResponse.Checks:type_name -> management.Checks + 66, // 21: management.LoginResponse.sessionExpiresAt:type_name -> google.protobuf.Timestamp 21, // 22: management.ExtendAuthSessionRequest.meta:type_name -> management.PeerSystemMeta - 65, // 23: management.ExtendAuthSessionResponse.sessionExpiresAt:type_name -> google.protobuf.Timestamp - 65, // 24: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 66, // 23: management.ExtendAuthSessionResponse.sessionExpiresAt:type_name -> google.protobuf.Timestamp + 66, // 24: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp 28, // 25: management.NetbirdConfig.stuns:type_name -> management.HostConfig - 32, // 26: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig + 33, // 26: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig 28, // 27: management.NetbirdConfig.signal:type_name -> management.HostConfig 29, // 28: management.NetbirdConfig.relay:type_name -> management.RelayConfig 30, // 29: management.NetbirdConfig.flow:type_name -> management.FlowConfig - 6, // 30: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol - 66, // 31: management.FlowConfig.interval:type_name -> google.protobuf.Duration - 28, // 32: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig - 39, // 33: management.PeerConfig.sshConfig:type_name -> management.SSHConfig - 34, // 34: management.PeerConfig.autoUpdate:type_name -> management.AutoUpdateSettings - 33, // 35: management.NetworkMap.peerConfig:type_name -> management.PeerConfig - 38, // 36: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig - 45, // 37: management.NetworkMap.Routes:type_name -> management.Route - 46, // 38: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig - 38, // 39: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig - 51, // 40: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 55, // 41: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule - 56, // 42: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule - 36, // 43: management.NetworkMap.sshAuth:type_name -> management.SSHAuth - 63, // 44: management.SSHAuth.machine_users:type_name -> management.SSHAuth.MachineUsersEntry - 39, // 45: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 31, // 46: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig - 7, // 47: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 44, // 48: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 44, // 49: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 49, // 50: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 47, // 51: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 48, // 52: management.CustomZone.Records:type_name -> management.SimpleRecord - 50, // 53: management.NameServerGroup.NameServers:type_name -> management.NameServer - 3, // 54: management.FirewallRule.Direction:type_name -> management.RuleDirection - 4, // 55: management.FirewallRule.Action:type_name -> management.RuleAction - 2, // 56: management.FirewallRule.Protocol:type_name -> management.RuleProtocol - 54, // 57: management.FirewallRule.PortInfo:type_name -> management.PortInfo - 64, // 58: management.PortInfo.range:type_name -> management.PortInfo.Range - 4, // 59: management.RouteFirewallRule.action:type_name -> management.RuleAction - 2, // 60: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol - 54, // 61: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo - 2, // 62: management.ForwardingRule.protocol:type_name -> management.RuleProtocol - 54, // 63: management.ForwardingRule.destinationPort:type_name -> management.PortInfo - 54, // 64: management.ForwardingRule.translatedPort:type_name -> management.PortInfo - 5, // 65: management.ExposeServiceRequest.protocol:type_name -> management.ExposeProtocol - 37, // 66: management.SSHAuth.MachineUsersEntry.value:type_name -> management.MachineUserIndexes - 8, // 67: management.ManagementService.Login:input_type -> management.EncryptedMessage - 8, // 68: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 26, // 69: management.ManagementService.GetServerKey:input_type -> management.Empty - 26, // 70: management.ManagementService.isHealthy:input_type -> management.Empty - 8, // 71: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 8, // 72: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 8, // 73: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 8, // 74: management.ManagementService.Logout:input_type -> management.EncryptedMessage - 8, // 75: management.ManagementService.Job:input_type -> management.EncryptedMessage - 8, // 76: management.ManagementService.ExtendAuthSession:input_type -> management.EncryptedMessage - 8, // 77: management.ManagementService.CreateExpose:input_type -> management.EncryptedMessage - 8, // 78: management.ManagementService.RenewExpose:input_type -> management.EncryptedMessage - 8, // 79: management.ManagementService.StopExpose:input_type -> management.EncryptedMessage - 8, // 80: management.ManagementService.Login:output_type -> management.EncryptedMessage - 8, // 81: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 25, // 82: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 26, // 83: management.ManagementService.isHealthy:output_type -> management.Empty - 8, // 84: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 8, // 85: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 26, // 86: management.ManagementService.SyncMeta:output_type -> management.Empty - 26, // 87: management.ManagementService.Logout:output_type -> management.Empty - 8, // 88: management.ManagementService.Job:output_type -> management.EncryptedMessage - 8, // 89: management.ManagementService.ExtendAuthSession:output_type -> management.EncryptedMessage - 8, // 90: management.ManagementService.CreateExpose:output_type -> management.EncryptedMessage - 8, // 91: management.ManagementService.RenewExpose:output_type -> management.EncryptedMessage - 8, // 92: management.ManagementService.StopExpose:output_type -> management.EncryptedMessage - 80, // [80:93] is the sub-list for method output_type - 67, // [67:80] is the sub-list for method input_type - 67, // [67:67] is the sub-list for extension type_name - 67, // [67:67] is the sub-list for extension extendee - 0, // [0:67] is the sub-list for field type_name + 31, // 30: management.NetbirdConfig.metrics:type_name -> management.MetricsConfig + 6, // 31: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol + 67, // 32: management.FlowConfig.interval:type_name -> google.protobuf.Duration + 28, // 33: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig + 40, // 34: management.PeerConfig.sshConfig:type_name -> management.SSHConfig + 35, // 35: management.PeerConfig.autoUpdate:type_name -> management.AutoUpdateSettings + 34, // 36: management.NetworkMap.peerConfig:type_name -> management.PeerConfig + 39, // 37: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig + 46, // 38: management.NetworkMap.Routes:type_name -> management.Route + 47, // 39: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig + 39, // 40: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig + 52, // 41: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule + 56, // 42: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 57, // 43: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule + 37, // 44: management.NetworkMap.sshAuth:type_name -> management.SSHAuth + 64, // 45: management.SSHAuth.machine_users:type_name -> management.SSHAuth.MachineUsersEntry + 40, // 46: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 32, // 47: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig + 7, // 48: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 45, // 49: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 45, // 50: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 50, // 51: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 48, // 52: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 49, // 53: management.CustomZone.Records:type_name -> management.SimpleRecord + 51, // 54: management.NameServerGroup.NameServers:type_name -> management.NameServer + 3, // 55: management.FirewallRule.Direction:type_name -> management.RuleDirection + 4, // 56: management.FirewallRule.Action:type_name -> management.RuleAction + 2, // 57: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 55, // 58: management.FirewallRule.PortInfo:type_name -> management.PortInfo + 65, // 59: management.PortInfo.range:type_name -> management.PortInfo.Range + 4, // 60: management.RouteFirewallRule.action:type_name -> management.RuleAction + 2, // 61: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 55, // 62: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 2, // 63: management.ForwardingRule.protocol:type_name -> management.RuleProtocol + 55, // 64: management.ForwardingRule.destinationPort:type_name -> management.PortInfo + 55, // 65: management.ForwardingRule.translatedPort:type_name -> management.PortInfo + 5, // 66: management.ExposeServiceRequest.protocol:type_name -> management.ExposeProtocol + 38, // 67: management.SSHAuth.MachineUsersEntry.value:type_name -> management.MachineUserIndexes + 8, // 68: management.ManagementService.Login:input_type -> management.EncryptedMessage + 8, // 69: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 26, // 70: management.ManagementService.GetServerKey:input_type -> management.Empty + 26, // 71: management.ManagementService.isHealthy:input_type -> management.Empty + 8, // 72: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 8, // 73: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 8, // 74: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 8, // 75: management.ManagementService.Logout:input_type -> management.EncryptedMessage + 8, // 76: management.ManagementService.Job:input_type -> management.EncryptedMessage + 8, // 77: management.ManagementService.ExtendAuthSession:input_type -> management.EncryptedMessage + 8, // 78: management.ManagementService.CreateExpose:input_type -> management.EncryptedMessage + 8, // 79: management.ManagementService.RenewExpose:input_type -> management.EncryptedMessage + 8, // 80: management.ManagementService.StopExpose:input_type -> management.EncryptedMessage + 8, // 81: management.ManagementService.Login:output_type -> management.EncryptedMessage + 8, // 82: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 25, // 83: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 26, // 84: management.ManagementService.isHealthy:output_type -> management.Empty + 8, // 85: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 8, // 86: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 26, // 87: management.ManagementService.SyncMeta:output_type -> management.Empty + 26, // 88: management.ManagementService.Logout:output_type -> management.Empty + 8, // 89: management.ManagementService.Job:output_type -> management.EncryptedMessage + 8, // 90: management.ManagementService.ExtendAuthSession:output_type -> management.EncryptedMessage + 8, // 91: management.ManagementService.CreateExpose:output_type -> management.EncryptedMessage + 8, // 92: management.ManagementService.RenewExpose:output_type -> management.EncryptedMessage + 8, // 93: management.ManagementService.StopExpose:output_type -> management.EncryptedMessage + 81, // [81:94] is the sub-list for method output_type + 68, // [68:81] is the sub-list for method input_type + 68, // [68:68] is the sub-list for extension type_name + 68, // [68:68] is the sub-list for extension extendee + 0, // [0:68] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -5889,7 +5952,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*JWTConfig); i { + switch v := v.(*MetricsConfig); i { case 0: return &v.state case 1: @@ -5901,7 +5964,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProtectedHostConfig); i { + switch v := v.(*JWTConfig); i { case 0: return &v.state case 1: @@ -5913,7 +5976,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PeerConfig); i { + switch v := v.(*ProtectedHostConfig); i { case 0: return &v.state case 1: @@ -5925,7 +5988,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AutoUpdateSettings); i { + switch v := v.(*PeerConfig); i { case 0: return &v.state case 1: @@ -5937,7 +6000,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkMap); i { + switch v := v.(*AutoUpdateSettings); i { case 0: return &v.state case 1: @@ -5949,7 +6012,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SSHAuth); i { + switch v := v.(*NetworkMap); i { case 0: return &v.state case 1: @@ -5961,7 +6024,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*MachineUserIndexes); i { + switch v := v.(*SSHAuth); i { case 0: return &v.state case 1: @@ -5973,7 +6036,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RemotePeerConfig); i { + switch v := v.(*MachineUserIndexes); i { case 0: return &v.state case 1: @@ -5985,7 +6048,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SSHConfig); i { + switch v := v.(*RemotePeerConfig); i { case 0: return &v.state case 1: @@ -5997,7 +6060,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlowRequest); i { + switch v := v.(*SSHConfig); i { case 0: return &v.state case 1: @@ -6009,7 +6072,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlow); i { + switch v := v.(*DeviceAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -6021,7 +6084,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlowRequest); i { + switch v := v.(*DeviceAuthorizationFlow); i { case 0: return &v.state case 1: @@ -6033,7 +6096,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlow); i { + switch v := v.(*PKCEAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -6045,7 +6108,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProviderConfig); i { + switch v := v.(*PKCEAuthorizationFlow); i { case 0: return &v.state case 1: @@ -6057,7 +6120,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[37].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { + switch v := v.(*ProviderConfig); i { case 0: return &v.state case 1: @@ -6069,7 +6132,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[38].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DNSConfig); i { + switch v := v.(*Route); i { case 0: return &v.state case 1: @@ -6081,7 +6144,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[39].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CustomZone); i { + switch v := v.(*DNSConfig); i { case 0: return &v.state case 1: @@ -6093,7 +6156,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[40].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SimpleRecord); i { + switch v := v.(*CustomZone); i { case 0: return &v.state case 1: @@ -6105,7 +6168,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[41].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServerGroup); i { + switch v := v.(*SimpleRecord); i { case 0: return &v.state case 1: @@ -6117,7 +6180,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[42].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServer); i { + switch v := v.(*NameServerGroup); i { case 0: return &v.state case 1: @@ -6129,7 +6192,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[43].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*FirewallRule); i { + switch v := v.(*NameServer); i { case 0: return &v.state case 1: @@ -6141,7 +6204,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[44].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkAddress); i { + switch v := v.(*FirewallRule); i { case 0: return &v.state case 1: @@ -6153,7 +6216,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[45].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Checks); i { + switch v := v.(*NetworkAddress); i { case 0: return &v.state case 1: @@ -6165,7 +6228,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[46].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PortInfo); i { + switch v := v.(*Checks); i { case 0: return &v.state case 1: @@ -6177,7 +6240,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[47].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RouteFirewallRule); i { + switch v := v.(*PortInfo); i { case 0: return &v.state case 1: @@ -6189,7 +6252,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[48].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ForwardingRule); i { + switch v := v.(*RouteFirewallRule); i { case 0: return &v.state case 1: @@ -6201,7 +6264,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[49].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ExposeServiceRequest); i { + switch v := v.(*ForwardingRule); i { case 0: return &v.state case 1: @@ -6213,7 +6276,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[50].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ExposeServiceResponse); i { + switch v := v.(*ExposeServiceRequest); i { case 0: return &v.state case 1: @@ -6225,7 +6288,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[51].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RenewExposeRequest); i { + switch v := v.(*ExposeServiceResponse); i { case 0: return &v.state case 1: @@ -6237,7 +6300,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[52].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RenewExposeResponse); i { + switch v := v.(*RenewExposeRequest); i { case 0: return &v.state case 1: @@ -6249,7 +6312,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[53].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*StopExposeRequest); i { + switch v := v.(*RenewExposeResponse); i { case 0: return &v.state case 1: @@ -6261,6 +6324,18 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[54].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*StopExposeRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[55].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*StopExposeResponse); i { case 0: return &v.state @@ -6272,7 +6347,7 @@ func file_management_proto_init() { return nil } } - file_management_proto_msgTypes[56].Exporter = func(v interface{}, i int) interface{} { + file_management_proto_msgTypes[57].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PortInfo_Range); i { case 0: return &v.state @@ -6291,7 +6366,7 @@ func file_management_proto_init() { file_management_proto_msgTypes[2].OneofWrappers = []interface{}{ (*JobResponse_Bundle)(nil), } - file_management_proto_msgTypes[46].OneofWrappers = []interface{}{ + file_management_proto_msgTypes[47].OneofWrappers = []interface{}{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } @@ -6301,7 +6376,7 @@ func file_management_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, NumEnums: 8, - NumMessages: 57, + NumMessages: 58, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 990a72a63..6b41a78d0 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -312,6 +312,8 @@ message NetbirdConfig { RelayConfig relay = 4; FlowConfig flow = 5; + + MetricsConfig metrics = 6; } // HostConfig describes connection properties of some server (e.g. STUN, Signal, Management) @@ -350,6 +352,10 @@ message FlowConfig { bool dnsCollection = 8; } +message MetricsConfig { + bool enabled = 1; +} + // JWTConfig represents JWT authentication configuration for validating tokens. message JWTConfig { string issuer = 1; From ff04ffb534cfb308e70c2b3eb3c0595db8ea490f Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 1 Jul 2026 14:51:06 +0200 Subject: [PATCH 04/19] [client] Fix pointer comparisons in profile config apply (#6622) apply() compared several *bool/*int ConfigInput fields against the Config fields by pointer identity instead of by value, so any non-nil input always looked "changed" and triggered a spurious log line plus an unconditional config rewrite even when the value was unchanged. --- client/internal/profilemanager/config.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 5a71a981e..8ffcb16f2 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -386,7 +386,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } - if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor { + if input.NetworkMonitor != nil && (config.NetworkMonitor == nil || *input.NetworkMonitor != *config.NetworkMonitor) { log.Infof("switching Network Monitor to %t", *input.NetworkMonitor) config.NetworkMonitor = input.NetworkMonitor updated = true @@ -454,7 +454,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } - if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot { + if input.EnableSSHRoot != nil && (config.EnableSSHRoot == nil || *input.EnableSSHRoot != *config.EnableSSHRoot) { if *input.EnableSSHRoot { log.Infof("enabling SSH root login") } else { @@ -464,7 +464,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } - if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP { + if input.EnableSSHSFTP != nil && (config.EnableSSHSFTP == nil || *input.EnableSSHSFTP != *config.EnableSSHSFTP) { if *input.EnableSSHSFTP { log.Infof("enabling SSH SFTP subsystem") } else { @@ -474,7 +474,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } - if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding { + if input.EnableSSHLocalPortForwarding != nil && (config.EnableSSHLocalPortForwarding == nil || *input.EnableSSHLocalPortForwarding != *config.EnableSSHLocalPortForwarding) { if *input.EnableSSHLocalPortForwarding { log.Infof("enabling SSH local port forwarding") } else { @@ -484,7 +484,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } - if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding { + if input.EnableSSHRemotePortForwarding != nil && (config.EnableSSHRemotePortForwarding == nil || *input.EnableSSHRemotePortForwarding != *config.EnableSSHRemotePortForwarding) { if *input.EnableSSHRemotePortForwarding { log.Infof("enabling SSH remote port forwarding") } else { @@ -494,7 +494,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } - if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth { + if input.DisableSSHAuth != nil && (config.DisableSSHAuth == nil || *input.DisableSSHAuth != *config.DisableSSHAuth) { if *input.DisableSSHAuth { log.Infof("disabling SSH authentication") } else { @@ -504,7 +504,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } - if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL { + if input.SSHJWTCacheTTL != nil && (config.SSHJWTCacheTTL == nil || *input.SSHJWTCacheTTL != *config.SSHJWTCacheTTL) { log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL) config.SSHJWTCacheTTL = input.SSHJWTCacheTTL updated = true @@ -587,7 +587,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } - if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications { + if input.DisableNotifications != nil && (config.DisableNotifications == nil || *input.DisableNotifications != *config.DisableNotifications) { if *input.DisableNotifications { log.Infof("disabling notifications") } else { From 2ab99eefa627ce92d9117cdece532f68a9b13cf4 Mon Sep 17 00:00:00 2001 From: Denis Ivanov <74763652+den-dw@users.noreply.github.com> Date: Wed, 1 Jul 2026 15:53:13 +0300 Subject: [PATCH 05/19] [management] detach JWT group sync write from request cancellation (#6621) The HTTP auth middleware runs syncUserJWTGroups in the request context. The dashboard SPA routinely aborts in-flight requests on re-render or navigation, which cancels the request context mid-write and rolls back the group-sync DB transaction. The error is logged but swallowed, so the synced groups silently never persist (users.auto_groups stays empty) while the failing log line repeats on every request. Detach the sync from the request's cancellation with context.WithoutCancel so the write can commit regardless of the client connection; the store already bounds the transaction with its own timeout. Add a regression test asserting the sync receives a non-cancelled context even when the originating request is cancelled. --- .../server/http/middleware/auth_middleware.go | 6 +- .../http/middleware/auth_middleware_test.go | 60 +++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 34df0de23..ba8c66241 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -152,7 +152,11 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] return err } - err = m.syncUserJWTGroups(ctx, userAuth) + // Detach the group-sync write from the request's cancellation: the dashboard + // SPA aborts in-flight requests on re-render, which would otherwise cancel the + // DB transaction mid-write and silently drop the synced groups. Context values + // (request id, logger) are preserved; the store bounds the tx with its own timeout. + err = m.syncUserJWTGroups(context.WithoutCancel(ctx), userAuth) if err != nil { log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err) } diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 24cf8fce5..a34554660 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -241,6 +241,66 @@ func TestAuthMiddleware_Handler(t *testing.T) { } } +// TestAuthMiddleware_SyncUserJWTGroupsDetachedFromRequestCancellation ensures the +// JWT group sync write is not bound to the request context. The dashboard SPA +// routinely aborts in-flight requests on re-render/navigation; if the sync ran in +// the request context, the cancellation would roll back the DB transaction and the +// synced groups would silently never persist. The sync must receive a context that +// is not cancelled even when the originating request is. +func TestAuthMiddleware_SyncUserJWTGroupsDetachedFromRequestCancellation(t *testing.T) { + var ( + syncCalled bool + syncCtxErr error + ) + + mockAuth := &auth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups, + MarkPATUsedFunc: mockMarkPATUsed, + GetPATInfoFunc: mockGetAccountInfoFromPAT, + } + + disabledLimiter := NewAPIRateLimiter(nil) + disabledLimiter.SetEnabled(false) + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) error { + syncCalled = true + syncCtxErr = ctx.Err() + return nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + disabledLimiter, + nil, + func(_ context.Context, _, _, _ string) bool { return false }, + ) + + handlerToTest := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + // Simulate the dashboard aborting the request: it arrives already cancelled. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + req := httptest.NewRequest("GET", "http://testing/test", nil).WithContext(ctx) + req.Header.Set("Authorization", "Bearer "+JWT) + rec := httptest.NewRecorder() + + handlerToTest.ServeHTTP(rec, req) + + if !syncCalled { + t.Fatal("syncUserJWTGroups was not called") + } + if syncCtxErr != nil { + t.Fatalf("syncUserJWTGroups received a cancelled context (%v); the group-sync write must be detached from request cancellation", syncCtxErr) + } +} + func TestAuthMiddleware_RateLimiting(t *testing.T) { mockAuth := &auth.MockManager{ ValidateAndParseTokenFunc: mockValidateAndParseToken, From 7c0d8cbae06ba2d30444af022f6727ae91a36a85 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 1 Jul 2026 17:23:50 +0200 Subject: [PATCH 06/19] [misc] Run agent-network e2e nightly + on manual dispatch (#6629) The suite builds combined/proxy/client from source and drives live provider traffic, so running it per push/PR is too costly. Switch to a nightly schedule (03:00 UTC) plus workflow_dispatch, and drop the now-unneeded fork guard that only mattered for pull_request runs. --- .github/workflows/agent-network-e2e.yml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/agent-network-e2e.yml b/.github/workflows/agent-network-e2e.yml index c041bfbfa..d78e3bbd3 100644 --- a/.github/workflows/agent-network-e2e.yml +++ b/.github/workflows/agent-network-e2e.yml @@ -1,10 +1,10 @@ name: Agent Network E2E on: - push: - branches: - - main - pull_request: + # Nightly at 03:00 UTC, plus on demand from the Actions tab. + schedule: + - cron: "0 3 * * *" + workflow_dispatch: concurrency: group: ${{ github.workflow }}-${{ github.ref }} @@ -13,7 +13,6 @@ concurrency: jobs: e2e: name: Agent Network E2E - if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository runs-on: ubuntu-latest timeout-minutes: 45 steps: From 0aa0f7c76b58d9bec8655b5558190419227dbd43 Mon Sep 17 00:00:00 2001 From: Riccardo Manfrin <3090891+riccardomanfrin@users.noreply.github.com> Date: Wed, 1 Jul 2026 19:10:50 +0200 Subject: [PATCH 07/19] [client] wire client -> mgmt is healthy check to proper gRPC API (#6421) --- shared/management/client/grpc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 6f5172376..781e66a3e 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -536,7 +536,7 @@ func (c *GrpcClient) IsHealthy() bool { ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout) defer cancel() - _, err := c.realClient.GetServerKey(ctx, &proto.Empty{}) + _, err := c.realClient.IsHealthy(ctx, &proto.Empty{}) if err != nil { c.notifyDisconnected(err) log.Warnf("health check returned: %s", err) From eb422a5cd3c50a401873e704c673a353ea59abd8 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 1 Jul 2026 20:43:15 +0200 Subject: [PATCH 08/19] [management,proxy] Add per-provider skip_tls_verification for agent-network (#6630) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [management,proxy] Add per-provider skip_tls_verification for agent-network Let agent-network providers opt into skipping upstream TLS verification for self-hosted / internal gateways behind a private or self-signed cert. - provider: add SkipTLSVerification (persisted via AutoMigrate) with request/response mapping (nil on update preserves, explicit false clears). - openapi: skip_tls_verification on the provider request + response; types regenerated. - synthesizer: carry the flag into the llm_router route config so it reaches the proxy. - proxy: llm_router sets it on the UpstreamRewrite mutation, and the reverse proxy applies roundtrip.WithSkipTLSVerify per selected route when forwarding upstream (the router dials per provider, so a per-target flag alone wouldn't cover it). - tests: synthesizer route config carries the flag, router rewrite propagates it, and the request/response round-trip incl. update semantics. * [e2e] Validate per-provider skip_tls_verification end to end Add a self-signed HTTPS upstream (nginx) to the harness and a test that provisions two providers on that same upstream — one with skip_tls_verification=true, one false — behind one proxy + client. The skip=true provider's chat reaches the upstream (200); the skip=false provider's fails the TLS handshake (5xx). Same upstream, opposite outcome, which proves the flag is honoured per provider (a single target-level flag could not, since all of an account's providers share one synthesised target). * [e2e] WaitProxyPeer: require >=1 connected peer, not exact 1/1 Each proxy container registers a fresh WireGuard key and its peer is not removed on teardown, so proxy peers from earlier tests linger in the account as disconnected. WaitProxyPeer matched the exact string "1/1 Connected", which failed once a second proxy-using test ran in the same package (status "1/2"). Parse the "Peers count: X/Y Connected" line and wait for X>=1 instead: only the live proxy can be connected, and the caller's subsequent chat is the real end-to-end assertion. Fixes the CI failure of TestProviderSkipTLSVerification (runs after TestProvidersMatrix). --- e2e/agentnetwork/skiptls_test.go | 140 ++++++++++++++++++ e2e/harness/client.go | 44 +++++- e2e/harness/upstream.go | 107 +++++++++++++ .../modules/agentnetwork/synthesizer.go | 5 + .../modules/agentnetwork/synthesizer_test.go | 35 +++++ .../modules/agentnetwork/types/provider.go | 25 +++- .../agentnetwork/types/provider_test.go | 44 ++++++ .../middleware/builtin/llm_router/factory.go | 4 + .../builtin/llm_router/middleware.go | 5 +- .../builtin/llm_router/middleware_test.go | 35 +++++ proxy/internal/middleware/types.go | 4 + proxy/internal/proxy/reverseproxy.go | 5 + shared/management/http/api/openapi.yml | 9 ++ shared/management/http/api/types.gen.go | 6 + 14 files changed, 456 insertions(+), 12 deletions(-) create mode 100644 e2e/agentnetwork/skiptls_test.go create mode 100644 e2e/harness/upstream.go create mode 100644 management/internals/modules/agentnetwork/types/provider_test.go diff --git a/e2e/agentnetwork/skiptls_test.go b/e2e/agentnetwork/skiptls_test.go new file mode 100644 index 000000000..077fd6005 --- /dev/null +++ b/e2e/agentnetwork/skiptls_test.go @@ -0,0 +1,140 @@ +//go:build e2e + +package agentnetwork + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/e2e/harness" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// TestProviderSkipTLSVerification proves skip_tls_verification is per-provider: +// two providers share one self-signed upstream, one skipping TLS verification +// and one not. The skip=true provider's chat reaches the upstream and returns +// 200; the skip=false provider's chat fails at the TLS handshake — same +// upstream, opposite outcome. This is the behaviour a target-level flag could +// not give, since all of an account's providers share one synthesised target. +func TestProviderSkipTLSVerification(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + up, err := harness.StartFakeUpstream(ctx, srv) + require.NoError(t, err, "start self-signed upstream") + t.Cleanup(func() { _ = up.Terminate(context.Background()) }) + + grp, err := srv.API().Groups.Create(ctx, api.PostApiGroupsJSONRequestBody{Name: "e2e-skiptls"}) + require.NoError(t, err, "create group") + t.Cleanup(func() { _ = srv.API().Groups.Delete(context.Background(), grp.Id) }) + + ephemeral := false + sk, err := srv.API().SetupKeys.Create(ctx, api.PostApiSetupKeysJSONRequestBody{ + Name: "e2e-skiptls-client", + Type: "reusable", + ExpiresIn: 86400, + UsageLimit: 0, + AutoGroups: []string{grp.Id}, + Ephemeral: &ephemeral, + }) + require.NoError(t, err, "mint setup key") + require.NotEmpty(t, sk.Key, "setup key plaintext") + + const ( + insecureModel = "insecure-model" + secureModel = "secure-model" + ) + + // Two providers on the SAME self-signed upstream, distinguished only by their + // skip_tls_verification and a unique model string so the router picks each + // unambiguously. + newReq := func(name, model string, skip bool) api.AgentNetworkProviderRequest { + key := "sk-dummy-e2e" + return api.AgentNetworkProviderRequest{ + Name: name, + ProviderId: "openai_api", + UpstreamUrl: up.URL, + ApiKey: &key, + Enabled: ptr(true), + SkipTlsVerification: ptr(skip), + Models: &[]api.AgentNetworkProviderModel{ + {Id: model, InputPer1k: 0.001, OutputPer1k: 0.002}, + }, + } + } + + // First create bootstraps the account cluster. + insecureReq := newReq("skip-tls", insecureModel, true) + insecureReq.BootstrapCluster = ptr(harness.AgentNetworkCluster) + insecureProv, err := srv.CreateProvider(ctx, insecureReq) + require.NoError(t, err, "create skip-tls provider") + t.Cleanup(func() { _ = srv.DeleteProvider(context.Background(), insecureProv.Id) }) + require.True(t, insecureProv.SkipTlsVerification, "response must echo skip_tls_verification=true") + + secureProv, err := srv.CreateProvider(ctx, newReq("verify-tls", secureModel, false)) + require.NoError(t, err, "create verify-tls provider") + t.Cleanup(func() { _ = srv.DeleteProvider(context.Background(), secureProv.Id) }) + require.False(t, secureProv.SkipTlsVerification, "response must echo skip_tls_verification=false") + + enabled := true + pol, err := srv.CreatePolicy(ctx, api.AgentNetworkPolicyRequest{ + Name: "e2e-skiptls-allow", + Enabled: &enabled, + SourceGroups: []string{grp.Id}, + DestinationProviderIds: []string{insecureProv.Id, secureProv.Id}, + }) + require.NoError(t, err, "create policy") + t.Cleanup(func() { _ = srv.DeletePolicy(context.Background(), pol.Id) }) + + settings, err := srv.GetSettings(ctx) + require.NoError(t, err, "read settings") + require.NotEmpty(t, settings.Endpoint, "endpoint must be assigned") + + proxyToken, err := srv.CreateProxyTokenCLI(ctx, "e2e-skiptls-proxy") + require.NoError(t, err, "mint proxy token") + px, err := harness.StartProxy(ctx, srv, proxyToken) + require.NoError(t, err, "start proxy") + t.Cleanup(func() { _ = px.Terminate(context.Background()) }) + + cl, err := harness.StartClient(ctx, srv, sk.Key) + require.NoError(t, err, "start client") + t.Cleanup(func() { _ = cl.Terminate(context.Background()) }) + + require.NoError(t, cl.WaitConnected(ctx, 90*time.Second), "client must connect to management") + if err := cl.WaitProxyPeer(ctx, 180*time.Second); err != nil { + t.Fatalf("client did not see the proxy peer: %v\n=== proxy logs ===\n%s", err, px.Logs(context.Background())) + } + proxyIP, err := cl.ResolveProxyIP(ctx, settings.Endpoint) + require.NoError(t, err, "resolve endpoint to proxy IP") + + // Positive: skip=true reaches the self-signed upstream. Retry to absorb + // tunnel/DNS jitter on the first call; success also proves the path works. + var code int + var body string + deadline := time.Now().Add(90 * time.Second) + for time.Now().Before(deadline) { + c, b, cerr := cl.Chat(ctx, settings.Endpoint, proxyIP, harness.WireChat, insecureModel, "Reply with exactly: pong", "e2e-skiptls-insecure") + if cerr == nil { + code, body = c, b + if code == 200 { + break + } + } + time.Sleep(5 * time.Second) + } + require.Equal(t, 200, code, + "skip_tls_verification=true must reach the self-signed upstream; body: %s\n=== upstream logs ===\n%s\n=== proxy logs ===\n%s", + body, up.Logs(context.Background()), px.Logs(context.Background())) + + // Negative: skip=false must fail the TLS handshake to the SAME upstream. The + // path is already proven working, so a non-200 here is the cert rejection. + secureCode, secureBody, cerr := cl.Chat(ctx, settings.Endpoint, proxyIP, harness.WireChat, secureModel, "Reply with exactly: pong", "e2e-skiptls-secure") + require.NoError(t, cerr, "the chat call itself must complete (proxy returns an error status, not a transport error)") + require.NotEqual(t, 200, secureCode, + "skip_tls_verification=false must NOT reach the self-signed upstream; got %d, body: %s", secureCode, secureBody) + require.GreaterOrEqual(t, secureCode, 500, + "a TLS verification failure should surface as a 5xx from the proxy; got %d, body: %s", secureCode, secureBody) +} diff --git a/e2e/harness/client.go b/e2e/harness/client.go index cf7ef8945..1ce8c0f6e 100644 --- a/e2e/harness/client.go +++ b/e2e/harness/client.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "os/exec" + "strconv" "strings" "time" @@ -108,9 +109,48 @@ func (cl *Client) WaitConnected(ctx context.Context, timeout time.Duration) erro return cl.pollStatus(ctx, timeout, "Management: Connected") } -// WaitProxyPeer polls until the client sees the proxy peer connected (1/1). +// WaitProxyPeer polls until the client sees at least one connected peer — the +// proxy serving the agent-network endpoint. It requires ">=1 connected" rather +// than an exact "1/1" because proxy peers from earlier tests linger in the +// account as disconnected (each proxy container registers a fresh WireGuard key +// and the peer is not removed on teardown), so the count is e.g. "1/2". Only the +// live proxy can be connected, and the caller's subsequent chat is the real +// end-to-end assertion. func (cl *Client) WaitProxyPeer(ctx context.Context, timeout time.Duration) error { - return cl.pollStatus(ctx, timeout, "1/1 Connected") + deadline := time.Now().Add(timeout) + var last string + for time.Now().Before(deadline) { + out, _ := cl.Status(ctx) + last = out + if connectedPeers(out) >= 1 { + return nil + } + time.Sleep(3 * time.Second) + } + return fmt.Errorf("timed out waiting for a connected proxy peer; last status:\n%s", last) +} + +// connectedPeers parses the "Peers count: X/Y Connected" line from `netbird +// status` and returns X (the connected count), or 0 when absent/unparseable. +func connectedPeers(status string) int { + for _, line := range strings.Split(status, "\n") { + line = strings.TrimSpace(line) + rest, ok := strings.CutPrefix(line, "Peers count:") + if !ok { + continue + } + rest = strings.TrimSpace(rest) + slash := strings.IndexByte(rest, '/') + if slash <= 0 { + return 0 + } + n, err := strconv.Atoi(strings.TrimSpace(rest[:slash])) + if err != nil { + return 0 + } + return n + } + return 0 } func (cl *Client) pollStatus(ctx context.Context, timeout time.Duration, want string) error { diff --git a/e2e/harness/upstream.go b/e2e/harness/upstream.go new file mode 100644 index 000000000..cdffe63b9 --- /dev/null +++ b/e2e/harness/upstream.go @@ -0,0 +1,107 @@ +//go:build e2e + +package harness + +import ( + "context" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/docker/docker/api/types/container" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + fakeUpstreamImage = "nginx:alpine" + fakeUpstreamAlias = "fakeupstream" + fakeUpstreamPort = "443/tcp" +) + +// fakeUpstreamNginxConf serves a canned OpenAI-shaped chat completion for any +// path over a self-signed certificate, so the proxy reaches it only when the +// provider opts into skipping TLS verification. +const fakeUpstreamNginxConf = `pid /tmp/nginx.pid; +events {} +http { + server { + listen 443 ssl; + ssl_certificate /certs/tls.crt; + ssl_certificate_key /certs/tls.key; + location / { + default_type application/json; + return 200 '{"id":"chatcmpl-e2e","object":"chat.completion","choices":[{"index":0,"message":{"role":"assistant","content":"pong"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}'; + } + } +} +` + +// FakeUpstream is a self-signed HTTPS server on the combined server's network, +// used to exercise provider skip_tls_verification: a proxy that verifies the +// certificate rejects it, one that skips verification reaches it. +type FakeUpstream struct { + container testcontainers.Container + workDir string + // URL is the upstream URL providers point at (https://). + URL string +} + +// StartFakeUpstream runs the self-signed upstream on the shared network. +func StartFakeUpstream(ctx context.Context, c *Combined) (*FakeUpstream, error) { + workDir, err := os.MkdirTemp("/tmp", "nb-e2e-upstream-*") + if err != nil { + return nil, fmt.Errorf("create upstream work dir: %w", err) + } + // Widen so the (non-root worker) nginx container can traverse the bind mount. + if err := os.Chmod(workDir, 0o755); err != nil { //nolint:gosec // throwaway e2e cert dir + return nil, fmt.Errorf("chmod upstream dir: %w", err) + } + if err := writeSelfSignedCert(workDir, []string{fakeUpstreamAlias}); err != nil { + return nil, err + } + if err := os.WriteFile(filepath.Join(workDir, "nginx.conf"), []byte(fakeUpstreamNginxConf), 0o644); err != nil { //nolint:gosec // non-secret e2e config + return nil, fmt.Errorf("write nginx conf: %w", err) + } + + req := testcontainers.ContainerRequest{ + Image: fakeUpstreamImage, + ExposedPorts: []string{fakeUpstreamPort}, + Networks: []string{c.network.Name}, + NetworkAliases: map[string][]string{c.network.Name: {fakeUpstreamAlias}}, + Cmd: []string{"nginx", "-c", "/certs/nginx.conf", "-g", "daemon off;"}, + HostConfigModifier: func(hc *container.HostConfig) { + hc.Binds = append(hc.Binds, workDir+":/certs:ro") + }, + WaitingFor: wait.ForListeningPort(fakeUpstreamPort).WithStartupTimeout(60 * time.Second), + } + + ctr, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + if err != nil { + _ = os.RemoveAll(workDir) + return nil, fmt.Errorf("start fake upstream container: %w", err) + } + + return &FakeUpstream{container: ctr, workDir: workDir, URL: "https://" + fakeUpstreamAlias}, nil +} + +// Logs returns the upstream container logs, for diagnostics on failure. +func (u *FakeUpstream) Logs(ctx context.Context) string { + return containerLogs(ctx, u.container) +} + +// Terminate stops the upstream container and cleans its work dir. +func (u *FakeUpstream) Terminate(ctx context.Context) error { + var err error + if u.container != nil { + err = u.container.Terminate(ctx) + } + if u.workDir != "" { + _ = os.RemoveAll(u.workDir) + } + return err +} diff --git a/management/internals/modules/agentnetwork/synthesizer.go b/management/internals/modules/agentnetwork/synthesizer.go index 9814d1a11..74ac91845 100644 --- a/management/internals/modules/agentnetwork/synthesizer.go +++ b/management/internals/modules/agentnetwork/synthesizer.go @@ -366,6 +366,10 @@ type routerProviderRoute struct { // + refreshes the OAuth token at request time instead of injecting a static // AuthHeaderValue. GCPServiceAccountKeyB64 string `json:"gcp_sa_key_b64,omitempty"` + // SkipTLSVerify disables upstream TLS certificate verification when the + // proxy dials this provider's upstream. For self-hosted / internal gateways + // behind a private or self-signed certificate. + SkipTLSVerify bool `json:"skip_tls_verify,omitempty"` } // indexProviderGroups walks the enabled policies and returns, per @@ -450,6 +454,7 @@ func buildRouterConfigJSON(providers []*types.Provider, groupIndex map[string][] Vertex: catalog.IsVertexPathStyle(p.ProviderID), Bedrock: catalog.IsBedrockPathStyle(p.ProviderID), GCPServiceAccountKeyB64: gcpSAKeyB64, + SkipTLSVerify: p.SkipTLSVerification, }) } out, err := json.Marshal(cfg) diff --git a/management/internals/modules/agentnetwork/synthesizer_test.go b/management/internals/modules/agentnetwork/synthesizer_test.go index 0b07f27b3..9d55bddf1 100644 --- a/management/internals/modules/agentnetwork/synthesizer_test.go +++ b/management/internals/modules/agentnetwork/synthesizer_test.go @@ -1057,6 +1057,41 @@ func TestSynthesizeServices_UpstreamURLPath_FlowsToRouter(t *testing.T) { "upstream path must be carried so the router can disambiguate same-model providers; trailing slash trimmed for stable string-prefix matching") } +func TestSynthesizeServices_SkipTLSVerification_FlowsToRouter(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStore := store.NewMockStore(ctrl) + + // A provider fronting a self-hosted / internal gateway opts into skipping + // upstream TLS verification; the synthesiser must carry it into the router + // route so the proxy dials that upstream insecurely. + provider := newSynthTestProvider() + provider.SkipTLSVerification = true + policy := newSynthTestPolicy(provider.ID, "grp-eng", "") + + expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(), + []*types.Provider{provider}, + []*types.Policy{policy}, + []*types.Guardrail{}) + + services, err := SynthesizeServices(ctx, mockStore, testAccountID) + require.NoError(t, err) + require.Len(t, services, 1) + + mws := services[0].Targets[0].Options.Middlewares + var routerCfg routerConfig + for _, m := range mws { + if m.ID == middlewareIDLLMRouter { + require.NoError(t, json.Unmarshal(m.ConfigJSON, &routerCfg)) + break + } + } + require.Len(t, routerCfg.Providers, 1) + assert.True(t, routerCfg.Providers[0].SkipTLSVerify, + "provider skip_tls_verification must flow into the router route") +} + func TestSynthesizeServices_UnknownProviderID_FailsClosed(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) diff --git a/management/internals/modules/agentnetwork/types/provider.go b/management/internals/modules/agentnetwork/types/provider.go index 28c8a94e2..2e3195481 100644 --- a/management/internals/modules/agentnetwork/types/provider.go +++ b/management/internals/modules/agentnetwork/types/provider.go @@ -46,6 +46,11 @@ type Provider struct { // Empty means all catalog models are allowed at catalog prices. Models []ProviderModel `gorm:"serializer:json"` Enabled bool + // SkipTLSVerification disables upstream TLS certificate verification for + // this provider's URL. For self-hosted / internal gateways fronted by a + // private or self-signed certificate. The synthesiser propagates it into + // the router route so the proxy dials that provider's upstream insecurely. + SkipTLSVerification bool `gorm:"column:skip_tls_verification"` // SessionPrivateKey + SessionPublicKey are the ed25519 keypair the // synthesised reverse-proxy service uses to sign / verify session // JWTs after a successful OIDC handshake. Generated once on @@ -129,6 +134,9 @@ func (p *Provider) FromAPIRequest(req *api.AgentNetworkProviderRequest) { if req.Enabled != nil { p.Enabled = *req.Enabled } + if req.SkipTlsVerification != nil { + p.SkipTLSVerification = *req.SkipTlsVerification + } // Identity-header overrides for catalogs flagged Customizable. // nil pointer = "field omitted on the wire" → leave the stored // value untouched (per the openapi description). Empty string is @@ -155,14 +163,15 @@ func (p *Provider) ToAPIResponse() *api.AgentNetworkProvider { created := p.CreatedAt updated := p.UpdatedAt resp := &api.AgentNetworkProvider{ - Id: p.ID, - ProviderId: p.ProviderID, - Name: p.Name, - UpstreamUrl: p.UpstreamURL, - Models: models, - Enabled: p.Enabled, - CreatedAt: &created, - UpdatedAt: &updated, + Id: p.ID, + ProviderId: p.ProviderID, + Name: p.Name, + UpstreamUrl: p.UpstreamURL, + Models: models, + Enabled: p.Enabled, + SkipTlsVerification: p.SkipTLSVerification, + CreatedAt: &created, + UpdatedAt: &updated, } if len(p.ExtraValues) > 0 { out := make(map[string]string, len(p.ExtraValues)) diff --git a/management/internals/modules/agentnetwork/types/provider_test.go b/management/internals/modules/agentnetwork/types/provider_test.go new file mode 100644 index 000000000..1195499e7 --- /dev/null +++ b/management/internals/modules/agentnetwork/types/provider_test.go @@ -0,0 +1,44 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// TestProvider_SkipTLSVerification_RoundTrip covers the request→provider→ +// response mapping of skip_tls_verification, including the update semantics +// (nil pointer preserves, explicit false clears). +func TestProvider_SkipTLSVerification_RoundTrip(t *testing.T) { + enable := true + disable := false + + base := func() *api.AgentNetworkProviderRequest { + return &api.AgentNetworkProviderRequest{ + ProviderId: "openai_api", + Name: "internal", + UpstreamUrl: "https://gw.internal", + } + } + + p := NewProvider("acc-1") + + req := base() + req.SkipTlsVerification = &enable + p.FromAPIRequest(req) + assert.True(t, p.SkipTLSVerification, "create with skip_tls_verification=true must set the field") + assert.True(t, p.ToAPIResponse().SkipTlsVerification, "response must surface skip_tls_verification") + + // Omitting the field on update leaves the stored value untouched. + p.FromAPIRequest(base()) + assert.True(t, p.SkipTLSVerification, "omitting skip_tls_verification on update must preserve it") + + // Explicit false clears it. + req = base() + req.SkipTlsVerification = &disable + p.FromAPIRequest(req) + assert.False(t, p.SkipTLSVerification, "explicit false must clear skip_tls_verification") + assert.False(t, p.ToAPIResponse().SkipTlsVerification, "response must reflect the cleared value") +} diff --git a/proxy/internal/middleware/builtin/llm_router/factory.go b/proxy/internal/middleware/builtin/llm_router/factory.go index 3c3b607ac..938a23ebe 100644 --- a/proxy/internal/middleware/builtin/llm_router/factory.go +++ b/proxy/internal/middleware/builtin/llm_router/factory.go @@ -59,6 +59,10 @@ type ProviderRoute struct { // (instead of the static AuthHeaderValue) — so the gateway holds a durable // Vertex credential rather than a 1-hour token. GCPServiceAccountKeyB64 string `json:"gcp_sa_key_b64,omitempty"` + // SkipTLSVerify disables upstream TLS certificate verification when dialing + // this route's upstream. For self-hosted / internal gateways behind a + // private or self-signed certificate. + SkipTLSVerify bool `json:"skip_tls_verify,omitempty"` } // Config is the on-wire configuration accepted by the factory. An diff --git a/proxy/internal/middleware/builtin/llm_router/middleware.go b/proxy/internal/middleware/builtin/llm_router/middleware.go index 73cc59c95..2aaeb1089 100644 --- a/proxy/internal/middleware/builtin/llm_router/middleware.go +++ b/proxy/internal/middleware/builtin/llm_router/middleware.go @@ -615,8 +615,9 @@ func (m *Middleware) allowWithRoute(route ProviderRoute, userGroups []string) *m // path is silently dropped and the gateway returns a 4xx for // the malformed URL. Empty value leaves the original // target's path untouched. - Path: route.UpstreamPath, - StripHeaders: append([]string(nil), strippedAuthHeaders...), + Path: route.UpstreamPath, + StripHeaders: append([]string(nil), strippedAuthHeaders...), + SkipTLSVerify: route.SkipTLSVerify, } authValue := route.AuthHeaderValue if route.GCPServiceAccountKeyB64 != "" { diff --git a/proxy/internal/middleware/builtin/llm_router/middleware_test.go b/proxy/internal/middleware/builtin/llm_router/middleware_test.go index 8ae03c5ba..425c383c1 100644 --- a/proxy/internal/middleware/builtin/llm_router/middleware_test.go +++ b/proxy/internal/middleware/builtin/llm_router/middleware_test.go @@ -107,6 +107,41 @@ func TestRouter_HappyPath(t *testing.T) { assert.Equal(t, "allow", dec, "decision metadata must be allow on a match") } +func TestRouter_SkipTLSVerifyPropagates(t *testing.T) { + base := ProviderRoute{ + ID: "internal-gw", + Models: []string{"gpt-4o"}, + AllowedGroupIDs: []string{defaultTestGroup}, + UpstreamScheme: "https", + UpstreamHost: "gateway.internal", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer sk-test-123", + } + + t.Run("enabled", func(t *testing.T) { + route := base + route.SkipTLSVerify = true + mw := New(Config{Providers: []ProviderRoute{route}}) + + out, err := mw.Invoke(context.Background(), newInputWithModel("gpt-4o")) + require.NoError(t, err) + require.NotNil(t, out.Mutations, "matched route must emit mutations") + require.NotNil(t, out.Mutations.RewriteUpstream, "matched route must emit upstream rewrite") + assert.True(t, out.Mutations.RewriteUpstream.SkipTLSVerify, + "skip_tls_verify on the route must ride on the upstream rewrite") + }) + + t.Run("default off", func(t *testing.T) { + mw := New(Config{Providers: []ProviderRoute{base}}) + + out, err := mw.Invoke(context.Background(), newInputWithModel("gpt-4o")) + require.NoError(t, err) + require.NotNil(t, out.Mutations.RewriteUpstream, "matched route must emit upstream rewrite") + assert.False(t, out.Mutations.RewriteUpstream.SkipTLSVerify, + "skip_tls_verify must default to false when the route does not set it") + }) +} + func TestRouter_MissingModel(t *testing.T) { mw := New(Config{Providers: []ProviderRoute{{ ID: "openai-prod", diff --git a/proxy/internal/middleware/types.go b/proxy/internal/middleware/types.go index 1b49e6159..1ed5c9d88 100644 --- a/proxy/internal/middleware/types.go +++ b/proxy/internal/middleware/types.go @@ -243,6 +243,10 @@ type UpstreamRewrite struct { StripPathPrefix string AuthHeader *AuthHeader StripHeaders []string + // SkipTLSVerify, when true, makes the proxy dial the rewritten upstream + // without verifying its TLS certificate. Set by llm_router from the + // provider's skip_tls_verification for self-hosted / internal gateways. + SkipTLSVerify bool } // AuthHeader is a single name/value pair the proxy injects on the diff --git a/proxy/internal/proxy/reverseproxy.go b/proxy/internal/proxy/reverseproxy.go index 2c0304ecd..835a1c0b2 100644 --- a/proxy/internal/proxy/reverseproxy.go +++ b/proxy/internal/proxy/reverseproxy.go @@ -346,6 +346,11 @@ func (p *ReverseProxy) forwardUpstream(respWriter http.ResponseWriter, r *http.R r.Host = effectiveURL.Host applyUpstreamHeaders(r, upstreamRewrite) stripUpstreamPathPrefix(r, upstreamRewrite.StripPathPrefix) + // A router-selected route (e.g. agent-network provider) can opt into + // skipping upstream TLS verification per its provider config. + if upstreamRewrite.SkipTLSVerify { + ctx = roundtrip.WithSkipTLSVerify(ctx) + } } rp := &httputil.ReverseProxy{ diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index f11eb2c0a..f746b31f4 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -5119,6 +5119,10 @@ components: type: boolean description: Whether the provider is enabled. example: true + skip_tls_verification: + type: boolean + description: Whether upstream TLS certificate verification is skipped when the proxy dials this provider's URL. Intended for self-hosted / internal gateways behind a private or self-signed certificate. + example: false created_at: type: string format: date-time @@ -5138,6 +5142,7 @@ components: - upstream_url - models - enabled + - skip_tls_verification - created_at - updated_at AgentNetworkProviderRequest: @@ -5190,6 +5195,10 @@ components: type: boolean description: Whether the provider is enabled. Defaults to true on create. example: true + skip_tls_verification: + type: boolean + description: Skip upstream TLS certificate verification when the proxy dials this provider's URL. For self-hosted / internal gateways behind a private or self-signed certificate. Defaults to false. When omitted on update, the stored value is left unchanged. + example: false required: - provider_id - name diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 2a766b845..3b587c4bf 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -2224,6 +2224,9 @@ type AgentNetworkProvider struct { // ProviderId Catalog identifier for the upstream AI provider (e.g. openai_api, anthropic_api, azure_openai_api, bedrock_api, vertex_ai_api, mistral_api, custom). ProviderId string `json:"provider_id"` + // SkipTlsVerification Whether upstream TLS certificate verification is skipped when the proxy dials this provider's URL. Intended for self-hosted / internal gateways behind a private or self-signed certificate. + SkipTlsVerification bool `json:"skip_tls_verification"` + // UpdatedAt Timestamp when the provider was last updated. UpdatedAt *time.Time `json:"updated_at,omitempty"` @@ -2272,6 +2275,9 @@ type AgentNetworkProviderRequest struct { // ProviderId Catalog identifier for the upstream AI provider (e.g. openai_api, anthropic_api, azure_openai_api, bedrock_api, vertex_ai_api, mistral_api, custom). ProviderId string `json:"provider_id"` + // SkipTlsVerification Skip upstream TLS certificate verification when the proxy dials this provider's URL. For self-hosted / internal gateways behind a private or self-signed certificate. Defaults to false. When omitted on update, the stored value is left unchanged. + SkipTlsVerification *bool `json:"skip_tls_verification,omitempty"` + // UpstreamUrl Full upstream URL (with scheme) that NetBird forwards traffic to. UpstreamUrl string `json:"upstream_url"` } From 06839a4731a64fea2ef7fb0d4857c356ca0eb60f Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 1 Jul 2026 22:08:23 +0200 Subject: [PATCH 09/19] [client] Fix race between WG watcher initial handshake read and endpoint creation (#6626) * [client] Fix race between WG watcher initial handshake read and endpoint config The watcher's initial handshake read ran in a separate goroutine with no ordering guarantee relative to the WireGuard endpoint configuration, so it would sometimes race with the peer being added to the interface. Split enabling into a synchronous PrepareInitialHandshake, called before the endpoint is configured, and an EnableWgWatcher that only runs the monitoring loop, making the baseline read deterministic and keeping it correct for reconnects where the peer's WireGuard entry survives. * [client] Skip WG watcher disconnect callback when context is cancelled A superseded or cancelled watcher whose handshake-check timer fires before it observes ctx.Done() would still invoke onDisconnectedFn, tearing down a now-healthy connection. Re-check ctx before firing the disconnect and handshake-success callbacks and stand down silently if it was cancelled. --- client/internal/peer/conn.go | 18 ++++++----- client/internal/peer/wg_watcher.go | 41 ++++++++++++++----------- client/internal/peer/wg_watcher_test.go | 10 ++++++ 3 files changed, 43 insertions(+), 26 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 85e54ba5f..fb468696f 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -803,15 +803,17 @@ func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) { } func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) { - if !conn.wgWatcher.IsEnabled() { - wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx) - conn.wgWatcherCancel = wgWatcherCancel - conn.wgWatcherWg.Add(1) - go func() { - defer conn.wgWatcherWg.Done() - conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess) - }() + if !conn.wgWatcher.PrepareInitialHandshake() { + return } + + wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx) + conn.wgWatcherCancel = wgWatcherCancel + conn.wgWatcherWg.Add(1) + go func() { + defer conn.wgWatcherWg.Done() + conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess) + }() } func (conn *Conn) disableWgWatcherIfNeeded() { diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go index 805a6f24a..4fc883d17 100644 --- a/client/internal/peer/wg_watcher.go +++ b/client/internal/peer/wg_watcher.go @@ -31,7 +31,9 @@ type WGWatcher struct { stateDump *stateDump enabled bool - muEnabled sync.RWMutex + muEnabled sync.Mutex + // initialHandshake is not thread-safe; never call PrepareInitialHandshake and EnableWgWatcher concurrently. + initialHandshake time.Time resetCh chan struct{} } @@ -46,38 +48,38 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin } } -// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing. -// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management. -func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) { +// PrepareInitialHandshake reserves the watcher and reads the peer's current WireGuard +// handshake time. It must be called before the peer is (re)configured on the WireGuard +// interface, so the captured baseline reflects the state prior to this connection attempt +// instead of racing with that configuration. Returns ok=false if the watcher is already +// running, in which case EnableWgWatcher must not be called. +func (w *WGWatcher) PrepareInitialHandshake() (ok bool) { w.muEnabled.Lock() if w.enabled { w.muEnabled.Unlock() - return + return false } w.log.Debugf("enable WireGuard watcher") w.enabled = true w.muEnabled.Unlock() - initialHandshake, err := w.wgState() - if err != nil { - w.log.Warnf("failed to read initial wg stats: %v", err) - } + handshake, _ := w.wgState() + w.initialHandshake = handshake + return true +} - w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, initialHandshake) +// EnableWgWatcher runs the WireGuard watcher loop using the handshake baseline captured by +// PrepareInitialHandshake. The watcher runs until ctx is cancelled. Caller is responsible +// for context lifecycle management. +func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) { + w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, w.initialHandshake) w.muEnabled.Lock() w.enabled = false w.muEnabled.Unlock() } -// IsEnabled returns true if the WireGuard watcher is currently enabled -func (w *WGWatcher) IsEnabled() bool { - w.muEnabled.RLock() - defer w.muEnabled.RUnlock() - return w.enabled -} - // Reset signals the watcher that the WireGuard peer has been reset and a new // handshake is expected. This restarts the handshake timeout from scratch. func (w *WGWatcher) Reset() { @@ -101,13 +103,16 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn case <-timer.C: handshake, ok := w.handshakeCheck(lastHandshake) if !ok { + if ctx.Err() != nil { + return + } onDisconnectedFn() return } if lastHandshake.IsZero() { elapsed := calcElapsed(enabledTime, *handshake) w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake) - if onHandshakeSuccessFn != nil { + if onHandshakeSuccessFn != nil && ctx.Err() == nil { onHandshakeSuccessFn(*handshake) } } diff --git a/client/internal/peer/wg_watcher_test.go b/client/internal/peer/wg_watcher_test.go index 3ce91cd46..634d7974f 100644 --- a/client/internal/peer/wg_watcher_test.go +++ b/client/internal/peer/wg_watcher_test.go @@ -7,6 +7,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/iface/configurer" ) @@ -34,6 +35,9 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + ok := watcher.PrepareInitialHandshake() + require.True(t, ok, "watcher should not be enabled yet") + onDisconnected := make(chan struct{}, 1) go watcher.EnableWgWatcher(ctx, time.Now(), func() { mlog.Infof("onDisconnectedFn") @@ -62,6 +66,9 @@ func TestWGWatcher_ReEnable(t *testing.T) { watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{})) ctx, cancel := context.WithCancel(context.Background()) + ok := watcher.PrepareInitialHandshake() + require.True(t, ok, "watcher should not be enabled yet") + wg := &sync.WaitGroup{} wg.Add(1) go func() { @@ -76,6 +83,9 @@ func TestWGWatcher_ReEnable(t *testing.T) { ctx, cancel = context.WithCancel(context.Background()) defer cancel() + ok = watcher.PrepareInitialHandshake() + require.True(t, ok, "watcher should be re-enabled after the previous run stopped") + onDisconnected := make(chan struct{}, 1) go watcher.EnableWgWatcher(ctx, time.Now(), func() { onDisconnected <- struct{}{} From 7d4736de5579fecbef172e203ec7f6a424ac2885 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 1 Jul 2026 22:08:43 +0200 Subject: [PATCH 10/19] [management] Enable lazy connections by default on new accounts (#6571) With improvements in userspace lazy connection handling, we should be able to enable it for new accounts with less impact on users. These connections are cheaper and only target traffic that should go through the tunnels, leaving all other tunnels in an idle state. --- management/server/account.go | 1 + 1 file changed, 1 insertion(+) diff --git a/management/server/account.go b/management/server/account.go index 94335cf27..9d2759cb7 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2057,6 +2057,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain, email, nam Extra: &types.ExtraSettings{ UserApprovalRequired: true, }, + LazyConnectionEnabled: true, }, Onboarding: types.AccountOnboarding{ OnboardingFlowPending: true, From 1d8b5f6e5cf08290b02af1f1300b33d0de723c4f Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 2 Jul 2026 17:58:16 +0900 Subject: [PATCH 11/19] [client] Make lazy connections opt-out via NB_LAZY_CONN (#6617) --- client/android/env_list.go | 2 +- client/cmd/root.go | 17 ++-- client/cmd/up.go | 10 --- client/internal/auth/auth.go | 1 - client/internal/conn_mgr.go | 55 ++++++++++--- client/internal/conn_mgr_test.go | 40 +++++++++ client/internal/connect.go | 4 +- client/internal/debug/debug.go | 2 +- client/internal/debug/debug_test.go | 2 +- client/internal/engine.go | 7 +- client/internal/lazyconn/env.go | 57 ++++++++++--- client/internal/lazyconn/env_test.go | 45 +++++++++++ client/internal/profilemanager/config.go | 21 ++--- .../profilemanager/config_mdm_test.go | 31 +++++++ client/ios/NetBirdSDK/env_list.go | 2 +- client/mdm/canonical_loaders.go | 1 + client/mdm/policy.go | 14 +++- client/mdm/policy_test.go | 13 ++- client/server/mdm.go | 3 - client/server/server.go | 3 - client/server/setconfig_test.go | 81 +++++++++---------- client/system/info.go | 6 +- client/ui/client_ui.go | 35 +++----- client/ui/const.go | 1 - client/ui/event_handler.go | 11 --- shared/management/client/grpc.go | 2 - 26 files changed, 312 insertions(+), 154 deletions(-) create mode 100644 client/internal/conn_mgr_test.go create mode 100644 client/internal/lazyconn/env_test.go diff --git a/client/android/env_list.go b/client/android/env_list.go index a0a4d7040..d0e0a1e78 100644 --- a/client/android/env_list.go +++ b/client/android/env_list.go @@ -10,7 +10,7 @@ var ( EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay // EnvKeyNBLazyConn Exported for Android java client to configure lazy connection - EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn + EnvKeyNBLazyConn = lazyconn.EnvLazyConn // EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold diff --git a/client/cmd/root.go b/client/cmd/root.go index f3fde2f1c..f1ef32717 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -71,12 +71,14 @@ var ( extraIFaceBlackList []string anonymizeFlag bool dnsRouteInterval time.Duration - lazyConnEnabled bool - mtu uint16 - profilesDisabled bool - updateSettingsDisabled bool - captureEnabled bool - networksDisabled bool + // lazyConnEnabled is the parse target for the deprecated --enable-lazy-connection + // flag. The flag is inert; the value is no longer read (use NB_LAZY_CONN instead). + lazyConnEnabled bool + mtu uint16 + profilesDisabled bool + updateSettingsDisabled bool + captureEnabled bool + networksDisabled bool rootCmd = &cobra.Command{ Use: "netbird", @@ -210,7 +212,8 @@ func init() { upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.") upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") - upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.") + upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "Deprecated: no longer used. Lazy connections are controlled by the server and the NB_LAZY_CONN environment variable.") + _ = upCmd.PersistentFlags().MarkDeprecated(enableLazyConnectionFlag, "no longer used; lazy connections are controlled by the server and the NB_LAZY_CONN environment variable") } diff --git a/client/cmd/up.go b/client/cmd/up.go index 0506bc65b..8b3de3c66 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -479,10 +479,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro req.DisableIpv6 = &disableIPv6 } - if cmd.Flag(enableLazyConnectionFlag).Changed { - req.LazyConnectionEnabled = &lazyConnEnabled - } - return &req } @@ -600,9 +596,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil ic.DisableIPv6 = &disableIPv6 } - if cmd.Flag(enableLazyConnectionFlag).Changed { - ic.LazyConnectionEnabled = &lazyConnEnabled - } return &ic, nil } @@ -718,9 +711,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte loginRequest.DisableIpv6 = &disableIPv6 } - if cmd.Flag(enableLazyConnectionFlag).Changed { - loginRequest.LazyConnectionEnabled = &lazyConnEnabled - } return &loginRequest, nil } diff --git a/client/internal/auth/auth.go b/client/internal/auth/auth.go index afc8ee77f..850e0706d 100644 --- a/client/internal/auth/auth.go +++ b/client/internal/auth/auth.go @@ -322,7 +322,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) { a.config.BlockLANAccess, a.config.BlockInbound, a.config.DisableIPv6, - a.config.LazyConnectionEnabled, a.config.EnableSSHRoot, a.config.EnableSSHSFTP, a.config.EnableSSHLocalPortForwarding, diff --git a/client/internal/conn_mgr.go b/client/internal/conn_mgr.go index 112559132..a82a4ca8b 100644 --- a/client/internal/conn_mgr.go +++ b/client/internal/conn_mgr.go @@ -16,6 +16,16 @@ import ( "github.com/netbirdio/netbird/route" ) +// lazyForce is the resolved local decision for lazy connections, layered above the +// management feature flag. lazyForceNone defers to management. +type lazyForce int + +const ( + lazyForceNone lazyForce = iota + lazyForceOn + lazyForceOff +) + // ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections. // // The connection manager is responsible for: @@ -28,7 +38,7 @@ type ConnMgr struct { peerStore *peerstore.Store statusRecorder *peer.Status iface lazyconn.WGIface - enabledLocally bool + force lazyForce rosenpassEnabled bool lazyConnMgr *manager.Manager @@ -43,28 +53,34 @@ func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerSto peerStore: peerStore, statusRecorder: statusRecorder, iface: iface, + force: resolveLazyForce(engineConfig.LazyConnection), rosenpassEnabled: engineConfig.RosenpassEnabled, } - if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() { - e.enabledLocally = true - } return e } -// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option. +// Start initializes the connection manager. It starts the lazy connection manager when a +// local override forces it on; with no local override it waits for the management feature flag. func (e *ConnMgr) Start(ctx context.Context) { if e.lazyConnMgr != nil { log.Errorf("lazy connection manager is already started") return } - if !e.enabledLocally { - log.Infof("lazy connection manager is disabled") + switch e.force { + case lazyForceOff: + log.Infof("lazy connection manager is disabled by local override (%s or MDM policy)", lazyconn.EnvLazyConn) + e.statusRecorder.UpdateLazyConnection(false) + return + case lazyForceNone: + log.Infof("lazy connection manager is managed by the management feature flag") + e.statusRecorder.UpdateLazyConnection(false) return } if e.rosenpassEnabled { log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started") + e.statusRecorder.UpdateLazyConnection(false) return } @@ -76,8 +92,8 @@ func (e *ConnMgr) Start(ctx context.Context) { // If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again. // If disabled, then it closes the lazy connection manager and open the connections to all peers. func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error { - // do not disable lazy connection manager if it was enabled by env var - if e.enabledLocally { + // a local override (NB_LAZY_CONN or local config) takes precedence over management + if e.force != lazyForceNone { return nil } @@ -89,6 +105,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er if e.rosenpassEnabled { log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started") + e.statusRecorder.UpdateLazyConnection(false) return nil } @@ -98,6 +115,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er return e.addPeersToLazyConnManager() } else { if e.lazyConnMgr == nil { + e.statusRecorder.UpdateLazyConnection(false) return nil } log.Infof("lazy connection manager is disabled by management feature flag") @@ -309,6 +327,25 @@ func (e *ConnMgr) isStartedWithLazyMgr() bool { return e.lazyConnMgr != nil && e.lazyCtxCancel != nil } +// resolveLazyForce determines the local override. NB_LAZY_CONN takes precedence; when it +// is unset the MDM policy override (mdmState) applies. Either wins in both directions over +// the management feature flag; StateUnset for both defers to management. +func resolveLazyForce(mdmState lazyconn.State) lazyForce { + state := lazyconn.EnvState() + if state == lazyconn.StateUnset { + state = mdmState + } + + switch state { + case lazyconn.StateOn: + return lazyForceOn + case lazyconn.StateOff: + return lazyForceOff + default: + return lazyForceNone + } +} + func inactivityThresholdEnv() *time.Duration { envValue := os.Getenv(lazyconn.EnvInactivityThreshold) if envValue == "" { diff --git a/client/internal/conn_mgr_test.go b/client/internal/conn_mgr_test.go new file mode 100644 index 000000000..5e2c53e35 --- /dev/null +++ b/client/internal/conn_mgr_test.go @@ -0,0 +1,40 @@ +package internal + +import ( + "os" + "testing" + + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +func TestResolveLazyForce(t *testing.T) { + tests := []struct { + name string + env string + envSet bool + mdm lazyconn.State + want lazyForce + }{ + {name: "env unset, mdm unset -> defer to management", mdm: lazyconn.StateUnset, want: lazyForceNone}, + {name: "env on -> force on", env: "on", envSet: true, mdm: lazyconn.StateUnset, want: lazyForceOn}, + {name: "env off -> force off", env: "off", envSet: true, mdm: lazyconn.StateUnset, want: lazyForceOff}, + {name: "env unset, mdm on -> force on", mdm: lazyconn.StateOn, want: lazyForceOn}, + {name: "env unset, mdm off -> force off", mdm: lazyconn.StateOff, want: lazyForceOff}, + {name: "env on beats mdm off", env: "on", envSet: true, mdm: lazyconn.StateOff, want: lazyForceOn}, + {name: "env off beats mdm on", env: "off", envSet: true, mdm: lazyconn.StateOn, want: lazyForceOff}, + {name: "unrecognized env, mdm on -> mdm wins", env: "auto", envSet: true, mdm: lazyconn.StateOn, want: lazyForceOn}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv(lazyconn.EnvLazyConn, tt.env) + if !tt.envSet { + os.Unsetenv(lazyconn.EnvLazyConn) + } + + if got := resolveLazyForce(tt.mdm); got != tt.want { + t.Fatalf("resolveLazyForce(%v) = %v, want %v", tt.mdm, got, tt.want) + } + }) + } +} diff --git a/client/internal/connect.go b/client/internal/connect.go index eff2c9489..93467b09a 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -27,6 +27,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/metrics" "github.com/netbirdio/netbird/client/internal/peer" @@ -601,7 +602,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf BlockInbound: config.BlockInbound, DisableIPv6: config.DisableIPv6, - LazyConnectionEnabled: config.LazyConnectionEnabled, + LazyConnection: lazyconn.ParseState(config.LazyConnection), MTU: selectMTU(config.MTU, peerConfig.Mtu), LogPath: logPath, @@ -675,7 +676,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config.BlockLANAccess, config.BlockInbound, config.DisableIPv6, - config.LazyConnectionEnabled, config.EnableSSHRoot, config.EnableSSHSFTP, config.EnableSSHLocalPortForwarding, diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index a65d8bd05..5700b05de 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -681,7 +681,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) configContent.WriteString(fmt.Sprintf("ClientCertKeyPath: %s\n", g.internalConfig.ClientCertKeyPath)) } - configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled)) + configContent.WriteString(fmt.Sprintf("LazyConnection: %q\n", g.internalConfig.LazyConnection)) configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU)) } diff --git a/client/internal/debug/debug_test.go b/client/internal/debug/debug_test.go index ca7785d35..8286f6852 100644 --- a/client/internal/debug/debug_test.go +++ b/client/internal/debug/debug_test.go @@ -885,7 +885,7 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) { DNSRouteInterval: 5 * time.Second, ClientCertPath: "/tmp/cert", ClientCertKeyPath: "/tmp/key", - LazyConnectionEnabled: true, + LazyConnection: "on", MTU: 1280, } diff --git a/client/internal/engine.go b/client/internal/engine.go index fb1d08f5e..a08bea31b 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -40,6 +40,7 @@ import ( "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/expose" "github.com/netbirdio/netbird/client/internal/ingressgw" + "github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/metrics" "github.com/netbirdio/netbird/client/internal/netflow" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" @@ -147,7 +148,9 @@ type EngineConfig struct { BlockInbound bool DisableIPv6 bool - LazyConnectionEnabled bool + // LazyConnection is the MDM-sourced lazy-connection override; StateUnset defers to + // the env var and management feature flag. + LazyConnection lazyconn.State MTU uint16 @@ -1130,7 +1133,6 @@ func (e *Engine) applyInfoFlags(info *system.Info) { e.config.BlockLANAccess, e.config.BlockInbound, e.config.DisableIPv6, - e.config.LazyConnectionEnabled, e.config.EnableSSHRoot, e.config.EnableSSHSFTP, e.config.EnableSSHLocalPortForwarding, @@ -1999,7 +2001,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err e.config.BlockLANAccess, e.config.BlockInbound, e.config.DisableIPv6, - e.config.LazyConnectionEnabled, e.config.EnableSSHRoot, e.config.EnableSSHSFTP, e.config.EnableSSHLocalPortForwarding, diff --git a/client/internal/lazyconn/env.go b/client/internal/lazyconn/env.go index 649d1cd65..d408083e7 100644 --- a/client/internal/lazyconn/env.go +++ b/client/internal/lazyconn/env.go @@ -3,24 +3,57 @@ package lazyconn import ( "os" "strconv" + "strings" log "github.com/sirupsen/logrus" ) const ( - EnvEnableLazyConn = "NB_ENABLE_EXPERIMENTAL_LAZY_CONN" + EnvLazyConn = "NB_LAZY_CONN" EnvInactivityThreshold = "NB_LAZY_CONN_INACTIVITY_THRESHOLD" ) -func IsLazyConnEnabledByEnv() bool { - val := os.Getenv(EnvEnableLazyConn) - if val == "" { - return false - } - enabled, err := strconv.ParseBool(val) - if err != nil { - log.Warnf("failed to parse %s: %v", EnvEnableLazyConn, err) - return false - } - return enabled +// State is the tri-state local override for lazy connections read from the environment. +type State int + +const ( + // StateUnset means no local override; defer to the management feature flag. + StateUnset State = iota + // StateOn forces lazy connections on, overriding management. + StateOn + // StateOff forces lazy connections off, overriding management. + StateOff +) + +// EnvState reads NB_LAZY_CONN and returns the local override state. +func EnvState() State { + return ParseState(os.Getenv(EnvLazyConn)) +} + +// ParseState interprets a lazy-connection override value (from the environment or an MDM +// policy). It accepts the on/off aliases plus any value strconv.ParseBool understands +// (true/false/1/0). An empty or unrecognized value returns StateUnset so that the +// management feature flag remains in control. +func ParseState(raw string) State { + if raw == "" { + return StateUnset + } + + normalized := strings.ToLower(strings.TrimSpace(raw)) + switch normalized { + case "on": + return StateOn + case "off": + return StateOff + } + + enabled, err := strconv.ParseBool(normalized) + if err != nil { + log.Warnf("failed to parse lazy connection value %q (from %s env or MDM policy): %v", raw, EnvLazyConn, err) + return StateUnset + } + if enabled { + return StateOn + } + return StateOff } diff --git a/client/internal/lazyconn/env_test.go b/client/internal/lazyconn/env_test.go new file mode 100644 index 000000000..59ee40c4b --- /dev/null +++ b/client/internal/lazyconn/env_test.go @@ -0,0 +1,45 @@ +package lazyconn + +import ( + "os" + "testing" +) + +func TestEnvState(t *testing.T) { + tests := []struct { + value string + set bool + want State + }{ + {set: false, want: StateUnset}, + {value: "", set: true, want: StateUnset}, + {value: "on", set: true, want: StateOn}, + {value: "ON", set: true, want: StateOn}, + {value: "true", set: true, want: StateOn}, + {value: "1", set: true, want: StateOn}, + {value: " on ", set: true, want: StateOn}, + {value: "off", set: true, want: StateOff}, + {value: "OFF", set: true, want: StateOff}, + {value: "false", set: true, want: StateOff}, + {value: "0", set: true, want: StateOff}, + {value: "auto", set: true, want: StateUnset}, + {value: "garbage", set: true, want: StateUnset}, + } + + for _, tt := range tests { + name := tt.value + if !tt.set { + name = "unset" + } + t.Run(name, func(t *testing.T) { + t.Setenv(EnvLazyConn, tt.value) + if !tt.set { + os.Unsetenv(EnvLazyConn) + } + + if got := EnvState(); got != tt.want { + t.Fatalf("EnvState() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 8ffcb16f2..ed2f21999 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -101,8 +101,6 @@ type ConfigInput struct { DNSLabels domain.List - LazyConnectionEnabled *bool - MTU *uint16 } @@ -180,7 +178,9 @@ type Config struct { ClientCertKeyPair *tls.Certificate `json:"-"` - LazyConnectionEnabled bool + // LazyConnection is the MDM-managed lazy-connection override ("on"/"off"/""). + // Runtime-only: re-derived from MDM policy on each load, never persisted. + LazyConnection string `json:"-"` MTU uint16 @@ -632,12 +632,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } - if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled { - log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled) - config.LazyConnectionEnabled = *input.LazyConnectionEnabled - updated = true - } - if input.MTU != nil && *input.MTU != config.MTU { log.Infof("updating MTU to %d (old value %d)", *input.MTU, config.MTU) config.MTU = *input.MTU @@ -728,6 +722,15 @@ func (config *Config) applyMDMPolicy(policy *mdm.Policy) { log.Warnf("MDM wireguard port %d out of range [1,65535]; keeping previous value", v) } } + + if v, ok := policy.GetBool(mdm.KeyLazyConnection); ok { + state := "off" + if v { + state = "on" + } + config.LazyConnection = state + logApplied(mdm.KeyLazyConnection, state) + } } // parseURL parses and validates the URL for the named service. The URL diff --git a/client/internal/profilemanager/config_mdm_test.go b/client/internal/profilemanager/config_mdm_test.go index 6a201235e..c6a688ab2 100644 --- a/client/internal/profilemanager/config_mdm_test.go +++ b/client/internal/profilemanager/config_mdm_test.go @@ -130,6 +130,37 @@ func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) { assert.True(t, cfg.Policy().HasKey(mdm.KeyRosenpassEnabled)) } +func TestApply_MDMLazyConnection(t *testing.T) { + cases := []struct { + name string + raw any + want string + }{ + {"native true", true, "on"}, + {"native false", false, "off"}, + {"string on", "on", "on"}, + {"string off", "off", "off"}, + {"string yes", "yes", "on"}, + {"string no", "no", "off"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + withMDMPolicy(t, mdm.NewPolicy(map[string]any{ + mdm.KeyLazyConnection: c.raw, + })) + + cfg, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: filepath.Join(t.TempDir(), "config.json"), + }) + require.NoError(t, err) + require.NotNil(t, cfg) + + assert.Equal(t, c.want, cfg.LazyConnection) + assert.True(t, cfg.Policy().HasKey(mdm.KeyLazyConnection)) + }) + } +} + func TestApply_MDMPreSharedKeyRedactionSentinelRejected(t *testing.T) { const maskSentinel = "**********" diff --git a/client/ios/NetBirdSDK/env_list.go b/client/ios/NetBirdSDK/env_list.go index 88ac97957..a3ffa0ebe 100644 --- a/client/ios/NetBirdSDK/env_list.go +++ b/client/ios/NetBirdSDK/env_list.go @@ -38,7 +38,7 @@ func GetEnvKeyNBForceRelay() string { // GetEnvKeyNBLazyConn Exports the environment variable for the iOS client func GetEnvKeyNBLazyConn() string { - return lazyconn.EnvEnableLazyConn + return lazyconn.EnvLazyConn } // GetEnvKeyNBInactivityThreshold Exports the environment variable for the iOS client diff --git a/client/mdm/canonical_loaders.go b/client/mdm/canonical_loaders.go index 6e7ab19cb..b20a823fb 100644 --- a/client/mdm/canonical_loaders.go +++ b/client/mdm/canonical_loaders.go @@ -27,6 +27,7 @@ var allKeys = []string{ KeyWireguardPort, KeySplitTunnelMode, KeySplitTunnelApps, + KeyLazyConnection, } // canonicalKey maps the lowercase form of a managed-config value name to diff --git a/client/mdm/policy.go b/client/mdm/policy.go index 109fb322e..67b126101 100644 --- a/client/mdm/policy.go +++ b/client/mdm/policy.go @@ -11,6 +11,7 @@ package mdm import ( "sort" "strconv" + "strings" log "github.com/sirupsen/logrus" ) @@ -41,6 +42,11 @@ const ( // construction — only one mode can be set at a time. KeySplitTunnelMode = "splitTunnelMode" KeySplitTunnelApps = "splitTunnelApps" + + // KeyLazyConnection forces the lazy-connection feature on or off, overriding + // the management feature flag. Read as a bool (native bool, or on/off, + // true/false, 1/0, yes/no); absent = defer to management. + KeyLazyConnection = "lazyConnection" ) // Split-tunnel mode literals (KeySplitTunnelMode values). @@ -62,12 +68,13 @@ var boolStringLiterals = map[string]bool{ "true": true, "1": true, "yes": true, + "on": true, "false": false, "0": false, "no": false, + "off": false, } - // Policy holds MDM-managed settings read from the platform source. A nil or // empty Policy means no enforcement is active. type Policy struct { @@ -150,7 +157,8 @@ func (p *Policy) GetString(key string) (string, bool) { } // GetBool returns the managed value for key coerced to bool, and whether the -// key was set. Accepts native bool and string literals "true"/"false"/"1"/"0". +// key was set. Accepts native bool and string literals (true/false, 1/0, +// yes/no, on/off), case-insensitively and trimmed of surrounding whitespace. func (p *Policy) GetBool(key string) (bool, bool) { if p == nil { return false, false @@ -163,7 +171,7 @@ func (p *Policy) GetBool(key string) (bool, bool) { case bool: return t, true case string: - b, known := boolStringLiterals[t] + b, known := boolStringLiterals[strings.ToLower(strings.TrimSpace(t))] return b, known case int: return t != 0, true diff --git a/client/mdm/policy_test.go b/client/mdm/policy_test.go index 47a6ed2c9..6cbe69776 100644 --- a/client/mdm/policy_test.go +++ b/client/mdm/policy_test.go @@ -31,8 +31,8 @@ func TestPolicy_Empty(t *testing.T) { func TestPolicy_HasKey(t *testing.T) { p := NewPolicy(map[string]any{ - KeyManagementURL: "https://corp.example.com", - KeyDisableProfiles: true, + KeyManagementURL: "https://corp.example.com", + KeyDisableProfiles: true, }) assert.False(t, p.IsEmpty()) assert.True(t, p.HasKey(KeyManagementURL)) @@ -53,8 +53,8 @@ func TestPolicy_ManagedKeysSorted(t *testing.T) { func TestPolicy_GetString(t *testing.T) { p := NewPolicy(map[string]any{ KeyManagementURL: "https://corp.example.com", - KeyDisableProfiles: true, // wrong type for GetString - KeyPreSharedKey: "", // empty rejected + KeyDisableProfiles: true, // wrong type for GetString + KeyPreSharedKey: "", // empty rejected }) v, ok := p.GetString(KeyManagementURL) assert.True(t, ok) @@ -85,6 +85,11 @@ func TestPolicy_GetBool(t *testing.T) { {"string 0", "0", false, true}, {"string yes", "yes", true, true}, {"string no", "no", false, true}, + {"string on", "on", true, true}, + {"string off", "off", false, true}, + {"mixed case On", "On", true, true}, + {"upper TRUE", "TRUE", true, true}, + {"padded yes", " yes ", true, true}, {"int nonzero", 1, true, true}, {"int zero", 0, false, true}, {"int64 nonzero", int64(2), true, true}, diff --git a/client/server/mdm.go b/client/server/mdm.go index 0da0ec5d1..db7db2759 100644 --- a/client/server/mdm.go +++ b/client/server/mdm.go @@ -152,7 +152,6 @@ func (s *Server) restartEngineForMDMLocked() error { s.config = config s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive) - s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled) ctx, cancel := context.WithCancel(s.rootCtx) s.actCancel = cancel @@ -305,7 +304,6 @@ func setConfigRequestHasConfigOverrides(msg *proto.SetConfigRequest) bool { msg.DisableFirewall != nil || msg.BlockLanAccess != nil || msg.DisableNotifications != nil || - msg.LazyConnectionEnabled != nil || msg.BlockInbound != nil || msg.DisableIpv6 != nil || msg.EnableSSHRoot != nil || @@ -348,7 +346,6 @@ func loginRequestHasConfigOverrides(msg *proto.LoginRequest) bool { msg.BlockLanAccess != nil || msg.DisableNotifications != nil || len(msg.DnsLabels) > 0 || msg.CleanDNSLabels || - msg.LazyConnectionEnabled != nil || msg.BlockInbound != nil } diff --git a/client/server/server.go b/client/server/server.go index 3f6dabc56..e8ef2f96e 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -214,7 +214,6 @@ func (s *Server) Start() error { s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive) - s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled) if s.sessionWatcher == nil { s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder) @@ -463,7 +462,6 @@ func (s *Server) setConfigInputFromRequest(msg *proto.SetConfigRequest) (profile config.DisableFirewall = msg.DisableFirewall config.BlockLANAccess = msg.BlockLanAccess config.DisableNotifications = msg.DisableNotifications - config.LazyConnectionEnabled = msg.LazyConnectionEnabled config.BlockInbound = msg.BlockInbound config.DisableIPv6 = msg.DisableIpv6 config.EnableSSHRoot = msg.EnableSSHRoot @@ -1647,7 +1645,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p ServerSSHAllowed: *cfg.ServerSSHAllowed, RosenpassEnabled: cfg.RosenpassEnabled, RosenpassPermissive: cfg.RosenpassPermissive, - LazyConnectionEnabled: cfg.LazyConnectionEnabled, BlockInbound: cfg.BlockInbound, DisableNotifications: disableNotifications, NetworkMonitor: networkMonitor, diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go index 7c85d16ce..0e55257a9 100644 --- a/client/server/setconfig_test.go +++ b/client/server/setconfig_test.go @@ -69,43 +69,41 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { disableFirewall := true blockLANAccess := true disableNotifications := true - lazyConnectionEnabled := true blockInbound := true disableIPv6 := true mtu := int64(1280) sshJWTCacheTTL := int32(300) req := &proto.SetConfigRequest{ - ProfileName: profName, - Username: currUser.Username, - ManagementUrl: "https://new-api.netbird.io:443", - AdminURL: "https://new-admin.netbird.io", - RosenpassEnabled: &rosenpassEnabled, - RosenpassPermissive: &rosenpassPermissive, - ServerSSHAllowed: &serverSSHAllowed, - InterfaceName: &interfaceName, - WireguardPort: &wireguardPort, - OptionalPreSharedKey: &preSharedKey, - DisableAutoConnect: &disableAutoConnect, - NetworkMonitor: &networkMonitor, - DisableClientRoutes: &disableClientRoutes, - DisableServerRoutes: &disableServerRoutes, - DisableDns: &disableDNS, - DisableFirewall: &disableFirewall, - BlockLanAccess: &blockLANAccess, - DisableNotifications: &disableNotifications, - LazyConnectionEnabled: &lazyConnectionEnabled, - BlockInbound: &blockInbound, - DisableIpv6: &disableIPv6, - NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"}, - CleanNATExternalIPs: false, - CustomDNSAddress: []byte("1.1.1.1:53"), - ExtraIFaceBlacklist: []string{"eth1", "eth2"}, - DnsLabels: []string{"label1", "label2"}, - CleanDNSLabels: false, - DnsRouteInterval: durationpb.New(2 * time.Minute), - Mtu: &mtu, - SshJWTCacheTTL: &sshJWTCacheTTL, + ProfileName: profName, + Username: currUser.Username, + ManagementUrl: "https://new-api.netbird.io:443", + AdminURL: "https://new-admin.netbird.io", + RosenpassEnabled: &rosenpassEnabled, + RosenpassPermissive: &rosenpassPermissive, + ServerSSHAllowed: &serverSSHAllowed, + InterfaceName: &interfaceName, + WireguardPort: &wireguardPort, + OptionalPreSharedKey: &preSharedKey, + DisableAutoConnect: &disableAutoConnect, + NetworkMonitor: &networkMonitor, + DisableClientRoutes: &disableClientRoutes, + DisableServerRoutes: &disableServerRoutes, + DisableDns: &disableDNS, + DisableFirewall: &disableFirewall, + BlockLanAccess: &blockLANAccess, + DisableNotifications: &disableNotifications, + BlockInbound: &blockInbound, + DisableIpv6: &disableIPv6, + NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"}, + CleanNATExternalIPs: false, + CustomDNSAddress: []byte("1.1.1.1:53"), + ExtraIFaceBlacklist: []string{"eth1", "eth2"}, + DnsLabels: []string{"label1", "label2"}, + CleanDNSLabels: false, + DnsRouteInterval: durationpb.New(2 * time.Minute), + Mtu: &mtu, + SshJWTCacheTTL: &sshJWTCacheTTL, } _, err = s.SetConfig(ctx, req) @@ -140,7 +138,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { require.Equal(t, blockLANAccess, cfg.BlockLANAccess) require.NotNil(t, cfg.DisableNotifications) require.Equal(t, disableNotifications, *cfg.DisableNotifications) - require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled) require.Equal(t, blockInbound, cfg.BlockInbound) require.Equal(t, disableIPv6, cfg.DisableIPv6) require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs) @@ -164,13 +161,14 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) { t.Helper() metadataFields := map[string]bool{ - "state": true, // protobuf internal - "sizeCache": true, // protobuf internal - "unknownFields": true, // protobuf internal - "Username": true, // metadata - "ProfileName": true, // metadata - "CleanNATExternalIPs": true, // control flag for clearing - "CleanDNSLabels": true, // control flag for clearing + "state": true, // protobuf internal + "sizeCache": true, // protobuf internal + "unknownFields": true, // protobuf internal + "Username": true, // metadata + "ProfileName": true, // metadata + "CleanNATExternalIPs": true, // control flag for clearing + "CleanDNSLabels": true, // control flag for clearing + "LazyConnectionEnabled": true, // deprecated: proto field retained for compat, no longer applied } expectedFields := map[string]bool{ @@ -190,7 +188,6 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) { "DisableFirewall": true, "BlockLanAccess": true, "DisableNotifications": true, - "LazyConnectionEnabled": true, "BlockInbound": true, "DisableIpv6": true, "NatExternalIPs": true, @@ -252,7 +249,6 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) { "block-lan-access": "BlockLanAccess", "block-inbound": "BlockInbound", "disable-ipv6": "DisableIpv6", - "enable-lazy-connection": "LazyConnectionEnabled", "external-ip-map": "NatExternalIPs", "dns-resolver-address": "CustomDNSAddress", "extra-iface-blacklist": "ExtraIFaceBlacklist", @@ -269,7 +265,8 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) { // SetConfigRequest fields that don't have CLI flags (settable only via UI or other means). fieldsWithoutCLIFlags := map[string]bool{ - "DisableNotifications": true, // Only settable via UI + "DisableNotifications": true, // Only settable via UI + "LazyConnectionEnabled": true, // deprecated: no longer settable (managed by server + NB_LAZY_CONN) } // Get all SetConfigRequest fields to verify our map is complete. diff --git a/client/system/info.go b/client/system/info.go index 496b478a3..1838204b8 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -74,8 +74,6 @@ type Info struct { BlockInbound bool DisableIPv6 bool - LazyConnectionEnabled bool - EnableSSHRoot bool EnableSSHSFTP bool EnableSSHLocalPortForwarding bool @@ -87,7 +85,7 @@ func (i *Info) SetFlags( rosenpassEnabled, rosenpassPermissive bool, serverSSHAllowed *bool, disableClientRoutes, disableServerRoutes, - disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool, + disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6 bool, enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool, disableSSHAuth *bool, ) { @@ -105,8 +103,6 @@ func (i *Info) SetFlags( i.BlockInbound = blockInbound i.DisableIPv6 = disableIPv6 - i.LazyConnectionEnabled = lazyConnectionEnabled - if enableSSHRoot != nil { i.EnableSSHRoot = *enableSSHRoot } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 40fb4169d..2b19c2bf5 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -266,7 +266,6 @@ type serviceClient struct { mAllowSSH *systray.MenuItem mAutoConnect *systray.MenuItem mEnableRosenpass *systray.MenuItem - mLazyConnEnabled *systray.MenuItem mBlockInbound *systray.MenuItem mNotifications *systray.MenuItem mAdvancedSettings *systray.MenuItem @@ -336,11 +335,11 @@ type serviceClient struct { // mNetworks + mExitNode submenu items. Combines features.DisableNetworks // AND s.connected — both must be true for the menus to be active. // Zero value (false) matches the Disable() call at AddMenuItem time. - networksMenuEnabled bool - showNetworks bool - wNetworks fyne.Window - wProfiles fyne.Window - wQuickActions fyne.Window + networksMenuEnabled bool + showNetworks bool + wNetworks fyne.Window + wProfiles fyne.Window + wQuickActions fyne.Window eventManager *event.Manager @@ -1094,7 +1093,6 @@ func (s *serviceClient) onTrayReady() { s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false) s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false) s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false) - s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable Lazy Connections", lazyConnMenuDescr, false) s.mBlockInbound = s.mSettings.AddSubMenuItemCheckbox("Block Inbound Connections", blockInboundMenuDescr, false) s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false) s.mSettings.AddSeparator() @@ -1578,7 +1576,6 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config { config.RosenpassEnabled = cfg.RosenpassEnabled config.RosenpassPermissive = cfg.RosenpassPermissive config.DisableNotifications = &cfg.DisableNotifications - config.LazyConnectionEnabled = cfg.LazyConnectionEnabled config.BlockInbound = cfg.BlockInbound config.NetworkMonitor = &cfg.NetworkMonitor config.DisableDNS = cfg.DisableDns @@ -1682,12 +1679,6 @@ func (s *serviceClient) loadSettings() { s.mEnableRosenpass.Uncheck() } - if cfg.LazyConnectionEnabled { - s.mLazyConnEnabled.Check() - } else { - s.mLazyConnEnabled.Uncheck() - } - if cfg.BlockInbound { s.mBlockInbound.Check() } else { @@ -1833,7 +1824,6 @@ func (s *serviceClient) updateConfig() error { disableAutoStart := !s.mAutoConnect.Checked() sshAllowed := s.mAllowSSH.Checked() rosenpassEnabled := s.mEnableRosenpass.Checked() - lazyConnectionEnabled := s.mLazyConnEnabled.Checked() blockInbound := s.mBlockInbound.Checked() notificationsDisabled := !s.mNotifications.Checked() @@ -1856,14 +1846,13 @@ func (s *serviceClient) updateConfig() error { } req := proto.SetConfigRequest{ - ProfileName: activeProf.ID.String(), - Username: currUser.Username, - DisableAutoConnect: &disableAutoStart, - ServerSSHAllowed: &sshAllowed, - RosenpassEnabled: &rosenpassEnabled, - LazyConnectionEnabled: &lazyConnectionEnabled, - BlockInbound: &blockInbound, - DisableNotifications: ¬ificationsDisabled, + ProfileName: activeProf.ID.String(), + Username: currUser.Username, + DisableAutoConnect: &disableAutoStart, + ServerSSHAllowed: &sshAllowed, + RosenpassEnabled: &rosenpassEnabled, + BlockInbound: &blockInbound, + DisableNotifications: ¬ificationsDisabled, } if _, err := conn.SetConfig(s.ctx, &req); err != nil { diff --git a/client/ui/const.go b/client/ui/const.go index 48619be75..ce7a9a294 100644 --- a/client/ui/const.go +++ b/client/ui/const.go @@ -4,7 +4,6 @@ const ( allowSSHMenuDescr = "Allow SSH connections" autoConnectMenuDescr = "Connect automatically when the service starts" quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass" - lazyConnMenuDescr = "[Experimental] Enable lazy connections" blockInboundMenuDescr = "Block inbound connections to the local machine and routed networks" notificationsMenuDescr = "Enable notifications" advancedSettingsMenuDescr = "Advanced settings of the application" diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index 876fcef5f..902082308 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -43,8 +43,6 @@ func (h *eventHandler) listen(ctx context.Context) { h.handleAutoConnectClick() case <-h.client.mEnableRosenpass.ClickedCh: h.handleRosenpassClick() - case <-h.client.mLazyConnEnabled.ClickedCh: - h.handleLazyConnectionClick() case <-h.client.mBlockInbound.ClickedCh: h.handleBlockInboundClick() case <-h.client.mAdvancedSettings.ClickedCh: @@ -152,15 +150,6 @@ func (h *eventHandler) handleRosenpassClick() { } } -func (h *eventHandler) handleLazyConnectionClick() { - h.toggleCheckbox(h.client.mLazyConnEnabled) - if err := h.updateConfigWithErr(); err != nil { - h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error - log.Errorf("failed to update config: %v", err) - h.client.notifier.Send("Error", "Failed to update lazy connection settings") - } -} - func (h *eventHandler) handleBlockInboundClick() { h.toggleCheckbox(h.client.mBlockInbound) if err := h.updateConfigWithErr(); err != nil { diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 781e66a3e..bd4585455 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -1030,8 +1030,6 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta { BlockLANAccess: info.BlockLANAccess, BlockInbound: info.BlockInbound, DisableIPv6: info.DisableIPv6, - - LazyConnectionEnabled: info.LazyConnectionEnabled, }, Capabilities: peerCapabilities(*info), From 167be3a30fb4b90f5f22e1f1986fa15df58c1a8c Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 2 Jul 2026 12:15:57 +0200 Subject: [PATCH 12/19] [ci] Run privileged client tests natively with sudo on Linux (#6635) Restore the pre-split native, sudo-based run for the Linux Client / Unit job: build with the privileged tag and run under sudo, matching the darwin job. Excludes the dockertest harness (client/testutil/privileged) so it does not recurse into a container spawn. The Docker privileged job is kept as-is. --- .github/workflows/golang-test-linux.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 34b215c60..ce53261a4 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -158,7 +158,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags 'devcert privileged' -exec 'sudo --preserve-env=CI,CGO_ENABLED' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/testutil/privileged) - name: Upload coverage reports to Codecov if: matrix.arch == 'amd64' From e203e0f42a14855fa30be0b79eb70397839552d5 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 2 Jul 2026 14:20:23 +0200 Subject: [PATCH 13/19] [self-hosted] Remove unused server/proxy image override logic in getting-started.sh (#6636) --- infrastructure_files/getting-started.sh | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/infrastructure_files/getting-started.sh b/infrastructure_files/getting-started.sh index 46bef5a1f..837cc42e6 100755 --- a/infrastructure_files/getting-started.sh +++ b/infrastructure_files/getting-started.sh @@ -351,11 +351,6 @@ initialize_default_values() { NETBIRD_STUN_PORT=3478 # Docker images - # Record whether the operator explicitly pinned the server/proxy images via - # env vars, so the agent-network preset can pick its own defaults without - # clobbering an explicit override. - NETBIRD_SERVER_IMAGE_EXPLICIT=${NETBIRD_SERVER_IMAGE:+true} - NETBIRD_PROXY_IMAGE_EXPLICIT=${NETBIRD_PROXY_IMAGE:+true} DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"} # Combined server replaces separate signal, relay, and management containers NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"} @@ -415,15 +410,6 @@ apply_agent_network_preset() { ENABLE_PROXY="true" ENABLE_CROWDSEC="false" - # Agent-network ships dedicated server/proxy images. Honor an explicit - # env override; otherwise pin the agent-network builds. - if [[ "${NETBIRD_SERVER_IMAGE_EXPLICIT}" != "true" ]]; then - NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:0.74.0-rc.2" - fi - if [[ "${NETBIRD_PROXY_IMAGE_EXPLICIT}" != "true" ]]; then - NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:0.74.0-rc.2" - fi - if [[ -n "${NETBIRD_LETSENCRYPT_EMAIL}" ]]; then TRAEFIK_ACME_EMAIL="${NETBIRD_LETSENCRYPT_EMAIL}" else From e40cb294f6fe62b34ba5105f6ac726853dfbae6b Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Thu, 2 Jul 2026 14:45:24 +0200 Subject: [PATCH 14/19] [management] Add vLLM to Agent Network (#6643) --- .../modules/agentnetwork/catalog/catalog.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/management/internals/modules/agentnetwork/catalog/catalog.go b/management/internals/modules/agentnetwork/catalog/catalog.go index baf622778..962b30250 100644 --- a/management/internals/modules/agentnetwork/catalog/catalog.go +++ b/management/internals/modules/agentnetwork/catalog/catalog.go @@ -627,6 +627,21 @@ var providers = []Provider{ }, Models: []Model{}, }, + { + // vLLM is an OpenAI-compatible self-hosted server. It behaves like + // the generic custom entry; it gets its own catalog id purely so it + // surfaces as a named "vLLM" choice in the provider picker. + ID: "vllm", + Kind: KindCustom, + Name: "vLLM", + Description: "Self-hosted vLLM (OpenAI-compatible)", + DefaultHost: "", + AuthHeaderName: "Authorization", + AuthHeaderTemplate: "Bearer ${API_KEY}", + DefaultContentType: "application/json", + BrandColor: "#30A2FF", + Models: []Model{}, + }, { ID: "custom", Kind: KindCustom, From 859fe19fff6661ed0ba7904b42ed8b12d72c36f5 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 2 Jul 2026 14:55:55 +0200 Subject: [PATCH 15/19] [management] return nil when config is not set (#6642) * [management] return nil when config is not set * [management] add relay invariant test and enforce config behavior --- .../internals/shared/grpc/conversion.go | 13 +++---- .../internals/shared/grpc/conversion_test.go | 38 +++++++++++++++++++ management/server/peer_test.go | 6 +-- 3 files changed, 44 insertions(+), 13 deletions(-) diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index 973749eb0..bdb4c8cf4 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -47,16 +47,13 @@ func init() { precomputedDeprecatedRemotePeersConstraint = constraint } +// toNetbirdConfig converts the server configuration to the wire representation. It returns +// nil when no server config is set (the fan-out network-map path) because clients treat any +// non-nil config as authoritative: a config without a relay section is interpreted as relay +// disabled and wipes the clients' relay URLs. func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings, settings *types.Settings) *proto.NetbirdConfig { if config == nil { - if settings == nil { - return nil - } - return &proto.NetbirdConfig{ - Metrics: &proto.MetricsConfig{ - Enabled: settings.MetricsPushEnabled, - }, - } + return nil } var stuns []*proto.HostConfig diff --git a/management/internals/shared/grpc/conversion_test.go b/management/internals/shared/grpc/conversion_test.go index 01a67e4fa..c81bef25c 100644 --- a/management/internals/shared/grpc/conversion_test.go +++ b/management/internals/shared/grpc/conversion_test.go @@ -8,11 +8,13 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/types" ) func TestToProtocolDNSConfigWithCache(t *testing.T) { @@ -263,3 +265,39 @@ func TestEncodeSessionExpiresAt(t *testing.T) { assert.True(t, got.AsTime().Equal(deadline)) }) } + +// TestToNetbirdConfig_RelayInvariant guards against the v0.74.0 relay-wipe regression. +// Clients treat any non-nil NetbirdConfig as authoritative and interpret a missing relay +// section as relay disabled, wiping their relay URLs. toNetbirdConfig must therefore +// return nil when no server config is set (the fan-out network-map path) instead of a +// partial config, and a result built from a relay-enabled config must carry the relay +// section. +func TestToNetbirdConfig_RelayInvariant(t *testing.T) { + settings := &types.Settings{MetricsPushEnabled: true} + + t.Run("nil server config returns nil config", func(t *testing.T) { + nbCfg := toNetbirdConfig(nil, nil, nil, nil, settings) + assert.Nil(t, nbCfg, "fan-out updates must not carry a partial NetbirdConfig even when settings are present") + }) + + t.Run("relay-enabled config carries relay section", func(t *testing.T) { + cfg := &nbconfig.Config{ + Stuns: []*nbconfig.Host{{Proto: nbconfig.UDP, URI: "stun:stun.example.com:3478"}}, + TURNConfig: &nbconfig.TURNConfig{ + Turns: []*nbconfig.Host{{Proto: nbconfig.UDP, URI: "turn:turn.example.com:3478", Username: "user", Password: "pass"}}, + }, + Relay: &nbconfig.Relay{Addresses: []string{"rels://relay.example.com:443"}}, + Signal: &nbconfig.Host{Proto: nbconfig.HTTP, URI: "signal.example.com:10000"}, + } + relayToken := &Token{Payload: "token-payload", Signature: "token-signature"} + + nbCfg := toNetbirdConfig(cfg, nil, relayToken, nil, settings) + require.NotNil(t, nbCfg) + require.NotNil(t, nbCfg.Relay, "non-nil NetbirdConfig must include the relay section") + assert.Equal(t, cfg.Relay.Addresses, nbCfg.Relay.Urls, "relay URLs should match the server config") + assert.Equal(t, relayToken.Payload, nbCfg.Relay.TokenPayload, "relay token payload should be set") + assert.Equal(t, relayToken.Signature, nbCfg.Relay.TokenSignature, "relay token signature should be set") + require.NotNil(t, nbCfg.Metrics) + assert.True(t, nbCfg.Metrics.Enabled, "metrics flag should carry the settings value") + }) +} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 6c243c4c7..d471a1302 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1048,11 +1048,7 @@ func testUpdateAccountPeers(t *testing.T) { for _, channel := range peerChannels { update := <-channel - assert.NotNil(t, update.Update.NetbirdConfig) - assert.Nil(t, update.Update.NetbirdConfig.Stuns) - assert.Nil(t, update.Update.NetbirdConfig.Turns) - assert.Nil(t, update.Update.NetbirdConfig.Signal) - assert.Nil(t, update.Update.NetbirdConfig.Relay) + assert.Nil(t, update.Update.NetbirdConfig, "fan-out updates must not carry a NetbirdConfig; clients treat a config without relay as relay disabled and wipe their relay URLs") assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers)) assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules)) } From 1dfa85a917eb47e23e4dc4c142056e67966a1c03 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Thu, 2 Jul 2026 15:36:51 +0200 Subject: [PATCH 16/19] [management] Add vLLM e2e test (#6649) * Add vLLM to Agent Network * Add vllm e2e test --- e2e/agentnetwork/vllm_test.go | 171 ++++++++++++++++++++++++++++++++++ e2e/harness/vllm.go | 113 ++++++++++++++++++++++ 2 files changed, 284 insertions(+) create mode 100644 e2e/agentnetwork/vllm_test.go create mode 100644 e2e/harness/vllm.go diff --git a/e2e/agentnetwork/vllm_test.go b/e2e/agentnetwork/vllm_test.go new file mode 100644 index 000000000..329994ca9 --- /dev/null +++ b/e2e/agentnetwork/vllm_test.go @@ -0,0 +1,171 @@ +//go:build e2e + +package agentnetwork + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/e2e/harness" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// TestVLLMProvider proves the proxy supports a self-hosted vLLM backend. vLLM is +// OpenAI-compatible, so it uses the "vllm" catalog entry (KindCustom) and is +// reached over plain HTTP — no TLS anywhere on the path: +// +// client --tunnel--> netbird proxy --http--> vllm (:8000, OpenAI-compatible) +// +// The mock vLLM server answers /v1/chat/completions with an OpenAI-shaped +// completion carrying a non-zero usage block. The test asserts the chat returns +// 200 with the completion, that the request is recorded in the access log by its +// session id, and that vLLM's usage block is metered into a consumption row — +// which together prove request routing, response parsing, and token accounting +// all work for a self-hosted OpenAI-compatible provider. +// +// It needs no external credentials (the mock ignores auth), so it always runs. +func TestVLLMProvider(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + vllm, err := harness.StartVLLM(ctx, srv) + require.NoError(t, err, "start mock vLLM server") + t.Cleanup(func() { _ = vllm.Terminate(context.Background()) }) + + grp, err := srv.API().Groups.Create(ctx, api.PostApiGroupsJSONRequestBody{Name: "e2e-vllm"}) + require.NoError(t, err, "create group") + t.Cleanup(func() { _ = srv.API().Groups.Delete(context.Background(), grp.Id) }) + + ephemeral := false + sk, err := srv.API().SetupKeys.Create(ctx, api.PostApiSetupKeysJSONRequestBody{ + Name: "e2e-vllm-client", + Type: "reusable", + ExpiresIn: 86400, + UsageLimit: 0, + AutoGroups: []string{grp.Id}, + Ephemeral: &ephemeral, + }) + require.NoError(t, err, "mint setup key") + require.NotEmpty(t, sk.Key, "setup key plaintext") + + // vLLM provider pointed at the mock over plain HTTP. The mock ignores auth, + // so a dummy key satisfies the "Bearer ${API_KEY}" template. The served model + // is enumerated so the router dispatches this model string to this provider. + dummyKey := "sk-vllm-e2e" + prov, err := srv.CreateProvider(ctx, api.AgentNetworkProviderRequest{ + Name: "vllm", + ProviderId: "vllm", + UpstreamUrl: vllm.URL, + ApiKey: &dummyKey, + Enabled: ptr(true), + BootstrapCluster: ptr(harness.AgentNetworkCluster), + Models: &[]api.AgentNetworkProviderModel{ + {Id: harness.VLLMModel, InputPer1k: 0.001, OutputPer1k: 0.002}, + }, + }) + require.NoError(t, err, "create vllm provider") + t.Cleanup(func() { _ = srv.DeleteProvider(context.Background(), prov.Id) }) + + // Token limit far above the handful of tokens this test drives, so it never + // blocks but switches on usage metering — the switch that makes consumption + // rows get recorded. + enabled := true + pol, err := srv.CreatePolicy(ctx, api.AgentNetworkPolicyRequest{ + Name: "e2e-vllm-allow", + Enabled: &enabled, + SourceGroups: []string{grp.Id}, + DestinationProviderIds: []string{prov.Id}, + Limits: &api.AgentNetworkPolicyLimits{ + TokenLimit: api.AgentNetworkPolicyTokenLimit{ + Enabled: true, + GroupCap: 10_000_000, + UserCap: 10_000_000, + WindowSeconds: 60, + }, + }, + }) + require.NoError(t, err, "create policy") + t.Cleanup(func() { _ = srv.DeletePolicy(context.Background(), pol.Id) }) + + settings, err := srv.GetSettings(ctx) + require.NoError(t, err, "read settings") + require.NotEmpty(t, settings.Endpoint, "endpoint must be assigned") + + proxyToken, err := srv.CreateProxyTokenCLI(ctx, "e2e-vllm-proxy") + require.NoError(t, err, "mint proxy token") + px, err := harness.StartProxy(ctx, srv, proxyToken) + require.NoError(t, err, "start proxy") + t.Cleanup(func() { _ = px.Terminate(context.Background()) }) + + cl, err := harness.StartClient(ctx, srv, sk.Key) + require.NoError(t, err, "start client") + t.Cleanup(func() { _ = cl.Terminate(context.Background()) }) + + require.NoError(t, cl.WaitConnected(ctx, 90*time.Second), "client must connect to management") + if err := cl.WaitProxyPeer(ctx, 180*time.Second); err != nil { + t.Fatalf("client did not see the proxy peer: %v\n=== proxy logs ===\n%s", err, px.Logs(context.Background())) + } + proxyIP, err := cl.ResolveProxyIP(ctx, settings.Endpoint) + require.NoError(t, err, "resolve endpoint to proxy IP") + + before, _ := srv.ListAccessLogs(ctx) + sessionID := "e2e-session-vllm" + + // Retry to absorb tunnel/DNS jitter on the first call. + var code int + var body string + deadline := time.Now().Add(90 * time.Second) + for time.Now().Before(deadline) { + c, b, cerr := cl.Chat(ctx, settings.Endpoint, proxyIP, harness.WireChat, harness.VLLMModel, "Reply with exactly: pong", sessionID) + if cerr == nil { + code, body = c, b + if code == 200 { + break + } + } + time.Sleep(5 * time.Second) + } + require.Equal(t, 200, code, + "chat through the vLLM provider must return 200; body: %s\n=== vllm logs ===\n%s\n=== proxy logs ===\n%s", + body, vllm.Logs(context.Background()), px.Logs(context.Background())) + require.True(t, strings.Contains(body, "chat.completion"), + "body should be an OpenAI-compatible chat completion; got: %s", body) + + // The request must surface as an access-log row carrying our session id. + require.Eventually(t, func() bool { + logs, lerr := srv.ListAccessLogs(ctx) + return lerr == nil && logs.TotalRecords > before.TotalRecords + }, 30*time.Second, 2*time.Second, "an access-log row should be ingested for the vLLM provider") + + require.Eventually(t, func() bool { + logs, lerr := srv.ListAccessLogs(ctx) + if lerr != nil { + return false + } + for _, r := range logs.Data { + if r.SessionId != nil && *r.SessionId == sessionID { + return true + } + } + return false + }, 30*time.Second, 2*time.Second, "session id %q must be recorded in an access-log row", sessionID) + + // vLLM's usage block (prompt_tokens=11, completion_tokens=2) must be parsed + // and metered into a consumption row with positive token counts. + require.Eventually(t, func() bool { + rows, lerr := srv.ListConsumption(ctx) + if lerr != nil { + return false + } + for _, r := range rows { + if r.TokensInput > 0 && r.TokensOutput > 0 { + return true + } + } + return false + }, 60*time.Second, 3*time.Second, "vLLM usage must be metered into a consumption row") +} diff --git a/e2e/harness/vllm.go b/e2e/harness/vllm.go new file mode 100644 index 000000000..2f3d306cc --- /dev/null +++ b/e2e/harness/vllm.go @@ -0,0 +1,113 @@ +//go:build e2e + +package harness + +import ( + "context" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/docker/docker/api/types/container" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + vllmImage = "nginx:alpine" + vllmAlias = "vllm" + vllmPort = "8000/tcp" + // VLLMModel is the served model id the mock advertises and echoes back. It + // matches a real small model commonly served by vLLM so the provider's + // enumerated model and the client's request line up. + VLLMModel = "Qwen/Qwen2.5-0.5B-Instruct" +) + +// vllmNginxConf emulates a vLLM OpenAI-compatible server over plain HTTP (vLLM's +// default: no TLS, port 8000). It answers /v1/models with a one-model list and +// any chat/completions path with a canned OpenAI-shaped chat completion carrying +// a non-zero usage block, so the proxy's OpenAI parser records real token +// consumption. Running actual vLLM in CI is infeasible (GPU + multi-GB model +// download), so this stands in for the wire contract the proxy depends on. +const vllmNginxConf = `pid /tmp/nginx.pid; +events {} +http { + server { + listen 8000; + location = /v1/models { + default_type application/json; + return 200 '{"object":"list","data":[{"id":"Qwen/Qwen2.5-0.5B-Instruct","object":"model","owned_by":"vllm"}]}'; + } + location / { + default_type application/json; + return 200 '{"id":"chatcmpl-e2e-vllm","object":"chat.completion","created":1700000000,"model":"Qwen/Qwen2.5-0.5B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"pong"},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"completion_tokens":2,"total_tokens":13}}'; + } + } +} +` + +// VLLM is a mock vLLM OpenAI-compatible server on the combined server's network, +// reachable at http://vllm:8000. A "vllm" provider points at it to exercise the +// proxy's support for self-hosted OpenAI-compatible backends. +type VLLM struct { + container testcontainers.Container + workDir string + // URL is the upstream URL the vllm provider points at (http://:8000). + URL string +} + +// StartVLLM runs the mock vLLM server on the shared network over plain HTTP. +func StartVLLM(ctx context.Context, c *Combined) (*VLLM, error) { + workDir, err := os.MkdirTemp("/tmp", "nb-e2e-vllm-*") + if err != nil { + return nil, fmt.Errorf("create vllm work dir: %w", err) + } + // Widen so the (non-root worker) nginx container can traverse the bind mount. + if err := os.Chmod(workDir, 0o755); err != nil { //nolint:gosec // throwaway e2e config dir + return nil, fmt.Errorf("chmod vllm dir: %w", err) + } + if err := os.WriteFile(filepath.Join(workDir, "nginx.conf"), []byte(vllmNginxConf), 0o644); err != nil { //nolint:gosec // non-secret e2e config + return nil, fmt.Errorf("write nginx conf: %w", err) + } + + req := testcontainers.ContainerRequest{ + Image: vllmImage, + ExposedPorts: []string{vllmPort}, + Networks: []string{c.network.Name}, + NetworkAliases: map[string][]string{c.network.Name: {vllmAlias}}, + Cmd: []string{"nginx", "-c", "/conf/nginx.conf", "-g", "daemon off;"}, + HostConfigModifier: func(hc *container.HostConfig) { + hc.Binds = append(hc.Binds, workDir+":/conf:ro") + }, + WaitingFor: wait.ForListeningPort(vllmPort).WithStartupTimeout(60 * time.Second), + } + + ctr, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + if err != nil { + _ = os.RemoveAll(workDir) + return nil, fmt.Errorf("start vllm container: %w", err) + } + + return &VLLM{container: ctr, workDir: workDir, URL: "http://" + vllmAlias + ":8000"}, nil +} + +// Logs returns the vLLM container logs, for diagnostics on failure. +func (v *VLLM) Logs(ctx context.Context) string { + return containerLogs(ctx, v.container) +} + +// Terminate stops the vLLM container and cleans its work dir. +func (v *VLLM) Terminate(ctx context.Context) error { + var err error + if v.container != nil { + err = v.container.Terminate(ctx) + } + if v.workDir != "" { + _ = os.RemoveAll(v.workDir) + } + return err +} From 21aa9335846114319229ecf28e7faea13bf8fae5 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 2 Jul 2026 17:21:06 +0200 Subject: [PATCH 17/19] [misc] Fix GHCR image push after dockers_v2 migration (#6653) --- .github/workflows/release.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 16eae31fb..cd431a389 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -293,8 +293,11 @@ jobs: ${{ steps.goreleaser.outputs.artifacts }} JSON + # dockers_v2 artifacts have no top-level goarch field, so match the + # per-platform -amd64 tag suffix instead; it works for both the old + # dockers and the new dockers_v2 image naming. mapfile -t src_images < <( - jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json + jq -r '.[] | select(.type == "Docker Image") | .name | select(startswith("ghcr.io/") and endswith("-amd64"))' /tmp/goreleaser-artifacts.json ) for src in "${src_images[@]}"; do From 8e3b284f4bb374c830bd85041260cf3dcde5666e Mon Sep 17 00:00:00 2001 From: Riccardo Manfrin <3090891+riccardomanfrin@users.noreply.github.com> Date: Thu, 2 Jul 2026 17:50:18 +0200 Subject: [PATCH 18/19] [client] Increase mgmt grpc buff size to 16MB (#6641) --- shared/management/client/grpc.go | 15 ++++++++++----- shared/management/client/grpc_test.go | 8 ++++---- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index bd4585455..e3ba259c6 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -33,10 +33,15 @@ const ConnectTimeout = 10 * time.Second const healthCheckTimeout = 5 * time.Second const ( - // EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB) + // EnvMaxRecvMsgSize overrides the default gRPC max receive message size // for the management client connection. Value is in bytes. EnvMaxRecvMsgSize = "NB_MANAGEMENT_GRPC_MAX_MSG_SIZE" + // defaultMaxRecvMsgSize is the max gRPC receive message size used for the + // management client connection when EnvMaxRecvMsgSize is unset or invalid. + // It overrides the gRPC library default of 4 MB. + defaultMaxRecvMsgSize = 1024 * 1024 * 16 + errMsgMgmtPublicKey = "failed getting Management Service public key: %s" errMsgNoMgmtConnection = "no connection to management" ) @@ -84,22 +89,22 @@ type ExposeResponse struct { } // MaxRecvMsgSize returns the configured max gRPC receive message size from -// the environment, or 0 if unset (which uses the gRPC default of 4 MB). +// the environment, or defaultMaxRecvMsgSize (16 MB) if unset or invalid. func MaxRecvMsgSize() int { val := os.Getenv(EnvMaxRecvMsgSize) if val == "" { - return 0 + return defaultMaxRecvMsgSize } size, err := strconv.Atoi(val) if err != nil { log.Warnf("invalid %s value %q, using default: %v", EnvMaxRecvMsgSize, val, err) - return 0 + return defaultMaxRecvMsgSize } if size <= 0 { log.Warnf("invalid %s value %d, must be positive, using default", EnvMaxRecvMsgSize, size) - return 0 + return defaultMaxRecvMsgSize } return size diff --git a/shared/management/client/grpc_test.go b/shared/management/client/grpc_test.go index 462cc43af..c947130fd 100644 --- a/shared/management/client/grpc_test.go +++ b/shared/management/client/grpc_test.go @@ -21,11 +21,11 @@ func TestMaxRecvMsgSize(t *testing.T) { envValue string expected int }{ - {name: "unset returns 0", envValue: "", expected: 0}, + {name: "unset returns default", envValue: "", expected: defaultMaxRecvMsgSize}, {name: "valid value", envValue: "10485760", expected: 10485760}, - {name: "non-numeric returns 0", envValue: "abc", expected: 0}, - {name: "negative returns 0", envValue: "-1", expected: 0}, - {name: "zero returns 0", envValue: "0", expected: 0}, + {name: "non-numeric returns default", envValue: "abc", expected: defaultMaxRecvMsgSize}, + {name: "negative returns default", envValue: "-1", expected: defaultMaxRecvMsgSize}, + {name: "zero returns default", envValue: "0", expected: defaultMaxRecvMsgSize}, } for _, tt := range tests { From 4b3dd9103dcf4009fbb6b088fbc44fc0db760128 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 2 Jul 2026 20:42:43 +0200 Subject: [PATCH 19/19] [client] Fix slow wg operations (#6633) * [iface] Drop redundant device dump in kernel configure() wgctrl.ConfigureDevice already returns an error when the interface is missing, so the preceding wg.Device() existence check is redundant. That check dumps the entire device (all peers) on every configure() call, making it O(peers) per call and turning bulk peer insertion into O(peers^2): inserting N peers one by one re-parsed the whole growing peer list N times. Removing it keeps each peer write constant-time regardless of how many peers are already configured. * [iface] Cache WireGuard stats to collapse per-peer device dumps Each peer runs a WGWatcher that polls GetStats(), and every call dumps the whole device, so with N peers the watchers perform O(N) full dumps per poll cycle (O(N^2) work) while each keeps only its own peer's entry. Wrap the kernel and userspace configurer GetStats() in a short-TTL cache with singleflight: the staggered per-peer calls share a single device dump per window and concurrent misses collapse into one dump. The kernel and userspace WireGuard APIs have no per-peer stats query (a get always returns the whole device), so a shared cached snapshot avoids the repeated full dumps. * Ignore .claude directory --- .gitignore | 1 + client/iface/configurer/kernel_unix.go | 23 +++---- client/iface/configurer/stats_cache.go | 52 +++++++++++++++ client/iface/configurer/stats_cache_test.go | 70 +++++++++++++++++++++ client/iface/configurer/usp.go | 10 ++- 5 files changed, 144 insertions(+), 12 deletions(-) create mode 100644 client/iface/configurer/stats_cache.go create mode 100644 client/iface/configurer/stats_cache_test.go diff --git a/.gitignore b/.gitignore index 783fe77f3..305f3cb50 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.claude .idea .run *.iml diff --git a/client/iface/configurer/kernel_unix.go b/client/iface/configurer/kernel_unix.go index a29fe181a..da69c2a35 100644 --- a/client/iface/configurer/kernel_unix.go +++ b/client/iface/configurer/kernel_unix.go @@ -17,12 +17,15 @@ import ( type KernelConfigurer struct { deviceName string + statsCache *statsCache } func NewKernelConfigurer(deviceName string) *KernelConfigurer { - return &KernelConfigurer{ + c := &KernelConfigurer{ deviceName: deviceName, } + c.statsCache = newStatsCache(statsCacheTTL, c.fetchStats) + return c } func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error { @@ -246,12 +249,6 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error { } }() - // validate if device with name exists - _, err = wg.Device(c.deviceName) - if err != nil { - return err - } - return wg.ConfigureDevice(c.deviceName, config) } @@ -300,6 +297,14 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) { } func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) { + return c.statsCache.get() +} + +func (c *KernelConfigurer) LastActivities() map[string]monotime.Time { + return nil +} + +func (c *KernelConfigurer) fetchStats() (map[string]WGStats, error) { stats := make(map[string]WGStats) wg, err := wgctrl.New() if err != nil { @@ -326,7 +331,3 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) { } return stats, nil } - -func (c *KernelConfigurer) LastActivities() map[string]monotime.Time { - return nil -} diff --git a/client/iface/configurer/stats_cache.go b/client/iface/configurer/stats_cache.go new file mode 100644 index 000000000..71a4e88fc --- /dev/null +++ b/client/iface/configurer/stats_cache.go @@ -0,0 +1,52 @@ +package configurer + +import ( + "sync" + "time" + + "golang.org/x/sync/singleflight" +) + +const statsCacheTTL = 1 * time.Second + +type statsCache struct { + ttl time.Duration + fetch func() (map[string]WGStats, error) + + mu sync.RWMutex + value map[string]WGStats + expireAt time.Time + + sf singleflight.Group +} + +func newStatsCache(ttl time.Duration, fetch func() (map[string]WGStats, error)) *statsCache { + return &statsCache{ttl: ttl, fetch: fetch} +} + +func (c *statsCache) get() (map[string]WGStats, error) { + c.mu.RLock() + if c.value != nil && time.Now().Before(c.expireAt) { + value := c.value + c.mu.RUnlock() + return value, nil + } + c.mu.RUnlock() + + value, err, _ := c.sf.Do("stats", func() (interface{}, error) { + res, err := c.fetch() + if err != nil { + return nil, err + } + + c.mu.Lock() + c.value = res + c.expireAt = time.Now().Add(c.ttl) + c.mu.Unlock() + return res, nil + }) + if err != nil { + return nil, err + } + return value.(map[string]WGStats), nil +} diff --git a/client/iface/configurer/stats_cache_test.go b/client/iface/configurer/stats_cache_test.go new file mode 100644 index 000000000..bcee5cd52 --- /dev/null +++ b/client/iface/configurer/stats_cache_test.go @@ -0,0 +1,70 @@ +package configurer + +import ( + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestStatsCache_CachesWithinTTL(t *testing.T) { + var calls atomic.Int64 + c := newStatsCache(50*time.Millisecond, func() (map[string]WGStats, error) { + calls.Add(1) + return map[string]WGStats{"p": {}}, nil + }) + + for i := 0; i < 10; i++ { + _, err := c.get() + require.NoError(t, err) + } + require.Equal(t, int64(1), calls.Load(), "within TTL only one underlying fetch") + + time.Sleep(60 * time.Millisecond) + _, err := c.get() + require.NoError(t, err) + require.Equal(t, int64(2), calls.Load(), "after TTL expiry a fresh fetch happens") +} + +func TestStatsCache_SingleFlight(t *testing.T) { + var calls atomic.Int64 + release := make(chan struct{}) + c := newStatsCache(time.Minute, func() (map[string]WGStats, error) { + calls.Add(1) + <-release + return map[string]WGStats{}, nil + }) + + const n = 50 + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + _, _ = c.get() + }() + } + time.Sleep(20 * time.Millisecond) + close(release) + wg.Wait() + + require.Equal(t, int64(1), calls.Load(), "concurrent misses collapse into one fetch") +} + +func TestStatsCache_ErrorNotCached(t *testing.T) { + var calls atomic.Int64 + wantErr := errors.New("dump failed") + c := newStatsCache(time.Minute, func() (map[string]WGStats, error) { + calls.Add(1) + return nil, wantErr + }) + + _, err := c.get() + require.ErrorIs(t, err, wantErr) + _, err = c.get() + require.ErrorIs(t, err, wantErr) + require.Equal(t, int64(2), calls.Load(), "errors are not cached; each call retries") +} diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 9b070aab8..0a25c55bc 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -40,6 +40,7 @@ type WGUSPConfigurer struct { device *device.Device deviceName string activityRecorder *bind.ActivityRecorder + statsCache *statsCache uapiListener net.Listener } @@ -50,16 +51,19 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder deviceName: deviceName, activityRecorder: activityRecorder, } + wgCfg.statsCache = newStatsCache(statsCacheTTL, wgCfg.fetchStats) wgCfg.startUAPI() return wgCfg } func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer { - return &WGUSPConfigurer{ + wgCfg := &WGUSPConfigurer{ device: device, deviceName: deviceName, activityRecorder: activityRecorder, } + wgCfg.statsCache = newStatsCache(statsCacheTTL, wgCfg.fetchStats) + return wgCfg } func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error { @@ -348,6 +352,10 @@ func (t *WGUSPConfigurer) Close() { } func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) { + return t.statsCache.get() +} + +func (t *WGUSPConfigurer) fetchStats() (map[string]WGStats, error) { ipc, err := t.device.IpcGet() if err != nil { return nil, fmt.Errorf("ipc get: %w", err)