Compare commits

..

14 Commits

Author SHA1 Message Date
dependabot[bot]
e847ac95aa Bump the actions group across 1 directory with 6 updates
Bumps the actions group with 6 updates in the / directory:

| Package | From | To |
| --- | --- | --- |
| [actions/setup-go](https://github.com/actions/setup-go) | `6.4.0` | `6.5.0` |
| [actions/cache](https://github.com/actions/cache) | `5.0.5` | `6.1.0` |
| [vmactions/freebsd-vm](https://github.com/vmactions/freebsd-vm) | `1.4.8` | `1.5.0` |
| [actions/cache/restore](https://github.com/actions/cache) | `6.0.0` | `6.1.0` |
| [golangci/golangci-lint-action](https://github.com/golangci/golangci-lint-action) | `9.2.1` | `9.3.0` |
| [goreleaser/goreleaser-action](https://github.com/goreleaser/goreleaser-action) | `7.2.2` | `7.2.3` |



Updates `actions/setup-go` from 6.4.0 to 6.5.0
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v6.4.0...924ae3a1cded613372ab5595356fb5720e22ba16)

Updates `actions/cache` from 5.0.5 to 6.1.0
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](https://github.com/actions/cache/compare/v5.0.5...55cc8345863c7cc4c66a329aec7e433d2d1c52a9)

Updates `vmactions/freebsd-vm` from 1.4.8 to 1.5.0
- [Release notes](https://github.com/vmactions/freebsd-vm/releases)
- [Commits](b84ab5559b...5a72679103)

Updates `actions/cache/restore` from 6.0.0 to 6.1.0
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](2c8a9bd745...55cc834586)

Updates `golangci/golangci-lint-action` from 9.2.1 to 9.3.0
- [Release notes](https://github.com/golangci/golangci-lint-action/releases)
- [Commits](82606bf257...ba0d7d2ec0)

Updates `goreleaser/goreleaser-action` from 7.2.2 to 7.2.3
- [Release notes](https://github.com/goreleaser/goreleaser-action/releases)
- [Commits](5daf1e915a...f06c13b6b1)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-version: 6.5.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: actions
- dependency-name: actions/cache
  dependency-version: 6.1.0
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: actions
- dependency-name: vmactions/freebsd-vm
  dependency-version: 1.5.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: actions
- dependency-name: actions/cache/restore
  dependency-version: 6.1.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: actions
- dependency-name: golangci/golangci-lint-action
  dependency-version: 9.3.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: actions
- dependency-name: goreleaser/goreleaser-action
  dependency-version: 7.2.3
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: actions
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-07-01 10:49:49 +00:00
Maycon Santos
92a66cdd20 [management,proxy,client] 0.74.0 version (#6563)
* [management,proxy] Agent network: per-account LLM gateway (policy, metering, multi-provider) (#6555)

* [agent-network] Shared proto, OpenAPI schema, and generated types

* [agent-network] Management: store, manager, synthesizer, policy engine, provider catalog, HTTP/gRPC API

Adds the account-scoped agent-network module: provider/policy/budget CRUD and
store, the reverse-proxy service synthesizer, policy selection + limit
enforcement, the provider catalog (incl. Vertex AI and AWS Bedrock entries),
and the management HTTP + proxy gRPC surfaces.

* [management] Fix agent-network proxy-peer fan-out on affected-peer recompute

The affected-peers resolver loaded only persisted reverse-proxy services, but
agent-network services are synthesized on demand and never persisted. As a
result the embedded proxy peer was never folded into the affected set when a
client's group changed, so the proxy received no network-map update for a newly
authorised client and rejected its handshake until a full resync (restart).

loadProxyServices now merges the synthesized agent-network services (injected
via a registration hook to avoid an import cycle), so proxy peers learn newly
authorised clients immediately.

* [proxy] Reverse-proxy middleware framework, chain, and request plumbing

The per-target middleware chain (slots, dispatcher, mutation gate, metadata
merger), body capture, access-log terminal sink, and the proxy wiring that
builds + runs chains for synthesized agent-network services.

* [proxy] LLM parsers, pricing, and builtin middlewares (OpenAI, Anthropic, Vertex AI, AWS Bedrock)

Request/response parsers and SSE/event-stream metering, the embedded pricing
table, and the builtin middleware set: request parser, router, policy
limit-check/record, cost meter, guardrail, identity inject, response parser.
Includes the path-routed providers — Google Vertex AI (keyfile:: service-account
OAuth minting) and AWS Bedrock (bearer auth, invoke/converse/streaming, optional
/bedrock prefix) — plus the Models allowlist and unmeterable-publisher deny.

* [proxy] IPv6 in-place apply and TCP accept-loop hardening on netstack listeners

* [agent-network] End-to-end test suite, module docs, and deployment preset

* [agent-network] Fix codespell typos and exclude false positives

- labelgen word pool: vermillion -> vermilion, racoon -> raccoon.
- codespell ignore list: add flate (Go compress/flate package), recordin
  (a test-local identifier), and unparseable (a valid alternative spelling used
  consistently across identifiers + a metadata-value constant).

* [management] Set LastSeen on injected proxy peer in realstack test (MySQL strict-mode)

The injected embedded proxy peer had a PeerStatus with a zero LastSeen, which
serializes to '0000-00-00' and is rejected by MySQL in strict mode (SQLite
tolerates it). Set LastSeen to a valid time so SaveAccount succeeds on both
engines.

* [agent-network] Remove e2e shell-script suite from this branch

The end-to-end shell scripts under scripts/e2e/ are maintained in a separate
testing suite and are not part of this change set.

* [agent-network] Polish module docs: remove internal review scaffolding, fix links, verify diagrams

Strip PR-review framing, commit references, absolute paths, and stale internal
references from the agent-network module docs; fix broken relative links; verify
all diagrams against the current architecture. Remove the internal AI-reviewer
prompt file.

* [management] Refine session expiration handling to support 3-state encoding for SSO deadlines

* [agent-network] Relocate agentnetwork package to internals/modules

Move management/server/agentnetwork (and its catalog/, labelgen/, types/
subpackages) to management/internals/modules/agentnetwork, alongside the
reverse-proxy module, and rewrite all importers. Pure relocation: package names,
the synthesizer + affectedpeers registration hook, and store access (shared
store.Store) are unchanged, so no import cycle is introduced (affectedpeers
still depends only on the agentnetwork/types leaf).

* [agent-network] Co-locate HTTP handlers in the module (RegisterEndpoints)

Move the agent-network HTTP handlers from server/http/handlers/agentnetwork into
the module at internals/modules/agentnetwork/handlers (package handlers) and
rename the entrypoint AddEndpoints -> RegisterEndpoints, matching the
reverse-proxy module convention. Wiring in http/handler.go updated accordingly.

* Update getting started to point to rc when agent network enabled

* Add a reference to a commercial license

* Fix docs localhost link

* Fix docs localhost link

* Add private services domain note

* [management] Add agent-network telemetry metrics (#6561)

Surface agent-network adoption and usage in the self-hosted metrics
worker: distinct accounts, providers, policies, budget rules, accounts
with log collection enabled, and aggregated input/output tokens plus
cost.

Tokens and cost are summed from agent_network_request_usage (the
always-written per-request ledger) so the figures are accurate
regardless of the log-collection toggle and carry no double-counting.
All values come from a handful of indexed aggregate queries run only on
the worker's periodic tick.

Adds store.AgentNetworkMetrics with GetAgentNetworkMetrics on the Store
interface, the SqlStore implementation, and a zero-valued FileStore stub.

* Update NetBird server and proxy image versions to 0.74.0-rc.2

* [management,proxy] Reduce agent-network cognitive complexity (#6566)

Address the SonarCloud quality-gate findings in new agent-network code
by extracting focused helpers. No behavior change.

- synthesizer.go: split buildIdentityInjectConfigJSON into per-shape
  rule builders; extract mergeGuardrail from mergeGuardrails to cut
  nesting depth.
- llm_identity_inject: extract injectionEmitsAnything validation
  predicate from New.
- llm_response_parser/streaming.go: extract applyOpenAIStreamUsage and
  applyAnthropicStreamUsage (via a named anthropicStreamUsage type) and
  simplify the OpenAI scanner loop.
- reverseproxy.go: decompose ServeHTTP into serveRouteError,
  buildTargetContext, serveDirect, serveWithChain, captureRequestForChain,
  serveDeny, newResponseWriter, observeResponse, and forwardUpstream,
  preserving the defer ordering so response observation still reads the
  captured writer before it is released.

* [management] Move agent-network access-log ingest into the agentnetwork module (#6568)

The agent-network access-log ingest path (metaKey wire contract, flatten,
usage derivation, and the dual-write of the usage ledger + settings-gated
full row) lived in the reverseproxy accesslogs manager, even though the
agentnetwork module already owns the rest of that domain — types, read
(ListAccessLogs / GetUsageOverview), the budget-counter writes, and
retention cleanup.

Move it next to the rest: a stateless agentnetwork.IngestAccessLog(ctx,
store, entry) that the reverseproxy SaveAccessLog delegates to when the
entry is agent-network. Removes the agentNetworkTypes import from the
reverseproxy manager. No behavior change; the write/read table separation
is unchanged.

Adds real-store coverage for the disable->enable log-collection toggle
(usage ledger always written, full row gated) plus the metadata parse and
group-dedup helpers, which previously had no dedicated tests.

* Add session view support in the access log

* [management,proxy] Container-based agent-network e2e harness (#6577)

* [e2e] Add container-based agent-network e2e harness (Pillar 1)

Introduce a self-contained, OIDC-free e2e harness that stands up NetBird
in containers, so suites no longer depend on the hand-maintained Tilt
stack or a real IdP.

- harness brings up the combined server (management + signal + relay +
  STUN + embedded IdP) in a single container built from
  combined/Dockerfile.multistage, and mints an admin PAT through the
  unauthenticated /api/setup bootstrap (NB_SETUP_PAT_ENABLED). API access
  goes through the existing shared/management/client/rest typed client.
- the image is built via the docker CLI (BuildKit) so the Dockerfile's
  cache mounts are honored; testcontainers then runs the tagged image.
- everything is behind the `e2e` build tag so normal builds and unit
  tests never pull in testcontainers.

Adds BuildKit cache mounts to combined/Dockerfile.multistage so source
changes recompile incrementally rather than from scratch.

Pillar 1 proven by TestCombinedBootstrap: server builds, boots, mints a
PAT, and the PAT authenticates a real management API call.

* [e2e] Add management-side agent-network scenarios (Pillar 2)

Port the API-driven agent-network scenarios from the bash suites to Go,
sharing one combined server per package run (TestMain) with each test
owning its resource cleanup. Drives the /api/agent-network/* endpoints
through the shared REST client's NewRequest primitive with the generated
api types.

Scenarios:
- provider lifecycle (create/get/list/delete + 404 after delete)
- provider validation (missing api_key, unknown catalog id → 4xx)
- settings collection-toggle round-trip with cluster/subdomain immutability
- policy window floor (reject <60s enabled limit, accept at 60s)
- consumption read endpoint returns an array

All deterministic and dependency-free (dummy provider keys; no upstream
calls), so they run headless in CI.

* [e2e] Add live chat-through-proxy scenario (Pillar 3)

Stand up the full agent-network data path in containers and drive a real
chat-completion through the gateway:

- harness: a shared docker network (combined server reachable by alias),
  a proxy container built from the published reverse-proxy image
  (NB_PROXY_PRIVATE, NB_PROXY_ALLOW_INSECURE, NB_RELAY_TRANSPORT=ws to match
  the combined server's WS-multiplexed relay) with a generated self-signed
  wildcard cert, and a netbird client container that joins via a setup key.
- the combined image, proxy image, and client image default to the
  published rc.2 releases (overridable via NB_E2E_*_IMAGE; a bare local tag
  is built from source instead). Geolocation download is disabled so the
  server starts without external fetches.
- one shared domain is used for the management exposed address, the proxy
  domain, and the agent-network cluster; the proxy token is minted via the
  server CLI (global) to match the manual install.

TestChatCompletionThroughProxy provisions provider+policy+group+setup key,
runs proxy+client, drives an OpenAI chat-completion through the tunnel, and
asserts a 200 plus the ingested access-log row. Requires OPENAI_TOKEN
(skips otherwise). The provider must be created with enabled=true explicitly
— the create default is false despite the API doc.

* [e2e] Run the live chat scenario across a provider matrix

Replace the single-provider chat test with a data-driven matrix that runs
the same scenario through every provider whose credentials are present in
the environment (keys/URLs sourced from ~/.llm-keys locally, Actions
secrets in CI):

- OpenAI (chat), Anthropic (messages), Vercel, OpenRouter, Cloudflare
  (OpenAI-compatible gateways), and Bedrock (path-routed, bearer, via the
  messages shape) — covering both wire shapes and the gateway routing.
- all providers are created enabled with a unique model string so the
  proxy's connect-time snapshot carries them all and model->provider
  routing is unambiguous (provider toggles after connect don't reconcile
  to a connected proxy).
- the client supports both wire shapes (/v1/chat/completions and
  /v1/messages); Cloudflare gets the openai provider segment appended to
  its gateway URL.

Each provider must return 200 through the tunnel and produce an ingested
access-log row. Vertex is intentionally excluded from the uniform matrix:
it needs a bespoke rawPredict request shape rather than the shared
chat/messages path, so it warrants a dedicated scenario.

* [ci] Add manual workflow for the agent-network e2e suite

The e2e suite (build tag `e2e`) stands up the combined server + proxy +
client in Docker and drives live chat-completions, so it is slow and needs
provider credentials. Gate it out of normal CI (it already is, via the
build tag) and run it on demand via workflow_dispatch. Provider scenarios
skip when their secret is unset, so it degrades gracefully.

* [e2e] Add Vertex to the provider matrix; run e2e on ubuntu-latest

Vertex (Anthropic-on-Vertex) doesn't share the chat/messages wire shapes:
the model travels in a rawPredict path and the proxy mints the service
account's OAuth token. Add a Vertex client method that posts
/v1/projects/<project>/locations/<region>/publishers/anthropic/models/<model>:rawPredict
with the Vertex anthropic_version body, and wire it into the matrix as a
path-routed provider (created without a models array). It is keyed off
GOOGLE_VERTEX_SA_BASE64 + GOOGLE_VERTEX_PROJECT (region defaults to
"global", model to a pinned claude snapshot, both overridable).

Also bump the e2e workflow runner to ubuntu-latest and add the Vertex
secrets.

* Add docker/docker and docker/go-connections as direct dependencies in go.mod

* [ci] Trigger agent-network e2e workflow on push to main and pull requests

* [e2e] Fix proxy cert permission denied on Linux CI runners

The proxy bind-mounts a temp dir of self-signed certs. MkdirTemp creates
it 0700 and the key was 0600, which Docker Desktop on macOS ignores but a
non-root proxy container on Linux runners cannot traverse/read, so the
cert watcher failed with "open /certs/tls.crt: permission denied" and the
container exited. Widen the cert dir to 0755 and write the throwaway key
0644 so the proxy uid can read the bind-mounted material.

* [e2e] Build images from source by default instead of pulling rc.2

The agent-network code under test lives in this branch, so the e2e should
exercise it rather than a frozen published release. Flip the harness
default: combined/proxy/client are now built from their in-repo
Dockerfiles (combined/Dockerfile.multistage, proxy/Dockerfile.multistage,
e2e/harness/Dockerfile.client) under local tags. Pulling a published image
stays available by setting NB_E2E_*_IMAGE to a registry reference.

Builds now go through buildx --load so the Dockerfile cache mounts are
honored and the result is loaded for testcontainers. The CI workflow adds
a container-driver builder and a local layer cache (NB_E2E_BUILDX_CACHE)
persisted via actions/cache, which caches the base/apt/dep-download layers
across runs. The Go compile still re-runs each time, as BuildKit mount
caches cannot be exported to the GitHub cache.

* [e2e] Cover real providers in lifecycle + assert real consumption metering

- TestProviderLifecycle now runs per available real provider (create → get →
  list → delete → 404) instead of a single dummy provider, exercising each
  catalog's create and field round-trip. Create is offline, so it stays fast
  and burns no provider quota; falls back to a synthetic OpenAI provider when
  no keys are set.
- TestProvidersMatrix attaches a token limit (high caps, 60s window) to its
  policy, which switches on usage metering, and asserts consumption rows are
  recorded with positive token counts after the live traffic. Consumption is
  account-scoped (keyed by source group / user and window, not per provider),
  so the assertion is aggregate.
- TestProviderValidation gains invalid-upstream and blank-name cases. Create
  validation is uniform across catalogs (no per-provider required-field rules),
  so per-provider rejection cases would be redundant.

* [e2e] Assert session id propagates per provider

Each matrix request now sends a unique session id as the universal
x-session-id header and asserts it round-trips into that provider's
access-log row. This guards the session-grouping contract end to end for
every provider (header extraction runs in llm_request_parser ahead of the
parser-specific body extraction, so it is provider-agnostic).

* [e2e] Drop accidentally committed sync-phases dashboard

netbird-sync-phases.json was swept into the Pillar 1 commit by a broad
git add; it belongs to the unrelated sync-phases metrics work, not this
e2e harness. Remove it from the branch so the PR diff is scoped to the
e2e changes.

* [e2e] Revert accidentally committed sync-phase ingest spec

The netbird_sync_phase measurement spec in metrics ingest was swept into
the Pillar 1 commit; it belongs to the unrelated sync-phases metrics work,
not this e2e harness. Its emission side never landed here, so the spec was
orphaned anyway. Restore ingest/main.go to its origin/main state.

* Fix golint issues

* Fix sonar

* Add access log session test

* Fix access log tests

---------

Co-authored-by: braginini <bangvalo@gmail.com>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
2026-07-01 12:45:14 +02:00
Bethuel Mmbaga
3be90f06b2 [management] Add peer expiration reason to activity meta (#6619) 2026-07-01 12:31:46 +03:00
Viktor Liu
4ef65294e9 [client] Reinject captured first packet on lazy connection activation (#6572) 2026-06-30 11:22:25 +02:00
Bethuel Mmbaga
5b5f11740a [misc] Require on-premise EULA acceptance in enterprise scripts (#6596) 2026-06-30 11:34:23 +03:00
Riccardo Manfrin
3de889d529 [client] bound system info / posture-check gathering with a timeout to prevent sync-loop freeze (#6512)
* Wraps syestem info / posture checks into a goroutine with timeout

e.checks = checks is set before doing the SyncMeta,
so if it fails next time isCheckEquals compares true and bypasses
the update. This is to avoid another repeating the 15 seconds hang.
The checks will be synced on reconnect or posture checks changes
push from mgmt.

* Propagate context to OS calls that can leverage its cancellation / timeout

* Distinguish timeout from cancellation in logs

* Dont log twice

* Block on timeout failure and reapply the exclude_ips

* Refactor for complexity
2026-06-30 08:18:51 +02:00
Zoltan Papp
04c3d19032 [client] Skip firewall ruleset rebuild when config is unchanged (#6508)
* [client] Skip firewall ruleset rebuild when config is unchanged

ApplyFiltering rebuilt every peer and route ACL and flushed the firewall
on every sync, with no guard for an unchanged configuration. Management
re-sends the same network map far more often than it actually changes
(account-wide updates, peer meta churn), so on busy accounts this is the
dominant client-side cost of redundant syncs — especially with a large
route set and a userspace firewall.

Hash the inputs ApplyFiltering consumes (peer rules, route rules, the
empty flag and the dns-route feature flag) and skip the rebuild + flush
when the hash matches the last successfully applied update. Mirrors the
guard the DNS server already uses (previousConfigHash). The hash is only
recorded after apply and flush both succeed, so a failed update is not
skipped on the next (possibly identical) sync and gets a chance to
reconcile the firewall state.

* [client] Include config hash in ACL skip debug log

* [client] Include RoutesFirewallRulesIsEmpty in firewall config hash

* [client] Add benchmarks for firewall config hash computation
2026-06-29 19:51:50 +02:00
Zoltan Papp
3f1fb3b52d [ingest] raise duration validation limit to 24 hours (#6598)
Peer connection timing fields (signaling_to_connection_seconds) can
legitimately exceed 5 minutes during long reconnections; the previous
300 s cap caused valid data points to be rejected.
2026-06-29 19:51:25 +02:00
Viktor Liu
b434cda062 [client] Refresh signal receive liveness when worker handoff drains (#6594) 2026-06-29 12:16:47 +02:00
Zoltan Papp
0b594c639a [client] report management unhealthy while Sync stream is failing (#6575)
* fix(mgm): report management unhealthy while Sync stream is failing

The health probe (IsHealthy) only checked the gRPC transport and a
GetServerKey call. GetServerKey succeeds even when the peer cannot sync
(e.g. the server returns "settings not found"), so the probe kept marking
management Connected while the Sync stream failed in a tight retry loop —
pinning the status to "Connected" forever despite no sync ever succeeding.

Track the last Sync stream error and have IsHealthy consult it, so a
healthy transport is no longer enough to report the connection healthy.

* fix(mgm): record disconnected state when sync stream setup fails

The connectToSyncStream failure path in handleSyncStream returned early
without updating syncStreamErr, so the client could still report healthy
even when stream setup failed. Mirror the receiveUpdatesEvents error path
by calling notifyDisconnected and setSyncStreamDisconnected.
2026-06-29 11:28:58 +02:00
Zoltan Papp
deff8af59f [client] Wait for signal receive watchdog to stop before reconnect (#6574)
* [client] Wait for signal receive watchdog to stop before reconnect

The per-stream watchReceiveStream goroutine was started fire-and-forget
and never joined. On reconnect a lingering watchdog could still flip
shared client state (receiveStalled, the disconnect notifier) on the
freshly established stream, since cancelStream only cancels its own
stream context.

Track the watchdog with a WaitGroup and wait for it to exit (after
cancelling its stream) before the operation returns, so each reconnect
starts with no stale watchdog.

* [client] Bind signal receive probe to the stream context

The watchdog probe reused the generic Send, which derives its per-attempt
timeouts from the long-lived client context, so cancelStream could not
interrupt an in-flight probe. After joining the watchdog on reconnect,
watchdogWg.Wait() could then block for the full send-attempt chain.

Split Send into a context-aware send and pass the stream context down
through sendReceiveProbe, so cancelStream aborts any in-flight probe and
the watchdog exits promptly.
2026-06-29 11:24:25 +02:00
Riccardo Manfrin
5711f0e38c [client] add per-phase timing metrics for sync processing (#6533)
* Adds metrics sync phases time split to monitor costs

* Address review fixes

* Increment README.md with description on usage with debug bundles
2026-06-29 11:02:02 +02:00
Maycon Santos
1409a1325a [misc] Update careers page link (#6538) 2026-06-29 09:19:01 +02:00
Viktor Liu
4400372f37 [client] Forward non-address DNS record types through route forwarders (#6455) 2026-06-28 18:50:17 +02:00
389 changed files with 43602 additions and 23619 deletions

69
.github/workflows/agent-network-e2e.yml vendored Normal file
View File

@@ -0,0 +1,69 @@
name: Agent Network E2E
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
e2e:
name: Agent Network E2E
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
# Container-driver builder so the harness can build the combined/proxy/
# client images from source with a local layer cache.
- name: Set up Buildx
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4.1.0
# Persist the Docker layer cache across runs. This caches the base, apt,
# and go-mod-download layers; the Go compile still re-runs, as BuildKit
# mount caches cannot be exported to the GitHub cache.
- name: Cache Docker layers
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-anet-e2e-buildx-${{ hashFiles('go.sum', 'combined/Dockerfile.multistage', 'proxy/Dockerfile.multistage', 'e2e/harness/Dockerfile.client') }}
restore-keys: |
${{ runner.os }}-anet-e2e-buildx-
- name: Run agent-network e2e
env:
# Build the images from source (this branch's code) with the shared
# local layer cache.
NB_E2E_BUILDX_CACHE: /tmp/.buildx-cache
# Provider credentials. Each provider scenario skips if its
# token (and URL, for gateways) is unset, so partial coverage is fine.
OPENAI_TOKEN: ${{ secrets.E2E_OPENAI_TOKEN }}
ANTHROPIC_TOKEN: ${{ secrets.E2E_ANTHROPIC_TOKEN }}
VERCEL_URL: ${{ secrets.E2E_VERCEL_URL }}
VERCEL_TOKEN: ${{ secrets.E2E_VERCEL_TOKEN }}
OPENROUTER_URL: ${{ secrets.E2E_OPENROUTER_URL }}
OPENROUTER_TOKEN: ${{ secrets.E2E_OPENROUTER_TOKEN }}
CLOUDFLARE_URL: ${{ secrets.E2E_CLOUDFLARE_URL }}
CLOUDFLARE_TOKEN: ${{ secrets.E2E_CLOUDFLARE_TOKEN }}
AWS_BEARER_TOKEN_BEDROCK: ${{ secrets.E2E_AWS_BEARER_TOKEN_BEDROCK }}
AWS_REGION: ${{ secrets.E2E_AWS_REGION }}
# Vertex (Anthropic-on-Vertex): SA + project required; region defaults
# to "global", model to a pinned claude snapshot.
GOOGLE_VERTEX_SA_BASE64: ${{ secrets.E2E_GOOGLE_VERTEX_SA_BASE64 }}
GOOGLE_VERTEX_PROJECT: ${{ secrets.E2E_GOOGLE_VERTEX_PROJECT }}
GOOGLE_VERTEX_REGION: ${{ secrets.E2E_GOOGLE_VERTEX_REGION }}
GOOGLE_VERTEX_MODEL: ${{ secrets.E2E_GOOGLE_VERTEX_MODEL }}
run: go test -tags e2e -timeout 40m -v ./e2e/...

View File

@@ -27,7 +27,7 @@ jobs:
cache: false
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }}

View File

@@ -28,7 +28,7 @@ jobs:
id: test
env:
GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
uses: vmactions/freebsd-vm@5a72679103d223925653750faa878a143340fbd0 # v1.5.0
with:
usesh: true
copyback: false

View File

@@ -41,7 +41,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
id: cache
with:
path: |
@@ -135,7 +135,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}
@@ -192,7 +192,7 @@ jobs:
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
id: cache-restore
with:
path: |
@@ -266,7 +266,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}
@@ -325,7 +325,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}
@@ -383,7 +383,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}
@@ -440,7 +440,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}
@@ -545,7 +545,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}
@@ -640,7 +640,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}
@@ -710,7 +710,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}

View File

@@ -35,7 +35,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}

View File

@@ -21,7 +21,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals,flate,recordin,unparseable
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:
@@ -56,7 +56,7 @@ jobs:
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
uses: golangci/golangci-lint-action@ba0d7d2ec06a0ea1cb5fa41b2e4a3ab91d21278a #v9.3.0
with:
version: latest
skip-cache: true

View File

@@ -34,7 +34,7 @@ jobs:
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620

View File

@@ -64,7 +64,7 @@ jobs:
if: steps.check_diff.outputs.diff_exists == 'true'
env:
GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
uses: vmactions/freebsd-vm@5a72679103d223925653750faa878a143340fbd0 # v1.5.0
with:
usesh: true
copyback: false
@@ -171,7 +171,7 @@ jobs:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
~/go/pkg/mod
@@ -221,7 +221,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
uses: goreleaser/goreleaser-action@f06c13b6b1a9625abc9e6e439d9c05a8f2190e94 # v7.2.3
with:
version: ${{ env.GORELEASER_VER }}
args: release --clean ${{ env.flags }}
@@ -379,7 +379,7 @@ jobs:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
~/go/pkg/mod
@@ -420,7 +420,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
uses: goreleaser/goreleaser-action@f06c13b6b1a9625abc9e6e439d9c05a8f2190e94 # v7.2.3
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
@@ -474,7 +474,7 @@ jobs:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
~/go/pkg/mod
@@ -488,7 +488,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
uses: goreleaser/goreleaser-action@f06c13b6b1a9625abc9e6e439d9c05a8f2190e94 # v7.2.3
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}

View File

@@ -78,7 +78,7 @@ jobs:
go-version-file: "go.mod"
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}

View File

@@ -29,7 +29,7 @@ jobs:
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
uses: golangci/golangci-lint-action@ba0d7d2ec06a0ea1cb5fa41b2e4a3ab91d21278a #v9.3.0
with:
version: latest
install-mode: binary

View File

@@ -33,7 +33,7 @@
<br/>
<br/>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
🚀 <a href="https://netbird.io/careers">We are hiring! Join us at https://netbird.io/careers</a>
</strong>
</p>

View File

@@ -401,12 +401,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
req.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(disableVNCApprovalFlag).Changed {
req.DisableVNCApproval = &disableVNCApproval
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
@@ -513,14 +507,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
ic.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(disableVNCApprovalFlag).Changed {
ic.DisableVNCApproval = &disableVNCApproval
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
}
applySSHFlagsToConfig(cmd, &ic)
if cmd.Flag(enableSSHSFTPFlag).Changed {
ic.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
@@ -596,49 +606,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
return &ic, nil
}
func applySSHFlagsToConfig(cmd *cobra.Command, ic *profilemanager.ConfigInput) {
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
ic.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
}
func applySSHFlagsToLogin(cmd *cobra.Command, req *proto.LoginRequest) {
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
req.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ttl := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &ttl
}
}
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
@@ -668,14 +635,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(serverSSHAllowedFlag).Changed {
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
loginRequest.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(disableVNCApprovalFlag).Changed {
loginRequest.DisableVNCApproval = &disableVNCApproval
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot
}
applySSHFlagsToLogin(cmd, &loginRequest)
if cmd.Flag(enableSSHSFTPFlag).Changed {
loginRequest.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled

View File

@@ -1,100 +0,0 @@
//go:build windows || (darwin && !ios)
package cmd
import (
"fmt"
"net"
"net/netip"
"os"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
var (
vncAgentSocket string
vncAgentTargetUID uint32
)
func init() {
vncAgentCmd.Flags().StringVar(&vncAgentSocket, "socket", "", "Unix-domain socket path the agent listens on (required)")
vncAgentCmd.Flags().Uint32Var(&vncAgentTargetUID, "target-uid", 0, "uid the agent should drop privileges to before listening (darwin only; 0 = stay as current uid)")
rootCmd.AddCommand(vncAgentCmd)
}
// vncAgentCmd runs a VNC server inside the user's interactive session,
// listening on a Unix-domain socket. The NetBird service spawns it: on
// Windows via CreateProcessAsUser into the console session, on macOS via
// launchctl asuser into the Aqua session.
var vncAgentCmd = &cobra.Command{
Use: "vnc-agent",
Short: "Run VNC capture agent (internal, spawned by service)",
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
log.SetReportCaller(true)
log.SetFormatter(&log.JSONFormatter{})
log.SetOutput(os.Stderr)
if vncAgentSocket == "" {
return fmt.Errorf("--socket is required")
}
token := os.Getenv("NB_VNC_AGENT_TOKEN")
if token == "" {
return fmt.Errorf("NB_VNC_AGENT_TOKEN not set; agent requires a token from the service")
}
// Purge the token from env so it doesn't leak via /proc/<pid>/environ.
if err := os.Unsetenv("NB_VNC_AGENT_TOKEN"); err != nil {
log.Debugf("unset NB_VNC_AGENT_TOKEN: %v", err)
}
// Drop root privileges to the target console user BEFORE creating
// the listening socket: keeps a post-auth bug in the encoder /
// input / capture paths confined to the user's own privileges
// rather than escalating to host root, and makes the daemon's
// LOCAL_PEERCRED check see the right uid. No-op on Windows
// (both processes run as SYSTEM) and when --target-uid is 0.
if vncAgentTargetUID != 0 {
if err := dropAgentPrivileges(vncAgentTargetUID); err != nil {
return fmt.Errorf("drop privileges to uid %d: %w", vncAgentTargetUID, err)
}
}
if err := os.Remove(vncAgentSocket); err != nil && !os.IsNotExist(err) {
log.Debugf("remove stale socket %s: %v", vncAgentSocket, err)
}
ln, err := net.Listen("unix", vncAgentSocket)
if err != nil {
return fmt.Errorf("listen on %s: %w", vncAgentSocket, err)
}
if err := os.Chmod(vncAgentSocket, 0o600); err != nil {
log.Debugf("chmod %s: %v", vncAgentSocket, err)
}
capturer, injector, err := newAgentResources()
if err != nil {
_ = ln.Close()
return err
}
srv := vncserver.New(vncserver.Config{
Capturer: capturer,
Injector: injector,
DisableAuth: true,
AgentTokenHex: token,
Listener: ln,
})
if err := srv.Start(cmd.Context(), netip.AddrPort{}, netip.Prefix{}); err != nil {
return fmt.Errorf("start vnc server: %w", err)
}
log.Infof("vnc-agent listening on %s, ready", vncAgentSocket)
<-cmd.Context().Done()
log.Info("vnc-agent context cancelled, shutting down")
return srv.Stop()
},
SilenceUsage: true,
}

View File

@@ -1,18 +0,0 @@
//go:build darwin && !ios
package cmd
import (
"fmt"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
capturer := vncserver.NewMacPoller()
injector, err := vncserver.NewMacInputInjector()
if err != nil {
return nil, nil, fmt.Errorf("macOS input injector: %w", err)
}
return capturer, injector, nil
}

View File

@@ -1,74 +0,0 @@
//go:build darwin && !ios
package cmd
import (
"fmt"
"os"
"os/user"
"strconv"
"syscall"
)
// dropAgentPrivileges drops the vnc-agent process from root (its
// launchctl-asuser-inherited starting uid) to the target console user
// before any other initialisation runs. Without this the agent runs as
// root for the lifetime of the session; any post-auth memory-safety
// issue in the capture/input/encode paths would then be a root-level
// RCE on the host instead of a user-level one. Also makes the daemon's
// LOCAL_PEERCRED check correctly identify the agent as the console user,
// not as root.
//
// Returns an error when the agent is running as a non-root uid that
// differs from targetUID: non-root can only setuid to itself, so a
// mismatch here means the spawn went to the wrong session.
func dropAgentPrivileges(targetUID uint32) error {
if targetUID == 0 {
return fmt.Errorf("refusing to keep agent running as root (target uid 0)")
}
cur := uint32(os.Getuid())
if cur == targetUID {
return nil
}
if cur != 0 {
return fmt.Errorf("agent uid %d does not match expected %d and we lack root to fix it", cur, targetUID)
}
// Resolve the target user's real primary group rather than reusing
// targetUID as the gid: a user's primary group on macOS is typically
// staff(20), not gid==uid. Fail closed if the lookup fails.
targetGID, err := primaryGroupID(targetUID)
if err != nil {
return err
}
// Drop supplementary groups first: setgid alone doesn't touch the
// auxiliary group list, leaving root's groups attached would let the
// dropped process write to root-only group-writable files.
if err := syscall.Setgroups([]int{}); err != nil {
return fmt.Errorf("setgroups([]): %w", err)
}
if err := syscall.Setgid(targetGID); err != nil {
return fmt.Errorf("setgid(%d): %w", targetGID, err)
}
if err := syscall.Setuid(int(targetUID)); err != nil {
return fmt.Errorf("setuid(%d): %w", targetUID, err)
}
if uint32(os.Getuid()) != targetUID || uint32(os.Geteuid()) != targetUID {
return fmt.Errorf("setuid verification: uid=%d euid=%d, expected %d", os.Getuid(), os.Geteuid(), targetUID)
}
return nil
}
// primaryGroupID resolves the real primary group id of the user with the
// given uid. Fails closed: a lookup or parse error returns an error so the
// caller never falls back to using uid as the gid.
func primaryGroupID(targetUID uint32) (int, error) {
u, err := user.LookupId(strconv.Itoa(int(targetUID)))
if err != nil {
return 0, fmt.Errorf("look up uid %d: %w", targetUID, err)
}
gid, err := strconv.Atoi(u.Gid)
if err != nil {
return 0, fmt.Errorf("parse gid %q for uid %d: %w", u.Gid, targetUID, err)
}
return gid, nil
}

View File

@@ -1,55 +0,0 @@
//go:build darwin && !ios
package cmd
import (
"strings"
"testing"
)
// TestDropAgentPrivileges_RefusesRootTarget locks in the contract that
// dropAgentPrivileges must never be a no-op when asked to keep the
// agent as root (target uid 0). A future caller that passes 0 by
// mistake would otherwise leave the post-auth attack surface running
// with full root privileges.
func TestDropAgentPrivileges_RefusesRootTarget(t *testing.T) {
err := dropAgentPrivileges(0)
if err == nil {
t.Fatal("expected refusal for target uid 0, got nil")
}
if !strings.Contains(err.Error(), "root") {
t.Fatalf("error should mention root, got: %v", err)
}
}
// TestDropAgentPrivileges_NoOpWhenAlreadyTarget covers the dev path
// where the agent is launched by hand as the target user (no root
// available, no setuid needed). The helper must succeed silently
// instead of trying (and failing) a setuid to its current uid.
func TestDropAgentPrivileges_NoOpWhenAlreadyTarget(t *testing.T) {
// Skip when running as root: the early-return path we want to
// cover only fires when current uid == target uid.
uid := currentUIDForTest()
if uid == 0 {
t.Skip("test must not run as root; cannot exercise the no-op early-return")
}
if err := dropAgentPrivileges(uid); err != nil {
t.Fatalf("expected no-op when current uid == target, got: %v", err)
}
}
// TestDropAgentPrivileges_RefusesMismatchedNonRoot guards the "non-root
// caller tries to setuid to a different uid" path: setuid would fail
// with EPERM anyway, but the helper should surface a clear error
// before issuing the syscall so a misconfigured spawn (wrong --target-uid
// flag) is debuggable.
func TestDropAgentPrivileges_RefusesMismatchedNonRoot(t *testing.T) {
uid := currentUIDForTest()
if uid == 0 {
t.Skip("test must not run as root; covered case requires non-root caller")
}
err := dropAgentPrivileges(uid + 1)
if err == nil {
t.Fatal("expected refusal when non-root caller asks to setuid elsewhere")
}
}

View File

@@ -1,11 +0,0 @@
//go:build darwin && !ios
package cmd
import "os"
// currentUIDForTest exposes os.Getuid for the darwin dropprivs tests
// without leaking an os import into the test file itself.
func currentUIDForTest() uint32 {
return uint32(os.Getuid())
}

View File

@@ -1,14 +0,0 @@
//go:build windows
package cmd
// dropAgentPrivileges is a no-op on Windows: the agent and the daemon
// both run as SYSTEM (the daemon spawns the agent into the interactive
// session via CreateProcessAsUser with an impersonation token, but the
// resulting process still runs under SYSTEM, not under the user's
// account). The Windows path relies on the DACL-restricted socket
// directory, the unpredictable per-spawn socket name, the listen-readiness
// gate, and the per-spawn token for integrity instead.
func dropAgentPrivileges(_ uint32) error {
return nil
}

View File

@@ -1,15 +0,0 @@
//go:build windows
package cmd
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
sessionID := vncserver.GetCurrentSessionID()
log.Infof("VNC agent running in Windows session %d", sessionID)
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), nil
}

View File

@@ -1,16 +0,0 @@
package cmd
const (
serverVNCAllowedFlag = "allow-server-vnc"
disableVNCApprovalFlag = "disable-vnc-approval"
)
var (
serverVNCAllowed bool
disableVNCApproval bool
)
func init() {
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
upCmd.PersistentFlags().BoolVar(&disableVNCApproval, disableVNCApprovalFlag, false, "Disable per-connection user approval prompts for the embedded VNC server")
}

View File

@@ -6,30 +6,19 @@ import (
"runtime"
)
var (
// StateDir holds persistent state (config, profiles, install metadata).
StateDir string
// RuntimeDir holds ephemeral artifacts that should not survive reboot,
// such as Unix sockets for daemon and per-session IPC. Empty on
// platforms without a conventional /var/run-style location.
RuntimeDir string
)
var StateDir string
func init() {
StateDir = os.Getenv("NB_STATE_DIR")
if StateDir != "" {
return
}
switch runtime.GOOS {
case "windows":
StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird")
case "darwin", "linux":
StateDir = "/var/lib/netbird"
RuntimeDir = "/var/run/netbird"
case "freebsd", "openbsd", "netbsd", "dragonfly":
StateDir = "/var/db/netbird"
RuntimeDir = "/var/run/netbird"
}
if v := os.Getenv("NB_STATE_DIR"); v != "" {
StateDir = v
}
if v := os.Getenv("NB_RUNTIME_DIR"); v != "" {
RuntimeDir = v
}
}

View File

@@ -136,6 +136,11 @@ func (p *ProxyBind) CloseConn() error {
return p.close()
}
// InjectPacket is a no-op for the userspace proxy: first-packet reinjection is kernel-only.
func (p *ProxyBind) InjectPacket(_ []byte) error {
return nil
}
func (p *ProxyBind) close() error {
if p.remoteConn == nil {
return nil

View File

@@ -219,6 +219,17 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Unlock()
}
// InjectPacket writes b to the remote peer over the underlying transport.
func (p *ProxyWrapper) InjectPacket(b []byte) error {
if p.remoteConn == nil {
return errors.New("proxy not started")
}
if _, err := p.remoteConn.Write(b); err != nil {
return err
}
return nil
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (p *ProxyWrapper) CloseConn() error {
if p.cancel == nil {

View File

@@ -18,4 +18,9 @@ type Proxy interface {
RedirectAs(endpoint *net.UDPAddr)
CloseConn() error
SetDisconnectListener(disconnected func())
// InjectPacket writes a raw packet directly to the remote peer over the underlying transport,
// bypassing WireGuard. Used to replay the captured lazyconn handshake initiation. Only the
// kernel-mode proxies act on it; the userspace proxy is a no-op since reinjection is kernel-only.
InjectPacket(b []byte) error
}

View File

@@ -147,6 +147,17 @@ func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
p.sendPkg = p.srcFakerConn.SendPkg
}
// InjectPacket writes b to the remote peer over the underlying transport.
func (p *WGUDPProxy) InjectPacket(b []byte) error {
if p.remoteConn == nil {
return errors.New("proxy not started")
}
if _, err := p.remoteConn.Write(b); err != nil {
return err
}
return nil
}
// CloseConn close the localConn
func (p *WGUDPProxy) CloseConn() error {
if p.cancel == nil {

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/hashicorp/go-multierror"
"github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
@@ -30,11 +31,13 @@ type Manager interface {
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]struct{}
mutex sync.Mutex
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]struct{}
previousConfigHash uint64
hasAppliedConfig bool
mutex sync.Mutex
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
@@ -57,6 +60,23 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
return
}
// Skip the full rebuild + flush when the inputs that drive the firewall
// state are byte-for-byte identical to the last successfully applied
// update. Management re-sends the same network map far more often than it
// actually changes (account-wide updates, peer meta churn), and rebuilding
// every peer/route ACL and flushing the firewall on every such sync is the
// dominant client-side cost when nothing changed. Mirrors the same guard the
// DNS server already uses (previousConfigHash). Only the fields ApplyFiltering
// consumes participate in the hash, so an unrelated map change cannot mask a
// real ACL change.
hash, err := d.firewallConfigHash(networkMap, dnsRouteFeatureFlag)
if err != nil {
log.Errorf("unable to hash firewall configuration, applying unconditionally: %v", err)
} else if d.hasAppliedConfig && d.previousConfigHash == hash {
log.Debugf("not applying the firewall configuration update as there is nothing new (hash: %d)", hash)
return
}
start := time.Now()
defer func() {
total := 0
@@ -70,13 +90,49 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
d.applyPeerACLs(networkMap)
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err)
routeErr := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag)
if routeErr != nil {
log.Errorf("Failed to apply route ACLs: %v", routeErr)
}
if err := d.firewall.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
flushErr := d.firewall.Flush()
if flushErr != nil {
log.Error("failed to flush firewall rules: ", flushErr)
}
// Only remember the hash once the firewall actually reflects this config.
// If applying or flushing failed, leave the previous hash untouched so the
// next (possibly identical) update is not skipped and gets a chance to
// reconcile the firewall state.
if err == nil && routeErr == nil && flushErr == nil {
d.previousConfigHash = hash
d.hasAppliedConfig = true
} else {
d.hasAppliedConfig = false
}
}
// firewallConfigHash hashes exactly the inputs ApplyFiltering uses to build the
// firewall state, so an identical hash means an identical resulting ruleset.
func (d *DefaultManager) firewallConfigHash(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) (uint64, error) {
return hashstructure.Hash(struct {
PeerRules []*mgmProto.FirewallRule
PeerRulesIsEmpty bool
RouteRules []*mgmProto.RouteFirewallRule
RouteRulesIsEmpty bool
DNSRouteFeatureFlag bool
}{
PeerRules: networkMap.GetFirewallRules(),
PeerRulesIsEmpty: networkMap.GetFirewallRulesIsEmpty(),
RouteRules: networkMap.GetRoutesFirewallRules(),
RouteRulesIsEmpty: networkMap.GetRoutesFirewallRulesIsEmpty(),
DNSRouteFeatureFlag: dnsRouteFeatureFlag,
}, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
}
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {

View File

@@ -1,6 +1,7 @@
package acl
import (
"fmt"
"net/netip"
"testing"
@@ -485,3 +486,149 @@ func TestPortInfoEmpty(t *testing.T) {
})
}
}
// TestApplyFilteringSkipsUnchangedConfig verifies that an identical network map
// re-applied is recognized as a no-op (hash unchanged), while a real change to
// any firewall-relevant input forces a re-apply (hash changes). This is the
// guard that prevents a full ruleset rebuild + flush on every redundant sync.
func TestApplyFilteringSkipsUnchangedConfig(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap, false)
require.True(t, acl.hasAppliedConfig, "config should be marked applied after first apply")
firstHash := acl.previousConfigHash
require.NotZero(t, firstHash)
// Re-applying the identical map must not change the recorded hash: the
// expensive rebuild path was skipped.
acl.ApplyFiltering(networkMap, false)
assert.Equal(t, firstHash, acl.previousConfigHash,
"identical re-apply must be a no-op (hash unchanged)")
// A real change must produce a different hash and re-apply.
networkMap.FirewallRules[0].Action = mgmProto.RuleAction_DROP
acl.ApplyFiltering(networkMap, false)
assert.NotEqual(t, firstHash, acl.previousConfigHash,
"changing a rule's action must force a re-apply (hash changed)")
// The dnsRouteFeatureFlag also participates in the hash.
changedHash := acl.previousConfigHash
acl.ApplyFiltering(networkMap, true)
assert.NotEqual(t, changedHash, acl.previousConfigHash,
"flipping dnsRouteFeatureFlag must force a re-apply (hash changed)")
}
func buildNetworkMap(peerRules, routeRules int) *mgmProto.NetworkMap {
nm := &mgmProto.NetworkMap{
FirewallRulesIsEmpty: peerRules == 0,
RoutesFirewallRulesIsEmpty: routeRules == 0,
}
for i := range peerRules {
nm.FirewallRules = append(nm.FirewallRules, &mgmProto.FirewallRule{
PeerIP: fmt.Sprintf("10.%d.%d.%d", i>>16&0xff, i>>8&0xff, i&0xff),
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: fmt.Sprintf("%d", 1024+i%64511),
})
}
for i := range routeRules {
nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, &mgmProto.RouteFirewallRule{
Destination: fmt.Sprintf("192.168.%d.0/24", i%256),
SourceRanges: []string{fmt.Sprintf("10.0.%d.0/24", i%256)},
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
})
}
return nm
}
func BenchmarkFirewallConfigHash_Small(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(10, 5)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
func BenchmarkFirewallConfigHash_Medium(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(100, 50)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
func BenchmarkFirewallConfigHash_Large(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(1000, 200)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
// TestFirewallConfigHashDeterministic verifies the hash is stable for equal
// inputs and order-independent for the rule slices (management does not
// guarantee rule order).
func TestFirewallConfigHashDeterministic(t *testing.T) {
d := &DefaultManager{}
nm1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{PeerIP: "10.0.0.1", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_ACCEPT, Protocol: mgmProto.RuleProtocol_TCP, Port: "22"},
{PeerIP: "10.0.0.2", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_DROP, Protocol: mgmProto.RuleProtocol_TCP, Port: "80"},
},
}
// Same rules, reversed order.
nm2 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
nm1.FirewallRules[1],
nm1.FirewallRules[0],
},
}
h1, err := d.firewallConfigHash(nm1, false)
require.NoError(t, err)
h2, err := d.firewallConfigHash(nm2, false)
require.NoError(t, err)
assert.Equal(t, h1, h2, "hash must be order-independent for rule slices")
}

View File

@@ -1,219 +0,0 @@
// Package approval brokers per-attempt user-accept prompts for inbound
// remote access (VNC today, SSH and others in the future). A caller pushes
// a Prompt; the broker emits a SystemEvent on the daemon→UI stream and
// blocks until the UI calls the daemon's RespondApproval RPC, the per-
// request timeout fires, or no subscriber is connected. The latter case
// fails closed so a backgrounded UI cannot silently bypass the gate.
package approval
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
)
// Metadata keys the broker reserves on the emitted SystemEvent. Callers
// should not set these themselves; values in Prompt.Metadata that collide
// are overwritten by the broker.
const (
MetaRequestID = "request_id"
MetaKind = "kind"
MetaExpiresAt = "expires_at"
)
// ShortKeyFingerprint formats a hex-encoded Noise_IK static pubkey as a
// short, eyeball-able fingerprint to display in the approval dialog.
// The dashboard-supplied display name attached to a SessionPubKey isn't
// cryptographically asserted by the connecting client, so the prompt
// must also show something that IS: the key fingerprint, a hash of
// the static public key the client just proved possession of during the
// Noise handshake. Returns the empty string when the input is too short
// to plausibly be a hex pubkey, so the row is omitted rather than
// rendered as a misleading partial.
//
// Output format: 16 hex chars grouped as XXXX-XXXX-XXXX-XXXX (64 bits of
// fingerprint, resistant to random-prefix collisions and easy for a human
// to compare with an out-of-band reference).
func ShortKeyFingerprint(hexKey string) string {
if len(hexKey) < 8 {
return ""
}
src := hexKey
if len(src) > 16 {
src = src[:16]
}
var out []byte
for i, c := range src {
if i > 0 && i%4 == 0 {
out = append(out, '-')
}
out = append(out, byte(c))
}
return string(out)
}
// Kind values for the well-known prompt subjects. New subsystems should
// add a constant here so the UI can dispatch on a known string.
const (
KindVNC = "vnc"
KindSSH = "ssh"
)
// DefaultTimeout is the wall-clock window the user has to accept or deny a
// pending approval before the broker fails closed and returns ErrTimeout.
// Kept well under typical VNC client and dashboard connection timeouts so
// the RFB rejection actually reaches the browser instead of racing the
// browser's own "connection timed out" message.
const DefaultTimeout = 15 * time.Second
// timeoutValue returns the active timeout. It's a var so tests in this
// package can shorten the wait without exposing a setter on the public
// API. Production code always sees DefaultTimeout.
var timeoutValue = func() time.Duration { return DefaultTimeout }
// ErrNoSubscriber indicates no UI is connected to consume the prompt.
// The caller must reject the underlying connection (fail-closed).
var ErrNoSubscriber = errors.New("no UI subscriber connected for approval")
// ErrTimeout indicates the user did not respond within DefaultTimeout.
var ErrTimeout = errors.New("approval timed out")
// ErrDenied indicates the user explicitly denied the connection.
var ErrDenied = errors.New("approval denied")
// EventPublisher is the subset of peer.Status used to emit prompts.
type EventPublisher interface {
PublishEvent(
severity proto.SystemEvent_Severity,
category proto.SystemEvent_Category,
msg string,
userMsg string,
metadata map[string]string,
)
HasEventSubscribers() bool
}
// Prompt describes the pending request shown to the user. Kind selects
// the UI dispatch path (e.g. "vnc", "ssh"). Subject is the human-readable
// one-liner the UI may show as a title or notification body. Metadata is
// passed through verbatim and is the subsystem-specific payload (peer
// name, source IP, mode, etc.).
type Prompt struct {
Kind string
Subject string
Metadata map[string]string
}
// Decision carries the user's response to an approval prompt. ViewOnly is
// only meaningful when Accept is true; it lets the host grant the
// connection but signal the requester that input control is withheld.
type Decision struct {
Accept bool
ViewOnly bool
}
// Broker holds in-flight approval requests keyed by request ID.
type Broker struct {
pub EventPublisher
mu sync.Mutex
pending map[string]chan Decision
}
// New returns a broker that publishes prompts via pub.
func New(pub EventPublisher) *Broker {
return &Broker{
pub: pub,
pending: make(map[string]chan Decision),
}
}
// Request emits a SystemEvent for p and blocks until the UI calls Respond,
// ctx is cancelled, or DefaultTimeout elapses. Returns a Decision when
// the user replied; ErrDenied / ErrTimeout / ErrNoSubscriber / ctx.Err
// otherwise. Callers must treat any non-nil error as a deny.
func (b *Broker) Request(ctx context.Context, p Prompt) (Decision, error) {
var zero Decision
if b == nil || b.pub == nil {
return zero, fmt.Errorf("approval broker not configured")
}
if !b.pub.HasEventSubscribers() {
return zero, ErrNoSubscriber
}
id := uuid.NewString()
resp := make(chan Decision, 1)
b.mu.Lock()
b.pending[id] = resp
b.mu.Unlock()
defer b.dropPending(id)
timeout := timeoutValue()
expiresAt := time.Now().Add(timeout)
meta := make(map[string]string, len(p.Metadata)+3)
for k, v := range p.Metadata {
meta[k] = v
}
meta[MetaRequestID] = id
meta[MetaKind] = p.Kind
meta[MetaExpiresAt] = expiresAt.UTC().Format(time.RFC3339)
subject := p.Subject
if subject == "" {
subject = fmt.Sprintf("%s connection requires approval", p.Kind)
}
b.pub.PublishEvent(proto.SystemEvent_INFO, proto.SystemEvent_APPROVAL, subject, subject, meta)
log.Debugf("approval request %s (%s) emitted: %s", id, p.Kind, subject)
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case d := <-resp:
if !d.Accept {
return zero, ErrDenied
}
return d, nil
case <-timer.C:
return zero, ErrTimeout
case <-ctx.Done():
return zero, ctx.Err()
}
}
// Respond delivers the user's decision for id. Returns true when a pending
// request matched and was woken, false when id was unknown or already done.
func (b *Broker) Respond(id string, d Decision) bool {
if b == nil {
return false
}
b.mu.Lock()
ch, ok := b.pending[id]
if ok {
delete(b.pending, id)
}
b.mu.Unlock()
if !ok {
return false
}
select {
case ch <- d:
default:
}
return true
}
func (b *Broker) dropPending(id string) {
b.mu.Lock()
delete(b.pending, id)
b.mu.Unlock()
}

View File

@@ -1,434 +0,0 @@
package approval
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/proto"
)
// fakePublisher records published events and reports whether subscribers
// are connected. The subscribers flag is the security-critical signal:
// when false the broker must refuse to emit and the gate must fail closed.
type fakePublisher struct {
mu sync.Mutex
subscribers bool
events []*proto.SystemEvent
}
func (p *fakePublisher) PublishEvent(
severity proto.SystemEvent_Severity,
category proto.SystemEvent_Category,
msg string,
userMsg string,
metadata map[string]string,
) {
p.mu.Lock()
p.events = append(p.events, &proto.SystemEvent{
Severity: severity,
Category: category,
Message: msg,
UserMessage: userMsg,
Metadata: metadata,
})
p.mu.Unlock()
}
func (p *fakePublisher) HasEventSubscribers() bool {
p.mu.Lock()
defer p.mu.Unlock()
return p.subscribers
}
func (p *fakePublisher) lastEvent(t *testing.T) *proto.SystemEvent {
t.Helper()
p.mu.Lock()
defer p.mu.Unlock()
require.NotEmpty(t, p.events, "publisher saw no events")
return p.events[len(p.events)-1]
}
func (p *fakePublisher) eventCount() int {
p.mu.Lock()
defer p.mu.Unlock()
return len(p.events)
}
// TestRequestNoSubscriberFailsClosed is the core fail-closed invariant:
// when the UI is not subscribed, the broker must refuse without emitting
// an event or arming a waiter. A regression here is a silent bypass.
func TestRequestNoSubscriberFailsClosed(t *testing.T) {
pub := &fakePublisher{subscribers: false}
b := New(pub)
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
assert.ErrorIs(t, err, ErrNoSubscriber)
assert.Equal(t, 0, pub.eventCount(), "no event must be emitted when fail-closed")
b.mu.Lock()
pending := len(b.pending)
b.mu.Unlock()
assert.Equal(t, 0, pending, "no waiter must be registered on fail-closed")
}
// TestRequestTimeoutDenies verifies that a request without a UI response
// returns ErrTimeout (deny) rather than nil (silent accept). Uses a short
// per-test broker timeout via Respond after the fact to keep the test fast.
func TestRequestTimeoutDenies(t *testing.T) {
// Replace DefaultTimeout for the lifetime of this test.
orig := DefaultTimeout
defaultTimeout(t, 60*time.Millisecond)
defer defaultTimeout(t, orig)
pub := &fakePublisher{subscribers: true}
b := New(pub)
start := time.Now()
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
assert.ErrorIs(t, err, ErrTimeout, "missing user response must yield ErrTimeout, not nil")
assert.GreaterOrEqual(t, time.Since(start), 50*time.Millisecond, "timeout fired prematurely")
}
// TestRequestDenied returns ErrDenied when the UI responds with false.
func TestRequestDenied(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
var requestID string
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
}()
requestID = waitForRequestID(t, pub)
require.True(t, b.Respond(requestID, Decision{Accept: false}))
select {
case err := <-done:
assert.ErrorIs(t, err, ErrDenied)
case <-time.After(time.Second):
t.Fatal("Request did not return after Respond(false)")
}
}
// TestRequestAccepted is the happy path. Failure here doesn't bypass the
// gate but breaks the feature.
func TestRequestAccepted(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
}()
id := waitForRequestID(t, pub)
require.True(t, b.Respond(id, Decision{Accept: true}))
select {
case err := <-done:
assert.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("Request did not return after Respond(true)")
}
}
// TestRequestCtxCancelDenies verifies that an upstream cancel (e.g. the
// engine shutting down mid-prompt) returns the cancel error rather than
// nil. A nil here would be a silent bypass on shutdown races.
func TestRequestCtxCancelDenies(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
go func() {
done <- requestErr(b, ctx, Prompt{Kind: KindVNC, Subject: "test"})
}()
// Wait until the prompt is in flight so cancel races a live waiter.
_ = waitForRequestID(t, pub)
cancel()
select {
case err := <-done:
assert.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
t.Fatal("Request did not return after ctx cancel")
}
}
// TestRespondUnknownIsNoop ensures a stray RespondApproval RPC cannot
// affect or accidentally accept any in-flight request whose id it doesn't
// match. Also confirms it doesn't panic.
func TestRespondUnknownIsNoop(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
// No in-flight prompts: Respond returns false.
assert.False(t, b.Respond("does-not-exist", Decision{Accept: true}))
// With an in-flight prompt, a wrong id still returns false and the
// prompt remains armed (eventually timing out as a deny).
defaultTimeout(t, 60*time.Millisecond)
defer defaultTimeout(t, DefaultTimeout)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
realID := waitForRequestID(t, pub)
assert.False(t, b.Respond("totally-bogus", Decision{Accept: true}), "unknown id must not match")
assert.NotEqual(t, "totally-bogus", realID)
select {
case err := <-done:
assert.ErrorIs(t, err, ErrTimeout, "armed prompt must still time out, not accept")
case <-time.After(time.Second):
t.Fatal("prompt did not resolve")
}
}
// TestRespondAfterTimeoutNoop confirms a late accept response can't
// retroactively flip a denied (timed-out) request. The dropPending defer
// in Request must have removed the entry by the time Respond races in.
func TestRespondAfterTimeoutNoop(t *testing.T) {
defaultTimeout(t, 30*time.Millisecond)
defer defaultTimeout(t, DefaultTimeout)
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
id := waitForRequestID(t, pub)
select {
case err := <-done:
require.ErrorIs(t, err, ErrTimeout)
case <-time.After(time.Second):
t.Fatal("prompt did not time out")
}
assert.False(t, b.Respond(id, Decision{Accept: true}), "late respond must be no-op")
}
// TestRespondDoubleNoop ensures a duplicate ack from the UI doesn't leak
// past the matched waiter or panic on a closed/full channel.
func TestRespondDoubleNoop(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
id := waitForRequestID(t, pub)
require.True(t, b.Respond(id, Decision{Accept: true}))
assert.False(t, b.Respond(id, Decision{Accept: false}), "second response must be no-op")
select {
case err := <-done:
assert.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("prompt did not resolve")
}
}
// TestNilBrokerRequestErrors guards the engine pre-init path where the
// broker may not yet exist (or its publisher is nil): Request must
// error, never silently accept.
func TestNilBrokerRequestErrors(t *testing.T) {
var b *Broker
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC})
assert.Error(t, err, "nil broker must error, never silently accept")
b2 := New(nil)
_, err = b2.Request(context.Background(), Prompt{Kind: KindVNC})
assert.Error(t, err, "broker with nil publisher must error, never silently accept")
}
// TestPromptMetadataInjected confirms the broker stamps request_id, kind,
// and expires_at on the emitted event. The UI relies on these keys; if
// they are dropped, the user cannot route the prompt and the response
// path breaks (which fails closed via timeout).
func TestPromptMetadataInjected(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{
Kind: KindVNC,
Subject: "VNC connection from peerA",
Metadata: map[string]string{"peer_name": "peerA"},
})
}()
id := waitForRequestID(t, pub)
ev := pub.lastEvent(t)
assert.Equal(t, proto.SystemEvent_APPROVAL, ev.Category)
assert.Equal(t, KindVNC, ev.Metadata[MetaKind])
assert.Equal(t, id, ev.Metadata[MetaRequestID])
assert.NotEmpty(t, ev.Metadata[MetaExpiresAt])
assert.Equal(t, "peerA", ev.Metadata["peer_name"], "caller metadata must pass through")
require.True(t, b.Respond(id, Decision{Accept: true}))
<-done
}
// TestConcurrentRequests verifies that two concurrent prompts are tracked
// independently. A bug that aliases ids would let one Respond unblock
// the wrong waiter (a silent accept across prompts).
func TestConcurrentRequests(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
const n = 20
results := make(chan error, n)
for i := 0; i < n; i++ {
go func() {
results <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
}
ids := waitForNRequestIDs(t, pub, n)
require.Len(t, ids, n)
// Deny exactly half, accept the rest. Track outcome per id so we can
// match each Request's return value against the response we sent.
denySet := make(map[string]bool, n)
for i, id := range ids {
deny := i%2 == 0
denySet[id] = deny
require.True(t, b.Respond(id, Decision{Accept: !deny}))
}
// Collect all returns and check no nil errors slipped past a deny.
var accepted, denied atomic.Int32
for i := 0; i < n; i++ {
select {
case err := <-results:
if err == nil {
accepted.Add(1)
} else {
assert.ErrorIs(t, err, ErrDenied)
denied.Add(1)
}
case <-time.After(2 * time.Second):
t.Fatalf("only got %d/%d responses", i, n)
}
}
assert.Equal(t, int32(n/2), denied.Load())
assert.Equal(t, int32(n/2), accepted.Load())
}
// waitForRequestID blocks until the publisher sees its next event and
// returns the request_id stamped on it.
func waitForRequestID(t *testing.T, pub *fakePublisher) string {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
pub.mu.Lock()
count := len(pub.events)
var id string
if count > 0 {
id = pub.events[count-1].Metadata[MetaRequestID]
}
pub.mu.Unlock()
if id != "" {
return id
}
time.Sleep(2 * time.Millisecond)
}
t.Fatal("timeout waiting for emitted event")
return ""
}
func waitForNRequestIDs(t *testing.T, pub *fakePublisher, n int) []string {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
pub.mu.Lock()
count := len(pub.events)
pub.mu.Unlock()
if count >= n {
break
}
time.Sleep(2 * time.Millisecond)
}
pub.mu.Lock()
defer pub.mu.Unlock()
out := make([]string, 0, len(pub.events))
seen := make(map[string]struct{}, len(pub.events))
for _, ev := range pub.events {
id := ev.Metadata[MetaRequestID]
if id == "" {
continue
}
if _, dup := seen[id]; dup {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
if len(out) < n {
t.Fatalf("only got %d/%d request ids", len(out), n)
}
return out
}
// defaultTimeout swaps the broker's per-request wall-clock window so the
// timeout tests run quickly. Restores the prior value on the next call.
func defaultTimeout(t *testing.T, d time.Duration) {
t.Helper()
if d <= 0 {
t.Fatal("defaultTimeout must be > 0")
}
timeoutValue = func() time.Duration { return d }
}
// requestErr wraps Broker.Request to drop the Decision when tests only
// care about the error path. Keeps the goroutine bodies tight.
func requestErr(b *Broker, ctx context.Context, p Prompt) error {
_, err := b.Request(ctx, p)
return err
}
// TestRequestViewOnly checks the view-only outcome flows through Request's
// Decision return without being silently swallowed.
func TestRequestViewOnly(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
type result struct {
d Decision
err error
}
done := make(chan result, 1)
go func() {
d, err := b.Request(context.Background(), Prompt{Kind: KindVNC})
done <- result{d, err}
}()
id := waitForRequestID(t, pub)
require.True(t, b.Respond(id, Decision{Accept: true, ViewOnly: true}))
select {
case r := <-done:
assert.NoError(t, r.err)
assert.True(t, r.d.Accept)
assert.True(t, r.d.ViewOnly, "ViewOnly must survive the round-trip")
case <-time.After(time.Second):
t.Fatal("view-only request did not resolve")
}
}

View File

@@ -1,62 +0,0 @@
package approval
import "testing"
// TestShortKeyFingerprint locks in the format the VNC approval prompt
// shows to the user. The fingerprint is the user's only cryptographic
// anchor against a malicious management server that pushes a spoofed
// display name, so accidental changes to its format would silently
// undermine that defence.
func TestShortKeyFingerprint(t *testing.T) {
cases := []struct {
name string
in string
want string
}{
{
name: "full_32_byte_pubkey",
in: "0123456789abcdeffedcba9876543210ffeeddccbbaa99887766554433221100",
want: "0123-4567-89ab-cdef",
},
{
name: "exactly_16_chars",
in: "0123456789abcdef",
want: "0123-4567-89ab-cdef",
},
{
name: "borderline_8_chars",
in: "01234567",
want: "0123-4567",
},
{
name: "too_short_returns_empty",
in: "0123",
want: "",
},
{
name: "empty_returns_empty",
in: "",
want: "",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := ShortKeyFingerprint(tc.in)
if got != tc.want {
t.Fatalf("ShortKeyFingerprint(%q) = %q, want %q", tc.in, got, tc.want)
}
})
}
}
// TestShortKeyFingerprint_DistinctKeysDistinctOutputs guards against a
// formatting bug that would collapse different prefixes onto the same
// displayed fingerprint and let an attacker substitute their pubkey for
// a victim's while keeping the prompt visually identical.
func TestShortKeyFingerprint_DistinctKeysDistinctOutputs(t *testing.T) {
a := ShortKeyFingerprint("0123456789abcdef" + "rest_of_pubkey_ignored")
b := ShortKeyFingerprint("0123456789abcde0" + "rest_of_pubkey_ignored")
if a == b {
t.Fatalf("expected distinct outputs for distinct prefixes, both = %q", a)
}
}

View File

@@ -315,7 +315,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.ServerVNCAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,

View File

@@ -581,8 +581,6 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
DisableVNCApproval: config.DisableVNCApproval,
EnableSSHRoot: config.EnableSSHRoot,
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
@@ -665,7 +663,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.ServerVNCAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,

View File

@@ -655,12 +655,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.SSHJWTCacheTTL != nil {
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
}
if g.internalConfig.ServerVNCAllowed != nil {
configContent.WriteString(fmt.Sprintf("ServerVNCAllowed: %v\n", *g.internalConfig.ServerVNCAllowed))
}
if g.internalConfig.DisableVNCApproval != nil {
configContent.WriteString(fmt.Sprintf("DisableVNCApproval: %v\n", *g.internalConfig.DisableVNCApproval))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))

View File

@@ -864,8 +864,6 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
RosenpassEnabled: true,
RosenpassPermissive: true,
ServerSSHAllowed: &bTrue,
ServerVNCAllowed: &bTrue,
DisableVNCApproval: &bTrue,
EnableSSHRoot: &bTrue,
EnableSSHSFTP: &bTrue,
EnableSSHLocalPortForwarding: &bTrue,

View File

@@ -8,6 +8,7 @@ import (
"errors"
"net"
"net/netip"
"slices"
"strings"
"github.com/miekg/dns"
@@ -167,7 +168,10 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
case dns.TypeA:
alternativeNetwork = "ip6"
default:
return dns.RcodeNameError
// Non-address types reach LookupIP only unexpectedly; without an
// address pair to probe we cannot prove the name is absent, so answer
// NODATA rather than a poisoning NXDOMAIN.
return dns.RcodeSuccess
}
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
@@ -184,6 +188,230 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
return dns.RcodeSuccess
}
// RecordResolver is the host resolver surface used to forward non-address
// record queries. net.DefaultResolver satisfies it.
type RecordResolver interface {
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
LookupCNAME(ctx context.Context, host string) (string, error)
LookupAddr(ctx context.Context, addr string) ([]string, error)
}
// LookupRecords resolves a non-address DNS record type through the host
// resolver and returns the resource records and the DNS rcode. Types the host
// resolver cannot answer (anything not covered by the net.Resolver Lookup*
// methods) yield NODATA so that a routed name is never poisoned with NXDOMAIN
// for an unsupported type.
func LookupRecords(ctx context.Context, r RecordResolver, name string, qtype uint16, ttl uint32) ([]dns.RR, int) {
fqdn := dns.Fqdn(name)
switch qtype {
case dns.TypeMX:
return lookupMX(ctx, r, name, fqdn, ttl)
case dns.TypeTXT:
return lookupTXT(ctx, r, name, fqdn, ttl)
case dns.TypeNS:
return lookupNS(ctx, r, name, fqdn, ttl)
case dns.TypeSRV:
return lookupSRV(ctx, r, name, fqdn, ttl)
case dns.TypeCNAME:
return lookupCNAME(ctx, r, name, fqdn, ttl)
case dns.TypePTR:
return lookupPTR(ctx, r, name, fqdn, ttl)
default:
return nil, dns.RcodeSuccess
}
}
func recordHeader(fqdn string, rrtype uint16, ttl uint32) dns.RR_Header {
return dns.RR_Header{Name: fqdn, Rrtype: rrtype, Class: dns.ClassINET, Ttl: ttl}
}
func lookupMX(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupMX(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, mx := range recs {
rrs = append(rrs, &dns.MX{
Hdr: recordHeader(fqdn, dns.TypeMX, ttl),
Preference: mx.Pref,
Mx: dns.Fqdn(mx.Host),
})
}
return rrs, dns.RcodeSuccess
}
func lookupTXT(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupTXT(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, txt := range recs {
rrs = append(rrs, &dns.TXT{
Hdr: recordHeader(fqdn, dns.TypeTXT, ttl),
Txt: chunkTXT(txt),
})
}
return rrs, dns.RcodeSuccess
}
func lookupNS(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupNS(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, ns := range recs {
rrs = append(rrs, &dns.NS{
Hdr: recordHeader(fqdn, dns.TypeNS, ttl),
Ns: dns.Fqdn(ns.Host),
})
}
return rrs, dns.RcodeSuccess
}
func lookupSRV(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
_, recs, err := r.LookupSRV(ctx, "", "", name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, srv := range recs {
rrs = append(rrs, &dns.SRV{
Hdr: recordHeader(fqdn, dns.TypeSRV, ttl),
Priority: srv.Priority,
Weight: srv.Weight,
Port: srv.Port,
Target: dns.Fqdn(srv.Target),
})
}
return rrs, dns.RcodeSuccess
}
func lookupCNAME(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
cname, err := r.LookupCNAME(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
// LookupCNAME returns the queried name itself when the name resolves but
// has no CNAME record; that is a NODATA result, not a CNAME.
if strings.EqualFold(dns.Fqdn(cname), fqdn) {
return nil, dns.RcodeSuccess
}
return []dns.RR{&dns.CNAME{
Hdr: recordHeader(fqdn, dns.TypeCNAME, ttl),
Target: dns.Fqdn(cname),
}}, dns.RcodeSuccess
}
func lookupPTR(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
addr, ok := ptrQueryAddr(name)
if !ok {
return nil, dns.RcodeSuccess
}
names, err := r.LookupAddr(ctx, addr)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(names))
for _, n := range names {
rrs = append(rrs, &dns.PTR{
Hdr: recordHeader(fqdn, dns.TypePTR, ttl),
Ptr: dns.Fqdn(n),
})
}
return rrs, dns.RcodeSuccess
}
// ptrQueryAddr converts a reverse-DNS query name (in-addr.arpa or ip6.arpa)
// into the address string expected by net.Resolver.LookupAddr. It reports false
// when the name is not a well-formed reverse name.
func ptrQueryAddr(qname string) (string, bool) {
name := strings.TrimSuffix(strings.ToLower(dns.Fqdn(qname)), ".")
switch {
case strings.HasSuffix(name, ".in-addr.arpa"):
return parseInAddrArpa(strings.TrimSuffix(name, ".in-addr.arpa"))
case strings.HasSuffix(name, ".ip6.arpa"):
return parseIP6Arpa(strings.TrimSuffix(name, ".ip6.arpa"))
default:
return "", false
}
}
// parseInAddrArpa turns the label portion of an in-addr.arpa name into an IPv4
// address string, reporting false when it is not a well-formed reverse name.
func parseInAddrArpa(labelPart string) (string, bool) {
labels := strings.Split(labelPart, ".")
if len(labels) != 4 {
return "", false
}
slices.Reverse(labels)
addr, err := netip.ParseAddr(strings.Join(labels, "."))
if err != nil || !addr.Is4() {
return "", false
}
return addr.String(), true
}
// parseIP6Arpa turns the nibble portion of an ip6.arpa name into an IPv6
// address string, reporting false when it is not a well-formed reverse name.
func parseIP6Arpa(nibblePart string) (string, bool) {
nibbles := strings.Split(nibblePart, ".")
if len(nibbles) != 32 {
return "", false
}
slices.Reverse(nibbles)
var sb strings.Builder
for i, n := range nibbles {
if i > 0 && i%4 == 0 {
sb.WriteByte(':')
}
sb.WriteString(n)
}
addr, err := netip.ParseAddr(sb.String())
if err != nil || !addr.Is6() {
return "", false
}
return addr.String(), true
}
// rcodeForRecordError maps a non-address lookup error to a DNS rcode. A
// not-found result becomes NODATA rather than NXDOMAIN: net.DNSError.IsNotFound
// does not distinguish a missing name from a name that exists only with records
// of other types, so the name cannot be proven absent and must not be poisoned.
func rcodeForRecordError(err error) int {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
return dns.RcodeSuccess
}
return dns.RcodeServerFailure
}
// chunkTXT splits a TXT string into character-strings no longer than 255 bytes
// so the record can be packed. The chunks form one TXT resource record.
func chunkTXT(s string) []string {
const maxLen = 255
if len(s) <= maxLen {
return []string{s}
}
var chunks []string
for len(s) > maxLen {
chunks = append(chunks, s[:maxLen])
s = s[maxLen:]
}
if len(s) > 0 {
chunks = append(chunks, s)
}
return chunks
}
// FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 {

View File

@@ -5,6 +5,7 @@ import (
"errors"
"net"
"net/netip"
"strings"
"testing"
"github.com/miekg/dns"
@@ -121,6 +122,164 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
}
func TestPtrQueryAddr(t *testing.T) {
tests := []struct {
name string
qname string
want string
wantOK bool
}{
{name: "ipv4", qname: "4.3.2.1.in-addr.arpa.", want: "1.2.3.4", wantOK: true},
{name: "ipv4 no trailing dot", qname: "1.0.0.127.in-addr.arpa", want: "127.0.0.1", wantOK: true},
{
name: "ipv6",
qname: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
want: "2001:db8::1",
wantOK: true,
},
{name: "ipv4 wrong label count", qname: "2.1.in-addr.arpa.", wantOK: false},
{name: "ipv6 wrong nibble count", qname: "1.0.ip6.arpa.", wantOK: false},
{name: "not a reverse name", qname: "example.com.", wantOK: false},
{name: "ipv4 bad octet", qname: "4.3.2.999.in-addr.arpa.", wantOK: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := ptrQueryAddr(tt.qname)
assert.Equal(t, tt.wantOK, ok, "parse success mismatch")
if tt.wantOK {
assert.Equal(t, tt.want, got, "parsed address mismatch")
}
})
}
}
type mockRecordResolver struct {
mx []*net.MX
txt []string
ns []*net.NS
srv []*net.SRV
cname string
ptr []string
err error
}
func (m *mockRecordResolver) LookupMX(context.Context, string) ([]*net.MX, error) {
return m.mx, m.err
}
func (m *mockRecordResolver) LookupTXT(context.Context, string) ([]string, error) {
return m.txt, m.err
}
func (m *mockRecordResolver) LookupNS(context.Context, string) ([]*net.NS, error) {
return m.ns, m.err
}
func (m *mockRecordResolver) LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error) {
return "", m.srv, m.err
}
func (m *mockRecordResolver) LookupCNAME(context.Context, string) (string, error) {
return m.cname, m.err
}
func (m *mockRecordResolver) LookupAddr(context.Context, string) ([]string, error) {
return m.ptr, m.err
}
func TestLookupRecords(t *testing.T) {
notFound := &net.DNSError{IsNotFound: true, Name: "example.com."}
t.Run("MX success", func(t *testing.T) {
r := &mockRecordResolver{mx: []*net.MX{{Host: "mail.example.com.", Pref: 10}}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "mail.example.com.", rrs[0].(*dns.MX).Mx)
})
t.Run("TXT short string is one character-string", func(t *testing.T) {
r := &mockRecordResolver{txt: []string{"v=spf1 -all"}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, []string{"v=spf1 -all"}, rrs[0].(*dns.TXT).Txt)
})
t.Run("TXT chunks long strings", func(t *testing.T) {
long := strings.Repeat("a", 300)
r := &mockRecordResolver{txt: []string{long}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
txt := rrs[0].(*dns.TXT).Txt
require.Len(t, txt, 2, "300-byte string should split into two character-strings")
assert.Equal(t, 255, len(txt[0]))
assert.Equal(t, 45, len(txt[1]))
})
t.Run("NS success", func(t *testing.T) {
r := &mockRecordResolver{ns: []*net.NS{{Host: "ns1.example.com."}}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeNS, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "ns1.example.com.", rrs[0].(*dns.NS).Ns)
})
t.Run("SRV success", func(t *testing.T) {
r := &mockRecordResolver{srv: []*net.SRV{{Target: "sip.example.com.", Port: 5060}}}
rrs, rcode := LookupRecords(context.Background(), r, "_sip._tcp.example.com.", dns.TypeSRV, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, uint16(5060), rrs[0].(*dns.SRV).Port)
})
t.Run("CNAME success", func(t *testing.T) {
r := &mockRecordResolver{cname: "target.example.com."}
rrs, rcode := LookupRecords(context.Background(), r, "www.example.com.", dns.TypeCNAME, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "target.example.com.", rrs[0].(*dns.CNAME).Target)
})
t.Run("CNAME equal to name is NODATA", func(t *testing.T) {
r := &mockRecordResolver{cname: "example.com."}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCNAME, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs, "self-referential CNAME is NODATA")
})
t.Run("PTR success", func(t *testing.T) {
r := &mockRecordResolver{ptr: []string{"host.example.com."}}
rrs, rcode := LookupRecords(context.Background(), r, "4.3.2.1.in-addr.arpa.", dns.TypePTR, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "host.example.com.", rrs[0].(*dns.PTR).Ptr)
})
t.Run("PTR malformed name is NODATA", func(t *testing.T) {
r := &mockRecordResolver{}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypePTR, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs)
})
t.Run("not found is NODATA never NXDOMAIN", func(t *testing.T) {
r := &mockRecordResolver{err: notFound}
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeSuccess, rcode, "missing record must not poison the name")
})
t.Run("server failure maps to SERVFAIL", func(t *testing.T) {
r := &mockRecordResolver{err: &net.DNSError{Err: "server misbehaving", IsTemporary: true}}
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeServerFailure, rcode)
})
t.Run("unsupported type is NODATA", func(t *testing.T) {
r := &mockRecordResolver{}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCAA, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs)
})
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{

View File

@@ -37,6 +37,12 @@ const (
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
LookupCNAME(ctx context.Context, host string) (string, error)
LookupAddr(ctx context.Context, addr string) ([]string, error)
}
type firewaller interface {
@@ -210,12 +216,6 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
f.writeResponse(logger, w, resp, qname, startTime)
return
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
if mostSpecificResId == "" {
@@ -227,9 +227,46 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
reqHasEdns := query.IsEdns0() != nil
switch question.Qtype {
case dns.TypeA, dns.TypeAAAA:
f.handleAddressQuery(ctx, logger, w, resp, mostSpecificResId, matchingEntries, reqHasEdns, startTime)
case dns.TypeMX, dns.TypeTXT, dns.TypeNS, dns.TypeSRV, dns.TypeCNAME, dns.TypePTR:
f.handleRecordQuery(ctx, logger, w, resp, startTime)
default:
// The domain is routed here, so any other type is answered NODATA
// (NOERROR, empty answer) rather than falling back to a resolver that
// would poison the name with NXDOMAIN. The Extended DNS Error lets a
// client tell this capability-driven NODATA apart from an
// authoritative one. The OPT pseudo-record must not appear unless the
// query advertised EDNS0.
if reqHasEdns {
attachEDE(resp, dns.ExtendedErrorCodeNotSupported, "netbird forwarder: unsupported query type")
}
f.writeResponse(logger, w, resp, qname, startTime)
}
}
// handleAddressQuery resolves A/AAAA queries, programs the firewall sets and
// resolved-IP state, and caches the answer for resilience on upstream failure.
func (f *DNSForwarder) handleAddressQuery(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
resp *dns.Msg,
mostSpecificResId route.ResID,
matchingEntries []*ForwarderEntry,
reqHasEdns bool,
startTime time.Time,
) {
question := resp.Question[0]
qname := strings.ToLower(question.Name)
network := resutil.NetworkForQtype(question.Qtype)
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
f.handleDNSError(ctx, logger, w, question, resp, qname, result, reqHasEdns, startTime)
return
}
@@ -240,6 +277,25 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
f.writeResponse(logger, w, resp, qname, startTime)
}
// handleRecordQuery resolves non-address record types (MX, TXT, NS, SRV,
// CNAME, PTR) through the host resolver. Missing records are answered NODATA so
// the routed name is never poisoned with NXDOMAIN.
func (f *DNSForwarder) handleRecordQuery(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
resp *dns.Msg,
startTime time.Time,
) {
question := resp.Question[0]
qname := strings.ToLower(question.Name)
records, rcode := resutil.LookupRecords(ctx, f.resolver, qname, question.Qtype, f.ttl)
resp.Rcode = rcode
resp.Answer = append(resp.Answer, records...)
f.writeResponse(logger, w, resp, qname, startTime)
}
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)

View File

@@ -133,6 +133,41 @@ func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([
return args.Get(0).([]netip.Addr), args.Error(1)
}
func (m *MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]*net.MX)
return recs, args.Error(1)
}
func (m *MockResolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]string)
return recs, args.Error(1)
}
func (m *MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]*net.NS)
return recs, args.Error(1)
}
func (m *MockResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
args := m.Called(ctx, service, proto, name)
recs, _ := args.Get(1).([]*net.SRV)
return args.String(0), recs, args.Error(2)
}
func (m *MockResolver) LookupCNAME(ctx context.Context, host string) (string, error) {
args := m.Called(ctx, host)
return args.String(0), args.Error(1)
}
func (m *MockResolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
args := m.Called(ctx, addr)
recs, _ := args.Get(0).([]string)
return recs, args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct {
name string
@@ -545,12 +580,15 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
}
func TestDNSForwarder_ResponseCodes(t *testing.T) {
// A type with no net.Resolver Lookup method (CAA) must answer NODATA
// (NOERROR, empty) rather than NXDOMAIN/NOTIMP to avoid poisoning the name.
tests := []struct {
name string
queryType uint16
queryDomain string
configured string
expectedCode int
expectEDE bool
description string
}{
{
@@ -562,28 +600,13 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
description: "RFC compliant REFUSED for unauthorized queries",
},
{
name: "unsupported query type returns NOTIMP",
queryType: dns.TypeMX,
name: "unsupported query type returns NODATA",
queryType: dns.TypeCAA,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "RFC compliant NOTIMP for unsupported types",
},
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
expectedCode: dns.RcodeSuccess,
expectEDE: true,
description: "Unsupported types answer NODATA, not NXDOMAIN/NOTIMP",
},
}
@@ -599,6 +622,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
query.SetEdns0(dns.DefaultMsgSize, false)
// Capture the written response
var writtenResp *dns.Msg
@@ -614,10 +638,213 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
assert.Empty(t, writtenResp.Answer, "Non-address response should carry no answers")
if tt.expectEDE {
require.NotNil(t, writtenResp.IsEdns0(), "EDNS0 client should get an OPT in the reply")
assert.True(t, hasEDE(writtenResp, dns.ExtendedErrorCodeNotSupported),
"unsupported type NODATA should carry EDE Not Supported")
}
})
}
}
func hasEDE(m *dns.Msg, code uint16) bool {
opt := m.IsEdns0()
if opt == nil {
return false
}
for _, o := range opt.Option {
if ede, ok := o.(*dns.EDNS0_EDE); ok && ede.InfoCode == code {
return true
}
}
return false
}
func TestDNSForwarder_RecordQueries(t *testing.T) {
notFound := &net.DNSError{IsNotFound: true, Name: "example.com"}
t.Run("MX records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupMX", mock.Anything, "example.com.").
Return([]*net.MX{{Host: "mail.example.com.", Pref: 10}}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
mx, ok := resp.Answer[0].(*dns.MX)
require.True(t, ok, "answer should be an MX record")
assert.Equal(t, uint16(10), mx.Preference)
assert.Equal(t, "mail.example.com.", mx.Mx)
mockResolver.AssertExpectations(t)
})
t.Run("missing MX is NODATA not NXDOMAIN", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
// A not-found cannot prove the name is absent (it may exist with only
// other record types), so it must answer NODATA, never NXDOMAIN.
mockResolver.On("LookupMX", mock.Anything, "example.com.").
Return(nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "missing record must be NODATA")
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("NS records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupNS", mock.Anything, "example.com.").
Return([]*net.NS{{Host: "ns1.example.com."}}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ns, ok := resp.Answer[0].(*dns.NS)
require.True(t, ok, "answer should be an NS record")
assert.Equal(t, "ns1.example.com.", ns.Ns)
mockResolver.AssertExpectations(t)
})
t.Run("missing NS is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupNS", mock.Anything, "example.com.").
Return(nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("SRV records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
Return("", []*net.SRV{{Target: "sip.example.com.", Port: 5060, Priority: 10, Weight: 5}}, nil).Once()
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
srv, ok := resp.Answer[0].(*dns.SRV)
require.True(t, ok, "answer should be an SRV record")
assert.Equal(t, "sip.example.com.", srv.Target)
assert.Equal(t, uint16(5060), srv.Port)
assert.Equal(t, uint16(10), srv.Priority)
mockResolver.AssertExpectations(t)
})
t.Run("missing SRV is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
Return("", nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("TXT records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupTXT", mock.Anything, "example.com.").
Return([]string{"v=spf1 -all"}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeTXT)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
txt, ok := resp.Answer[0].(*dns.TXT)
require.True(t, ok, "answer should be a TXT record")
assert.Equal(t, []string{"v=spf1 -all"}, txt.Txt)
mockResolver.AssertExpectations(t)
})
t.Run("CNAME record is forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "www.example.com")
mockResolver.On("LookupCNAME", mock.Anything, "www.example.com.").
Return("target.example.com.", nil).Once()
resp := runRecordQuery(t, forwarder, "www.example.com", dns.TypeCNAME)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok, "answer should be a CNAME record")
assert.Equal(t, "target.example.com.", cname.Target)
mockResolver.AssertExpectations(t)
})
t.Run("CNAME equal to the name is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
// No CNAME exists: LookupCNAME echoes the queried name back.
mockResolver.On("LookupCNAME", mock.Anything, "example.com.").
Return("example.com.", nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeCNAME)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer, "self-referential CNAME means no CNAME record")
mockResolver.AssertExpectations(t)
})
t.Run("PTR record is forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "*.in-addr.arpa")
// The reverse name is parsed back to the address LookupAddr expects.
mockResolver.On("LookupAddr", mock.Anything, "1.2.3.4").
Return([]string{"host.example.com."}, nil).Once()
resp := runRecordQuery(t, forwarder, "4.3.2.1.in-addr.arpa", dns.TypePTR)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok, "answer should be a PTR record")
assert.Equal(t, "host.example.com.", ptr.Ptr)
mockResolver.AssertExpectations(t)
})
}
func newRecordTestForwarder(t *testing.T, r resolver, configured string) *DNSForwarder {
t.Helper()
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = r
d, err := domain.FromString(configured)
require.NoError(t, err)
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
return forwarder
}
func runRecordQuery(t *testing.T, forwarder *DNSForwarder, qname string, qtype uint16) *dns.Msg {
t.Helper()
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(qname), qtype)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp, "expected response to be written")
return resp
}
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
tests := []struct {
name string

View File

@@ -34,7 +34,6 @@ import (
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/approval"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
@@ -83,6 +82,12 @@ const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
disableAutoUpdate = "disabled"
// systemInfoTimeout bounds how long the sync loop waits for system info / posture
// check gathering. The gathering runs uncancellable system calls (process scan,
// exec, os.Stat); without this bound a single stuck call freezes handleSync, and
// thus syncMsgMux, for as long as the call hangs (observed multi-minute freezes).
systemInfoTimeout = 15 * time.Second
)
var ErrResetConnection = fmt.Errorf("reset connection")
@@ -126,8 +131,6 @@ type EngineConfig struct {
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
DisableVNCApproval *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -215,9 +218,7 @@ type Engine struct {
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
vncSrv vncServer
approvalBroker *approval.Broker
sshServer sshServer
statusRecorder *peer.Status
@@ -309,7 +310,6 @@ func NewEngine(
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
approvalBroker: approval.New(services.StatusRecorder),
stateManager: services.StateManager,
portForwardManager: portforward.NewManager(),
checks: services.Checks,
@@ -372,10 +372,6 @@ func (e *Engine) stopLocked() {
log.Warnf("failed to stop SSH server: %v", err)
}
if err := e.stopVNCServer(); err != nil {
log.Warnf("failed to stop VNC server: %v", err)
}
e.cleanupSSHConfig()
if e.ingressGatewayMgr != nil {
@@ -905,6 +901,16 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
}
// phase times a sync sub-phase: it returns a function that records the elapsed
// duration when called. Starting the timer at the call site keeps inter-phase
// glue code out of the measurement.
func (e *Engine) phase(name string) func() {
start := time.Now()
return func() {
e.clientMetrics.RecordSyncPhase(e.ctx, name, time.Since(start))
}
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
@@ -924,7 +930,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
}
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
done := e.phase("netbird_config")
err := e.updateNetbirdConfig(update.GetNetbirdConfig())
done()
if err != nil {
return err
}
@@ -938,11 +947,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
done = e.phase("checks")
err = e.updateChecksIfNew(update.Checks)
done()
if err != nil {
return err
}
done = e.phase("persist")
e.persistSyncResponse(update)
done()
// only apply new changes and ignore old ones
if err := e.updateNetworkMap(nm); err != nil {
@@ -1076,16 +1090,26 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
}
e.checks = checks
info, err := system.GetInfoWithChecks(e.ctx, checks, e.overlayAddresses()...)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, checks, e.overlayAddresses()...)
if !ok {
// Gathering timed out; skip the meta sync this cycle rather than blocking the
// sync loop (and syncMsgMux) on a stuck system call. A later sync will retry.
return nil
}
e.applyInfoFlags(info)
if err := e.mgmClient.SyncMeta(info); err != nil {
return fmt.Errorf("could not sync meta: error %s", err)
}
return nil
}
// applyInfoFlags sets the engine's config-derived feature flags on the gathered system info.
func (e *Engine) applyInfoFlags(info *system.Info) {
info.SetFlags(
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1100,12 +1124,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
log.Errorf("could not sync meta: error %s", err)
return err
}
return nil
}
// overlayAddresses returns our own WireGuard overlay address (v4 and v6) so it
@@ -1147,10 +1165,6 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
}
if err := e.updateVNC(); err != nil {
log.Warnf("failed handling VNC server setup: %v", err)
}
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.IPv6 = e.wgInterface.Address().IPv6String()
@@ -1269,32 +1283,15 @@ func (e *Engine) receiveManagementEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.overlayAddresses()...)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, e.checks, e.overlayAddresses()...)
if !ok {
// Gathering timed out; connect the stream with base info so management
// connectivity still comes up rather than blocking here.
info = system.GetInfo(e.ctx)
}
info.SetFlags(
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
e.config.DisableFirewall,
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.DisableIPv6,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
e.applyInfoFlags(info)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
err := e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
@@ -1387,13 +1384,16 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address())
done := e.phase("dns_server")
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
done()
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
// apply routes first, route related actions might depend on routing being enabled
done = e.phase("routes_classify")
routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
@@ -1402,29 +1402,60 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.connMgr.UpdateRouteHAMap(clientRoutes)
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
}
done()
done = e.phase("routes_apply")
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update routes: %v", err)
}
done()
done = e.phase("filtering")
if e.acl != nil {
e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
}
done()
done = e.phase("dns_forwarder")
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
done()
// Ingress forward rules
done = e.phase("forward_rules")
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
if err != nil {
log.Errorf("failed to update forward rules, err: %v", err)
}
done()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
done = e.phase("offline_peers")
e.updateOfflinePeers(networkMap.GetOfflinePeers())
done()
remotePeers, err := e.reconcilePeers(networkMap)
if err != nil {
return err
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
done = e.phase("lazy_exclude")
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
done()
e.networkSerial = serial
return nil
}
// reconcilePeers applies the remote peer list from the network map (removing,
// modifying and adding peers, then updating SSH config) and returns the remote
// peers with our own peer filtered out, for use by later sync steps.
func (e *Engine) reconcilePeers(networkMap *mgmProto.NetworkMap) ([]*mgmProto.RemotePeerConfig, error) {
// Filter out own peer from the remote peers list
localPubKey := e.config.WgPrivateKey.PublicKey().String()
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
@@ -1439,47 +1470,43 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
err := e.removeAllPeers()
e.statusRecorder.FinishPeerListModifications()
if err != nil {
return err
return nil, err
}
} else {
err := e.removePeers(remotePeers)
if err != nil {
return err
}
err = e.modifyPeers(remotePeers)
if err != nil {
return err
}
err = e.addNewPeers(remotePeers)
if err != nil {
return err
}
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
return remotePeers, nil
}
// VNC auth: always sync, including nil so cleared auth on the management
// side is applied locally, and so it isn't skipped on the RemotePeersIsEmpty
// cleanup path.
e.updateVNCServerAuth(networkMap.GetVncAuth())
done := e.phase("removed_peers")
err := e.removePeers(remotePeers)
done()
if err != nil {
return nil, err
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
done = e.phase("modified_peers")
err = e.modifyPeers(remotePeers)
done()
if err != nil {
return nil, err
}
e.networkSerial = serial
done = e.phase("added_peers")
err = e.addNewPeers(remotePeers)
done()
if err != nil {
return nil, err
}
return nil
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
return remotePeers, nil
}
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
@@ -1952,7 +1979,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -2721,16 +2747,3 @@ func decodeRelayIP(b []byte) netip.Addr {
}
return ip.Unmap()
}
// RespondApproval relays the user's decision for a pending approval to
// the broker. viewOnly is honoured only when accept is true. Returns
// true when the request_id matched a live prompt.
func (e *Engine) RespondApproval(requestID string, accept, viewOnly bool) bool {
if e == nil || e.approvalBroker == nil {
return false
}
return e.approvalBroker.Respond(requestID, approval.Decision{
Accept: accept,
ViewOnly: accept && viewOnly,
})
}

View File

@@ -12,10 +12,10 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/netstack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
@@ -237,18 +237,22 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
return errors.New("wg interface not initialized")
}
wgAddr := e.wgInterface.Address()
serverConfig := &sshserver.Config{
HostKeyPEM: e.config.SSHKey,
JWT: jwtConfig,
NetstackNet: e.wgInterface.GetNet(),
NetworkValidation: wgAddr,
HostKeyPEM: e.config.SSHKey,
JWT: jwtConfig,
}
server := sshserver.New(serverConfig)
wgAddr := e.wgInterface.Address()
server.SetNetworkValidation(wgAddr)
netbirdIP := wgAddr.IP
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
server.SetNetstackNet(netstackNet)
}
e.configureSSHServer(server)
if err := server.Start(e.ctx, listenAddr); err != nil {

View File

@@ -178,6 +178,10 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time {
return nil
}
func (m *MockWGIface) MTU() uint16 {
return 1280
}
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
return nil
}

View File

@@ -1,302 +0,0 @@
//go:build !js && !ios && !android
package internal
import (
"context"
"errors"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/approval"
"github.com/netbirdio/netbird/client/internal/metrics"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/vnc"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
type vncServer interface {
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
Stop() error
ActiveSessions() []vncserver.ActiveSessionInfo
}
func (e *Engine) setupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vnc.ExternalPort, vnc.InternalPort); err != nil {
return fmt.Errorf("add VNC port redirection: %w", err)
}
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vnc.ExternalPort, localAddr, vnc.InternalPort)
return nil
}
func (e *Engine) cleanupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vnc.ExternalPort, vnc.InternalPort); err != nil {
return fmt.Errorf("remove VNC port redirection: %w", err)
}
return nil
}
// updateVNC handles starting/stopping the VNC server based on the config flag.
func (e *Engine) updateVNC() error {
if !e.config.ServerVNCAllowed {
if e.vncSrv != nil {
log.Info("VNC server disabled, stopping")
}
return e.stopVNCServer()
}
if e.config.BlockInbound {
log.Info("VNC server disabled because inbound connections are blocked")
return e.stopVNCServer()
}
if e.vncSrv != nil {
return nil
}
return e.startVNCServer()
}
func (e *Engine) startVNCServer() error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
capturer, injector, ok := newPlatformVNC()
if !ok {
log.Debug("VNC server not supported on this platform")
return nil
}
netbirdIP := e.wgInterface.Address().IP
var sessionRecorder func(vncserver.SessionTick)
if e.clientMetrics != nil {
sessionRecorder = func(t vncserver.SessionTick) {
e.clientMetrics.RecordVNCSessionTick(e.ctx, metrics.VNCSessionTick{
Period: t.Period,
BytesOut: t.BytesOut,
Writes: t.Writes,
FBUs: t.FBUs,
MaxFBUBytes: t.MaxFBUBytes,
MaxFBURects: t.MaxFBURects,
MaxWriteBytes: t.MaxWriteBytes,
WriteNanos: t.WriteNanos,
})
}
}
serviceMode := vncNeedsServiceMode()
if serviceMode {
log.Info("VNC: running as system service, enabling service mode (per-session agent proxy)")
}
requireApproval := e.config.DisableVNCApproval == nil || !*e.config.DisableVNCApproval
srv := vncserver.New(vncserver.Config{
Capturer: capturer,
Injector: injector,
IdentityKey: e.config.WgPrivateKey[:],
ServiceMode: serviceMode,
SessionRecorder: sessionRecorder,
NetstackNet: e.wgInterface.GetNet(),
RequireApproval: requireApproval,
Approver: &vncApprover{broker: e.approvalBroker, statusRecorder: e.statusRecorder},
})
listenAddr := netip.AddrPortFrom(netbirdIP, vnc.InternalPort)
network := e.wgInterface.Address().Network
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
return fmt.Errorf("start VNC server: %w", err)
}
e.vncSrv = srv
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, vnc.InternalPort)
log.Debugf("registered VNC service with netstack for TCP:%d", vnc.InternalPort)
}
}
if err := e.setupVNCPortRedirection(); err != nil {
log.Warnf("setup VNC port redirection: %v", err)
}
log.Info("VNC server enabled")
return nil
}
// updateVNCServerAuth updates VNC fine-grained access control from management.
// A nil vncAuth clears all authorized users and session pubkeys so management
// can revoke access by omitting the field on the next sync.
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
if e.vncSrv == nil {
return
}
vncSrv, ok := e.vncSrv.(*vncserver.Server)
if !ok {
return
}
if vncAuth == nil {
vncSrv.UpdateVNCAuth(&sshauth.Config{})
return
}
protoUsers := vncAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range vncAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
sessionPubKeys := make([]sshauth.SessionPubKey, 0, len(vncAuth.GetSessionPubKeys()))
for _, pk := range vncAuth.GetSessionPubKeys() {
pub := pk.GetPubKey()
if len(pub) != 32 {
log.Warnf("VNC session pubkey wrong length %d", len(pub))
continue
}
hash := pk.GetUserIdHash()
if len(hash) != 16 {
log.Warnf("VNC session user id hash wrong length %d", len(hash))
continue
}
sessionPubKeys = append(sessionPubKeys, sshauth.SessionPubKey{
PubKey: pub,
UserIDHash: sshuserhash.UserIDHash(hash),
DisplayName: pk.GetDisplayName(),
})
}
vncSrv.UpdateVNCAuth(&sshauth.Config{
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
SessionPubKeys: sessionPubKeys,
})
}
// GetVNCServerStatus returns whether the VNC server is running and the list
// of active VNC sessions. The pointer is captured under syncMsgMux so a
// concurrent updateVNC/stopVNCServer cannot swap it out between the nil
// check and the ActiveSessions call.
func (e *Engine) GetVNCServerStatus() (enabled bool, sessions []vncserver.ActiveSessionInfo) {
e.syncMsgMux.Lock()
vncSrv := e.vncSrv
e.syncMsgMux.Unlock()
if vncSrv == nil {
return false, nil
}
return true, vncSrv.ActiveSessions()
}
func (e *Engine) stopVNCServer() error {
if e.vncSrv == nil {
return nil
}
if err := e.cleanupVNCPortRedirection(); err != nil {
log.Warnf("cleanup VNC port redirection: %v", err)
}
if e.wgInterface != nil && e.wgInterface.GetNet() != nil {
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, vnc.InternalPort)
}
}
log.Info("stopping VNC server")
err := e.vncSrv.Stop()
e.vncSrv = nil
if err != nil {
return fmt.Errorf("stop VNC server: %w", err)
}
return nil
}
// vncApprover adapts the generic approval.Broker for the VNC server.
type vncApprover struct {
broker *approval.Broker
statusRecorder *peer.Status
}
func (a *vncApprover) Request(ctx context.Context, info vncserver.ApprovalInfo) (vncserver.ApprovalDecision, error) {
// Resolve the source overlay IP to a peer FQDN for the prompt label.
if info.PeerName == "" && info.SourceIP != "" && a.statusRecorder != nil {
if fqdn, ok := a.statusRecorder.PeerByIP(info.SourceIP); ok {
info.PeerName = fqdn
}
}
subject := fmt.Sprintf("VNC connection from %s", displayPeer(info))
meta := map[string]string{
"peer_name": info.PeerName,
"peer_pubkey": info.PeerPubKey,
"source_ip": info.SourceIP,
"mode": info.Mode,
"username": info.Username,
"initiator": info.Initiator,
}
d, err := a.broker.Request(ctx, approval.Prompt{
Kind: approval.KindVNC,
Subject: subject,
Metadata: meta,
})
if err != nil {
return vncserver.ApprovalDecision{}, err
}
return vncserver.ApprovalDecision{ViewOnly: d.ViewOnly}, nil
}
func displayPeer(info vncserver.ApprovalInfo) string {
if info.Initiator != "" {
return info.Initiator
}
if info.PeerName != "" {
return info.PeerName
}
if info.SourceIP != "" {
return info.SourceIP
}
if info.PeerPubKey != "" {
return info.PeerPubKey
}
return "unknown peer"
}

View File

@@ -1,31 +0,0 @@
//go:build freebsd
package internal
import (
"fmt"
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
// newConsoleVNC builds the FreeBSD console fallback: vt(4) framebuffer
// for capture, /dev/uinput for input. The uinput device requires the
// `uinput` kernel module (`kldload uinput`); without it, input init
// fails and we drop to a stub injector so the user still gets a
// view-only screen mirror.
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
poller := vncserver.NewFBPoller("")
w, h := poller.Width(), poller.Height()
if w == 0 || h == 0 {
poller.Close()
return nil, nil, fmt.Errorf("vt framebuffer init failed (vt may not allow mmap on this driver)")
}
if inj, err := vncserver.NewUInputInjector(w, h); err == nil {
return poller, inj, nil
} else {
log.Infof("VNC console: uinput unavailable (%v); view-only mode. Run `kldload uinput` to enable input.", err)
return poller, &vncserver.StubInputInjector{}, nil
}
}

View File

@@ -1,30 +0,0 @@
//go:build linux && !android
package internal
import (
"fmt"
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
// newConsoleVNC builds a framebuffer + uinput VNC backend for boxes
// without a running X server. Used as the auto-fallback when
// newPlatformVNC can't reach X. Returns an error when /dev/fb0 or
// /dev/uinput aren't usable so the caller can drop back to a stub.
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
poller := vncserver.NewFBPoller("")
w, h := poller.Width(), poller.Height()
if w == 0 || h == 0 {
poller.Close()
return nil, nil, fmt.Errorf("framebuffer capturer init failed (is /dev/fb0 readable?)")
}
inj, err := vncserver.NewUInputInjector(w, h)
if err != nil {
log.Debugf("uinput unavailable, falling back to view-only VNC: %v", err)
return poller, &vncserver.StubInputInjector{}, nil
}
return poller, inj, nil
}

View File

@@ -1,34 +0,0 @@
//go:build darwin && !ios
package internal
import (
"os"
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
capturer := vncserver.NewMacPoller()
// Prompt for Screen Recording at server-enable time rather than first
// client-connect. The native prompt is far easier for users to act on
// in the moment they toggled VNC on than later when "the screen looks
// like wallpaper" would otherwise be the only clue.
vncserver.PrimeScreenCapturePermission()
injector, err := vncserver.NewMacInputInjector()
if err != nil {
log.Debugf("VNC: macOS input injector: %v", err)
return capturer, &vncserver.StubInputInjector{}, true
}
return capturer, injector, true
}
// vncNeedsServiceMode reports whether the running process is a system
// LaunchDaemon (root, parented by launchd). Daemons sit in the global
// bootstrap namespace and cannot talk to WindowServer; we route capture
// through a per-user agent in that case.
func vncNeedsServiceMode() bool {
return os.Geteuid() == 0 && os.Getppid() == 1
}

View File

@@ -1,23 +0,0 @@
//go:build js || ios || android
package internal
import (
log "github.com/sirupsen/logrus"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
type vncServer interface{}
func (e *Engine) updateVNC() error { return nil }
func (e *Engine) updateVNCServerAuth(auth *mgmProto.VNCAuth) {
if auth == nil {
return
}
log.Debugf("ignoring VNC auth push on platform without a VNC server: %d session pubkeys, %d authorized users",
len(auth.GetSessionPubKeys()), len(auth.GetAuthorizedUsers()))
}
func (e *Engine) stopVNCServer() error { return nil }

View File

@@ -1,13 +0,0 @@
//go:build windows
package internal
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), true
}
func vncNeedsServiceMode() bool {
return vncserver.GetCurrentSessionID() == 0
}

View File

@@ -1,35 +0,0 @@
//go:build (linux && !android) || freebsd
package internal
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
// Prefer X11 when an X server is reachable. NewX11InputInjector probes
// DISPLAY (and /proc) eagerly, so a non-nil error here means no X.
injector, err := vncserver.NewX11InputInjector("", "", "")
if err == nil {
return vncserver.NewX11Poller("", ""), injector, true
}
log.Debugf("VNC: X11 not available: %v", err)
// Fallback for headless / pre-X states (kernel console, login manager
// without X, physical server in recovery): stream the framebuffer and
// inject input via /dev/uinput.
consoleCap, consoleInj, err := newConsoleVNC()
if err == nil {
log.Infof("VNC: using framebuffer console capture (%dx%d)", consoleCap.Width(), consoleCap.Height())
return consoleCap, consoleInj, true
}
log.Debugf("VNC: framebuffer console fallback unavailable: %v", err)
return &vncserver.StubCapturer{}, &vncserver.StubInputInjector{}, false
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -44,4 +44,5 @@ type wgIfaceBase interface {
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
MTU() uint16
}

View File

@@ -124,6 +124,11 @@ func (d *BindListener) ReadPackets() {
d.done.Done()
}
// CapturedPacket is unused in userspace bind mode: first-packet reinjection is kernel-only.
func (d *BindListener) CapturedPacket() []byte {
return nil
}
// Close stops the listener and cleans up resources.
func (d *BindListener) Close() {
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")

View File

@@ -45,10 +45,6 @@ type MockWGIfaceBind struct {
endpointMgr *mockEndpointManager
}
func (m *MockWGIfaceBind) RemovePeer(string) error {
return nil
}
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
@@ -68,6 +64,10 @@ func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
return m.endpointMgr
}
func (m *MockWGIfaceBind) MTU() uint16 {
return 1280
}
func TestBindListener_Creation(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
@@ -207,8 +207,9 @@ func TestManager_BindMode(t *testing.T) {
require.NoError(t, err)
select {
case peerConnID := <-mgr.OnActivityChan:
assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
case ev := <-mgr.OnActivityChan:
assert.Equal(t, cfg.PeerConnID, ev.PeerConnID, "Received peer connection ID should match")
assert.Nil(t, ev.FirstPacket, "Bind mode does not capture packets: reinjection is kernel-only")
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notification")
}
@@ -266,8 +267,8 @@ func TestManager_BindMode_MultiplePeers(t *testing.T) {
receivedPeers := make(map[peerid.ConnID]bool)
for i := 0; i < 2; i++ {
select {
case peerConnID := <-mgr.OnActivityChan:
receivedPeers[peerConnID] = true
case ev := <-mgr.OnActivityChan:
receivedPeers[ev.PeerConnID] = true
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notifications")
}

View File

@@ -3,11 +3,13 @@ package activity
import (
"fmt"
"net"
"slices"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
@@ -20,6 +22,8 @@ type UDPListener struct {
done sync.Mutex
isClosed atomic.Bool
capturedPacket []byte
}
// NewUDPListener creates a listener that detects activity via UDP socket reads.
@@ -46,9 +50,13 @@ func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener,
}
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
// The first packet that triggers activity is captured so it can be reinjected through the real
// transport once it is established. Without this, kernel WireGuard's handshake initiation would be
// dropped and WG would only retry after REKEY_TIMEOUT.
func (d *UDPListener) ReadPackets() {
for {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
buf := make([]byte, int(d.wgIface.MTU())+bufsize.WGBufferOverhead)
n, remoteAddr, err := d.conn.ReadFromUDP(buf)
if err != nil {
if d.isClosed.Load() {
d.peerCfg.Log.Infof("exit from activity listener")
@@ -62,20 +70,24 @@ func (d *UDPListener) ReadPackets() {
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
continue
}
d.peerCfg.Log.Infof("activity detected")
d.capturedPacket = slices.Clone(buf[:n])
d.peerCfg.Log.Infof("activity detected, captured %d bytes for reinjection", n)
break
}
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
// Ignore close error as it may return "use of closed network connection" if already closed.
// Leave the peer in place. ConfigureWGEndpoint will UpdatePeer with the real endpoint;
// removing the peer here wipes kernel WG's staged queue and drops the user packet that
// triggered activation.
_ = d.conn.Close()
d.done.Unlock()
}
// CapturedPacket returns the first packet that triggered activity, or nil if none was captured.
// Safe to call after ReadPackets returns.
func (d *UDPListener) CapturedPacket() []byte {
return d.capturedPacket
}
// Close stops the listener and cleans up resources.
func (d *UDPListener) Close() {
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())

View File

@@ -19,17 +19,25 @@ import (
type listener interface {
ReadPackets()
Close()
CapturedPacket() []byte
}
// Event reports activity on a managed peer. FirstPacket is the bytes that triggered activation,
// captured for reinjection through the real transport.
type Event struct {
PeerConnID peerid.ConnID
FirstPacket []byte
}
type WgInterface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool
Address() wgaddr.Address
MTU() uint16
}
type Manager struct {
OnActivityChan chan peerid.ConnID
OnActivityChan chan Event
wgIface WgInterface
@@ -41,7 +49,7 @@ type Manager struct {
func NewManager(wgIface WgInterface) *Manager {
m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1),
OnActivityChan: make(chan Event, 1),
wgIface: wgIface,
peers: make(map[peerid.ConnID]listener),
done: make(chan struct{}),
@@ -116,12 +124,12 @@ func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
delete(m.peers, peerConnID)
m.mu.Unlock()
m.notify(peerConnID)
m.notify(Event{PeerConnID: peerConnID, FirstPacket: l.CapturedPacket()})
}
func (m *Manager) notify(peerConnID peerid.ConnID) {
func (m *Manager) notify(ev Event) {
select {
case <-m.done:
case m.OnActivityChan <- peerConnID:
case m.OnActivityChan <- ev:
}
}

View File

@@ -1,6 +1,7 @@
package activity
import (
"bytes"
"net"
"net/netip"
"testing"
@@ -25,10 +26,6 @@ func (m *MocPeer) ConnID() peerid.ConnID {
type MocWGIface struct {
}
func (m MocWGIface) RemovePeer(string) error {
return nil
}
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
@@ -44,6 +41,10 @@ func (m MocWGIface) Address() wgaddr.Address {
}
}
func (m MocWGIface) MTU() uint16 {
return 1280
}
// GetPeerListener is a test helper to access listeners
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
m.mu.Lock()
@@ -86,11 +87,15 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
}
select {
case peerConnID := <-mgr.OnActivityChan:
if peerConnID != peerCfg1.PeerConnID {
t.Fatalf("unexpected peerConnID: %v", peerConnID)
case ev := <-mgr.OnActivityChan:
if ev.PeerConnID != peerCfg1.PeerConnID {
t.Fatalf("unexpected peerConnID: %v", ev.PeerConnID)
}
if !bytes.Equal(ev.FirstPacket, []byte{0x01, 0x02, 0x03, 0x04, 0x05}) {
t.Fatalf("unexpected first packet: %v", ev.FirstPacket)
}
case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for activity")
}
}

View File

@@ -130,8 +130,8 @@ func (m *Manager) Start(ctx context.Context) {
select {
case <-ctx.Done():
return
case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(peerConnID)
case ev := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ev)
case peerIDs := <-m.inactivityManager.InactivePeersChan():
m.onPeerInactivityTimedOut(peerIDs)
}
@@ -513,13 +513,13 @@ func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string,
return false
}
func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
func (m *Manager) onPeerActivity(ev activity.Event) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
mp, ok := m.managedPeersByConnID[ev.PeerConnID]
if !ok {
log.Errorf("peer not found by conn id: %v", peerConnID)
log.Errorf("peer not found by conn id: %v", ev.PeerConnID)
return
}
@@ -536,7 +536,7 @@ func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
m.activateHAGroupPeers(mp.peerCfg)
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
m.peerStore.PeerConnOpenWithFirstPacket(m.engineCtx, mp.peerCfg.PublicKey, ev.FirstPacket)
}
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {

View File

@@ -17,4 +17,5 @@ type WGIface interface {
IsUserspaceBind() bool
Address() wgaddr.Address
LastActivities() map[string]monotime.Time
MTU() uint16
}

View File

@@ -120,30 +120,24 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
m.trimLocked()
}
func (m *influxDBMetrics) RecordVNCSessionTick(_ context.Context, agentInfo AgentInfo, tick VNCSessionTick) {
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s",
func (m *influxDBMetrics) RecordSyncPhase(_ context.Context, agentInfo AgentInfo, phase string, duration time.Duration) {
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s,phase=%s",
agentInfo.DeploymentType.String(),
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
phase,
)
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_vnc_traffic",
measurement: "netbird_sync_phase",
tags: tags,
fields: map[string]float64{
"period_seconds": tick.Period.Seconds(),
"bytes_out": float64(tick.BytesOut),
"writes": float64(tick.Writes),
"fbus": float64(tick.FBUs),
"max_fbu_bytes": float64(tick.MaxFBUBytes),
"max_fbu_rects": float64(tick.MaxFBURects),
"max_write_bytes": float64(tick.MaxWriteBytes),
"write_time_seconds": float64(tick.WriteNanos) / 1e9,
"duration_seconds": duration.Seconds(),
},
timestamp: time.Now(),
})

View File

@@ -78,6 +78,25 @@ Tags:
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
### Sync Phase Timing
Measurement: `netbird_sync_phase`
Breaks down where time goes inside a single sync, so the total `netbird_sync` duration can be attributed to the sub-step that dominates.
| Field | Description |
|-------|-------------|
| `duration_seconds` | Time spent in one sub-phase of sync processing |
Tags:
- `phase`: the sub-phase — `netbird_config`, `checks`, `persist`, `dns_server`, `routes_classify`, `routes_apply`, `filtering`, `dns_forwarder`, `forward_rules`, `offline_peers`, `removed_peers`, `modified_peers`, `added_peers`, `lazy_exclude`
- `deployment_type`: "cloud" | "selfhosted" | "unknown"
- `version`: NetBird version string
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
**Note:** this is wall-time per phase — it includes both CPU work and time spent waiting on locks. A slow phase points to *where* the time goes, not *why*; pair it with lock-wait metrics to tell contention apart from real work.
### Login Duration
Measurement: `netbird_login`
@@ -191,4 +210,52 @@ docker compose exec influxdb influx query \
# Check ingest server health
curl http://localhost:8087/health
```
```
## Analyzing a Debug Bundle
Metrics collection is always on, so every debug bundle ships a `metrics.txt` in InfluxDB line protocol — a timestamped time series of all recorded events (sync durations, sync phases, connection stages, login). You can replay it into the local stack and graph it, without a running client.
The bundle's `metrics.txt` is a rolling window (capped at 5 days / ~20k samples, see [Buffer Limits](#buffer-limits)). For a connection incident the relevant window is short (connection setup is seconds), so a bundle captured during the issue is enough.
### 1. Start the stack
```bash
# From this directory (client/internal/metrics/infra)
INFLUXDB_ADMIN_TOKEN=admin123 INFLUXDB_ADMIN_PASSWORD=admin123 GRAFANA_ADMIN_PASSWORD=admin123 \
docker compose up -d
```
(`admin123` are throwaway local credentials — fine for offline analysis.)
### 2. Clear any previous data
So you only see this bundle:
```bash
docker exec influxdb influx delete --org netbird --bucket metrics --token admin123 \
--start 1970-01-01T00:00:00Z --stop 2100-01-01T00:00:00Z
```
### 3. Import the bundle's metrics.txt
InfluxDB is not exposed on the host, so import inside the container:
```bash
docker cp /path/to/bundle/metrics.txt influxdb:/tmp/m.txt
docker exec influxdb influx write --org netbird --bucket metrics --precision ns \
--token admin123 --file /tmp/m.txt
```
Re-importing the same file is idempotent (same measurement+tags+timestamp overwrites).
### 4. View the dashboards
Grafana on http://localhost:3001 (login `admin` / `admin123`), datasource pre-provisioned:
- **Where sync time goes:** http://localhost:3001/d/netbird-sync-phases/netbird-sync-phases-where-time-goes
- **General client metrics:** http://localhost:3001/d/netbird-influxdb-metrics
**Set the time range** to cover the bundle's timestamps (e.g. "Last 7 days" or an absolute range matching when the bundle was taken) — with the default short range the panels look empty.
Bundles are distinguishable by the `version` tag; add a tag at import time (e.g. `sed 's/^netbird_\([a-z_]*\),/netbird_\1,bundle=mycase,/' metrics.txt`) if you want to compare several side by side.

View File

@@ -0,0 +1,259 @@
{
"annotations": {
"list": []
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 1,
"links": [],
"refresh": "",
"schemaVersion": 39,
"tags": [
"netbird",
"sync"
],
"templating": {
"list": [
{
"current": {
"text": "All",
"value": "$__all"
},
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"definition": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
"includeAll": true,
"label": "version",
"multi": true,
"name": "version",
"query": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
"refresh": 2,
"type": "query",
"allValue": ".*"
}
]
},
"time": {
"from": "now-2d",
"to": "now"
},
"timepicker": {},
"timezone": "",
"title": "NetBird Sync Phases (where time goes)",
"uid": "netbird-sync-phases",
"version": 1,
"panels": [
{
"id": 1,
"title": "Time per phase over time (stacked, ms)",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 10,
"w": 24,
"x": 0,
"y": 0
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"custom": {
"drawStyle": "bars",
"stacking": {
"mode": "normal",
"group": "A"
},
"fillOpacity": 80,
"lineWidth": 0
}
},
"overrides": []
},
"options": {
"legend": {
"displayMode": "table",
"placement": "right",
"calcs": [
"max",
"mean"
]
},
"tooltip": {
"mode": "multi",
"sort": "desc"
}
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"phase\"])\n |> group(columns: [\"phase\"])"
}
]
},
{
"id": 2,
"title": "p95 per phase (ms)",
"type": "bargauge",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 11,
"w": 12,
"x": 0,
"y": 10
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"color": {
"mode": "continuous-GrYlRd"
}
},
"overrides": []
},
"options": {
"displayMode": "gradient",
"orientation": "horizontal",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showUnfilled": true
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> sort(columns: [\"_value\"], desc: true)"
}
]
},
{
"id": 3,
"title": "Per-phase stats (ms): mean / p95 / max",
"type": "table",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 11,
"w": 12,
"x": 12,
"y": 10
},
"fieldConfig": {
"defaults": {
"unit": "ms"
},
"overrides": []
},
"options": {
"showHeader": true,
"sortBy": [
{
"displayName": "max",
"desc": true
}
]
},
"transformations": [
{
"id": "merge",
"options": {}
}
],
"targets": [
{
"refId": "mean",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> mean()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"mean\"})"
},
{
"refId": "p95",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"p95\"})"
},
{
"refId": "max",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> max()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"max\"})"
}
]
},
{
"id": 4,
"title": "Total sync duration (netbird_sync, ms) \u2014 reference",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 24,
"x": 0,
"y": 21
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"custom": {
"drawStyle": "points",
"pointSize": 5
}
},
"overrides": []
},
"options": {
"legend": {
"displayMode": "table",
"placement": "right",
"calcs": [
"max",
"mean"
]
},
"tooltip": {
"mode": "single"
}
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"version\"])\n |> group(columns: [\"version\"])"
}
]
}
]
}

View File

@@ -19,7 +19,7 @@ const (
defaultListenAddr = ":8087"
defaultInfluxDBURL = "http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns"
maxBodySize = 50 * 1024 * 1024 // 50 MB max request body
maxDurationSeconds = 300.0 // reject any duration field > 5 minutes
maxDurationSeconds = 86400.0 // reject any duration field > 24 hours
peerIDLength = 16 // truncated SHA-256: 8 bytes = 16 hex chars
maxTagValueLength = 64 // reject tag values longer than this
)
@@ -59,6 +59,19 @@ var allowedMeasurements = map[string]measurementSpec{
"peer_id": true,
},
},
"netbird_sync_phase": {
allowedFields: map[string]bool{
"duration_seconds": true,
},
allowedTags: map[string]bool{
"deployment_type": true,
"version": true,
"os": true,
"arch": true,
"peer_id": true,
"phase": true,
},
},
"netbird_login": {
allowedFields: map[string]bool{
"duration_seconds": true,

View File

@@ -53,14 +53,14 @@ func TestValidateLine_NegativeValue(t *testing.T) {
}
func TestValidateLine_DurationTooLarge(t *testing.T) {
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=999 1234567890`
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=100000 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "too large")
}
func TestValidateLine_TotalSecondsTooLarge(t *testing.T) {
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=500 1234567890`
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=100000 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "too large")

View File

@@ -56,14 +56,12 @@ type metricsImplementation interface {
// RecordSyncDuration records how long it took to process a sync message
RecordSyncDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration)
// RecordSyncPhase records how long a single sub-phase of sync processing took
RecordSyncPhase(ctx context.Context, agentInfo AgentInfo, phase string, duration time.Duration)
// RecordLoginDuration records how long the login to management took
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
// RecordVNCSessionTick records a periodic snapshot of one VNC
// session's wire activity. Called once per metricsConn tick interval
// (and once at session close), only when the tick saw activity.
RecordVNCSessionTick(ctx context.Context, agentInfo AgentInfo, tick VNCSessionTick)
// Export exports metrics in InfluxDB line protocol format
Export(w io.Writer) error
@@ -83,21 +81,6 @@ type ClientMetrics struct {
pushCancel context.CancelFunc
}
// VNCSessionTick is one sampling slice of a VNC session's wire activity.
// BytesOut / Writes / FBUs / WriteNanos are deltas observed during this
// tick; Max* fields are the high-water marks observed during the tick.
// Period is the wall-clock duration the deltas cover.
type VNCSessionTick struct {
Period time.Duration
BytesOut uint64
Writes uint64
FBUs uint64
MaxFBUBytes uint64
MaxFBURects uint64
MaxWriteBytes uint64
WriteNanos uint64
}
// ConnectionStageTimestamps holds timestamps for each connection stage
type ConnectionStageTimestamps struct {
SignalingReceived time.Time // First signal received from remote peer (both initial and reconnection)
@@ -147,15 +130,16 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
}
// RecordVNCSessionTick records a periodic snapshot of one VNC session.
func (c *ClientMetrics) RecordVNCSessionTick(ctx context.Context, tick VNCSessionTick) {
// RecordSyncPhase records the duration of a single sub-phase of sync processing
func (c *ClientMetrics) RecordSyncPhase(ctx context.Context, phase string, duration time.Duration) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
c.impl.RecordVNCSessionTick(ctx, agentInfo, tick)
c.impl.RecordSyncPhase(ctx, agentInfo, phase, duration)
}
// RecordLoginDuration records how long the login to management server took

View File

@@ -70,10 +70,10 @@ func (m *mockMetrics) RecordConnectionStages(_ context.Context, _ AgentInfo, _ s
func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.Duration) {
}
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
func (m *mockMetrics) RecordSyncPhase(_ context.Context, _ AgentInfo, _ string, _ time.Duration) {
}
func (m *mockMetrics) RecordVNCSessionTick(_ context.Context, _ AgentInfo, _ VNCSessionTick) {
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
}
func (m *mockMetrics) Export(w io.Writer) error {

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"runtime"
"slices"
"sync"
"time"
@@ -136,6 +137,39 @@ type Conn struct {
// Connection stage timestamps for metrics
metricsRecorder MetricsRecorder
metricsStages *MetricsStages
// pendingFirstPacket is the lazyconn-captured handshake init, replayed once the real
// transport is up.
pendingFirstPacket []byte
}
// injectPendingFirstPacket replays the captured handshake through the proxy if present, else
// directly through the ICE conn. The packet is cleared only after a successful write, so a failed
// or transport-less attempt leaves it available for a later reinjection. Caller must hold conn.mu.
func (conn *Conn) injectPendingFirstPacket(proxy wgproxy.Proxy, directConn net.Conn) {
pkt := conn.pendingFirstPacket
if len(pkt) == 0 {
return
}
switch {
case proxy != nil:
if err := proxy.InjectPacket(pkt); err != nil {
conn.Log.Debugf("failed to reinject captured first packet via proxy: %v", err)
return
}
case directConn != nil:
if _, err := directConn.Write(pkt); err != nil {
conn.Log.Debugf("failed to reinject captured first packet via direct conn: %v", err)
return
}
default:
conn.Log.Debugf("no transport available to reinject captured first packet")
return
}
conn.pendingFirstPacket = nil
conn.Log.Debugf("reinjected captured first packet (%d bytes)", len(pkt))
}
// NewConn creates a new not opened Conn to the remote peer.
@@ -172,6 +206,16 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used.
func (conn *Conn) Open(engineCtx context.Context) error {
return conn.open(engineCtx, nil)
}
// OpenWithFirstPacket opens the connection like Open and stashes firstPacket to be replayed once
// the real transport is established. The packet is retained only on a successful open.
func (conn *Conn) OpenWithFirstPacket(engineCtx context.Context, firstPacket []byte) error {
return conn.open(engineCtx, firstPacket)
}
func (conn *Conn) open(engineCtx context.Context, firstPacket []byte) error {
conn.mu.Lock()
defer conn.mu.Unlock()
@@ -227,6 +271,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
defer conn.wg.Done()
conn.guard.Start(conn.ctx, conn.onGuardEvent)
}()
if len(firstPacket) > 0 {
conn.pendingFirstPacket = slices.Clone(firstPacket)
}
conn.opened = true
return nil
}
@@ -423,6 +470,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.wgProxyRelay.RedirectAs(ep)
}
conn.injectPendingFirstPacket(wgProxy, iceConnInfo.RemoteConn)
conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo, updateTime)
@@ -546,6 +595,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround()
conn.injectPendingFirstPacket(wgProxy, nil)
conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay
conn.statusRelay.SetConnected()

View File

@@ -1241,15 +1241,6 @@ func (d *Status) SubscribeToEvents() *EventSubscription {
}
}
// HasEventSubscribers reports whether any client is currently subscribed
// to the daemon's SystemEvent stream. Used by the VNC approval broker to
// fail closed when no UI is connected to prompt the user.
func (d *Status) HasEventSubscribers() bool {
d.eventMux.Lock()
defer d.eventMux.Unlock()
return len(d.eventStreams) > 0
}
// UnsubscribeFromEvents removes an event subscription
func (d *Status) UnsubscribeFromEvents(sub *EventSubscription) {
if sub == nil {

View File

@@ -88,11 +88,24 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
if !ok {
return
}
// this can be blocked because of the connect open limiter semaphore
if err := p.Open(ctx); err != nil {
p.Log.Errorf("failed to open peer connection: %v", err)
}
}
// PeerConnOpenWithFirstPacket opens the peer connection and stashes a first packet to be
// reinjected once the real transport is established.
func (s *Store) PeerConnOpenWithFirstPacket(ctx context.Context, pubKey string, firstPacket []byte) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return
}
if err := p.OpenWithFirstPacket(ctx, firstPacket); err != nil {
p.Log.Errorf("failed to open peer connection: %v", err)
}
}
func (s *Store) PeerConnIdle(pubKey string) {

View File

@@ -70,8 +70,6 @@ type ConfigInput struct {
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
ServerVNCAllowed *bool
DisableVNCApproval *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -127,8 +125,6 @@ type Config struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
ServerVNCAllowed *bool
DisableVNCApproval *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -458,33 +454,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.ServerVNCAllowed != nil {
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
if *input.ServerVNCAllowed {
log.Infof("enabling VNC server")
} else {
log.Infof("disabling VNC server")
}
config.ServerVNCAllowed = input.ServerVNCAllowed
updated = true
}
} else if config.ServerVNCAllowed == nil {
config.ServerVNCAllowed = util.False()
updated = true
}
if input.DisableVNCApproval != nil {
if config.DisableVNCApproval == nil || *input.DisableVNCApproval != *config.DisableVNCApproval {
if *input.DisableVNCApproval {
log.Infof("disabling VNC connection approval prompt")
} else {
log.Infof("enabling VNC connection approval prompt")
}
config.DisableVNCApproval = input.DisableVNCApproval
updated = true
}
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")

View File

@@ -226,12 +226,11 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
return
}
// All query types for an intercepted domain are forwarded to the peer's
// DNS forwarder, which owns the name. Falling through to the system
// resolver would let it answer NXDOMAIN for a name it isn't authoritative
// for, poisoning the whole name (including the A/AAAA records the route
// does serve). The forwarder answers NODATA for types it cannot resolve.
d.mu.RLock()
peerKey := d.currentPeerKey
d.mu.RUnlock()
@@ -293,19 +292,6 @@ func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger
}
}
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed writing DNS continue response: %v", err)
}
}
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
if !exists {

View File

@@ -74,14 +74,6 @@ func New(filePath string) *Manager {
}
}
// FilePath returns the path of the underlying state file.
func (m *Manager) FilePath() string {
if m == nil {
return ""
}
return m.filePath
}
// Start starts the state manager periodic save routine
func (m *Manager) Start() {
if m == nil {

File diff suppressed because it is too large Load Diff

View File

@@ -121,14 +121,6 @@ service DaemonService {
// ExposeService exposes a local port via the NetBird reverse proxy
rpc ExposeService(ExposeServiceRequest) returns (stream ExposeServiceEvent) {}
// RespondApproval delivers the user's accept/deny decision for a
// pending user-approval prompt. The daemon pushes the prompt as a
// SystemEvent with category APPROVAL and metadata key "request_id";
// the UI calls this RPC with the same request_id to unblock whichever
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
// the UI which subsystem the prompt belongs to.
rpc RespondApproval(RespondApprovalRequest) returns (RespondApprovalResponse) {}
}
@@ -215,10 +207,6 @@ message LoginRequest {
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
optional bool disable_ipv6 = 40;
optional bool serverVNCAllowed = 41;
optional bool disableVNCApproval = 42;
}
message LoginResponse {
@@ -329,16 +317,12 @@ message GetConfigResponse {
bool disable_ipv6 = 27;
bool serverVNCAllowed = 28;
bool disableVNCApproval = 29;
// mDMManagedFields lists the names of configuration keys whose value is
// currently enforced by an MDM policy. Names match mdm.Key* constants
// (e.g. "managementURL", "disableClientRoutes"). UI/CLI clients should
// render the corresponding inputs as read-only and display a "managed
// by MDM" indicator.
repeated string mDMManagedFields = 30;
repeated string mDMManagedFields = 28;
}
// PeerState contains the latest state of a peer
@@ -423,25 +407,6 @@ message SSHServerState {
repeated SSHSessionInfo sessions = 2;
}
// VNCSessionInfo contains information about an active VNC session
message VNCSessionInfo {
string remoteAddress = 1;
string mode = 2;
string username = 3;
// userID is the Noise-verified session identity (hashed user ID from
// the ACL session-key entry), empty when auth is disabled.
string userID = 4;
// initiator is the human-readable display name of the dashboard user
// who minted the SessionPubKey, when known.
string initiator = 5;
}
// VNCServerState contains the latest state of the VNC server
message VNCServerState {
bool enabled = 1;
repeated VNCSessionInfo sessions = 2;
}
// FullStatus contains the full state held by the Status instance
message FullStatus {
ManagementState managementState = 1;
@@ -456,7 +421,6 @@ message FullStatus {
bool lazyConnectionEnabled = 9;
SSHServerState sshServerState = 10;
VNCServerState vncServerState = 11;
}
// Networks
@@ -645,7 +609,6 @@ message SystemEvent {
AUTHENTICATION = 2;
CONNECTIVITY = 3;
SYSTEM = 4;
APPROVAL = 5;
}
string id = 1;
@@ -736,10 +699,6 @@ message SetConfigRequest {
optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
optional bool disable_ipv6 = 35;
optional bool serverVNCAllowed = 36;
optional bool disableVNCApproval = 37;
}
message SetConfigResponse{}
@@ -970,18 +929,3 @@ message StartBundleCaptureRequest {
message StartBundleCaptureResponse {}
message StopBundleCaptureRequest {}
message StopBundleCaptureResponse {}
message RespondApprovalRequest {
// request_id matches the SystemEvent metadata key emitted by the daemon
// when a subsystem awaits user approval for an inbound connection.
string request_id = 1;
// accept is true if the user approved the request, false if they
// denied it. A missing or unknown request_id is treated as a no-op.
bool accept = 2;
// view_only signals that the user granted the connection but withheld
// input control. Only meaningful when accept is true; ignored when
// accept is false.
bool view_only = 3;
}
message RespondApprovalResponse {}

View File

@@ -59,7 +59,6 @@ const (
DaemonService_StopCPUProfile_FullMethodName = "/daemon.DaemonService/StopCPUProfile"
DaemonService_GetInstallerResult_FullMethodName = "/daemon.DaemonService/GetInstallerResult"
DaemonService_ExposeService_FullMethodName = "/daemon.DaemonService/ExposeService"
DaemonService_RespondApproval_FullMethodName = "/daemon.DaemonService/RespondApproval"
)
// DaemonServiceClient is the client API for DaemonService service.
@@ -137,13 +136,6 @@ type DaemonServiceClient interface {
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExposeServiceEvent], error)
// RespondApproval delivers the user's accept/deny decision for a
// pending user-approval prompt. The daemon pushes the prompt as a
// SystemEvent with category APPROVAL and metadata key "request_id";
// the UI calls this RPC with the same request_id to unblock whichever
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
// the UI which subsystem the prompt belongs to.
RespondApproval(ctx context.Context, in *RespondApprovalRequest, opts ...grpc.CallOption) (*RespondApprovalResponse, error)
}
type daemonServiceClient struct {
@@ -581,16 +573,6 @@ func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServi
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type DaemonService_ExposeServiceClient = grpc.ServerStreamingClient[ExposeServiceEvent]
func (c *daemonServiceClient) RespondApproval(ctx context.Context, in *RespondApprovalRequest, opts ...grpc.CallOption) (*RespondApprovalResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RespondApprovalResponse)
err := c.cc.Invoke(ctx, DaemonService_RespondApproval_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility.
@@ -666,13 +648,6 @@ type DaemonServiceServer interface {
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error
// RespondApproval delivers the user's accept/deny decision for a
// pending user-approval prompt. The daemon pushes the prompt as a
// SystemEvent with category APPROVAL and metadata key "request_id";
// the UI calls this RPC with the same request_id to unblock whichever
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
// the UI which subsystem the prompt belongs to.
RespondApproval(context.Context, *RespondApprovalRequest) (*RespondApprovalResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -803,9 +778,6 @@ func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *Ins
func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error {
return status.Error(codes.Unimplemented, "method ExposeService not implemented")
}
func (UnimplementedDaemonServiceServer) RespondApproval(context.Context, *RespondApprovalRequest) (*RespondApprovalResponse, error) {
return nil, status.Error(codes.Unimplemented, "method RespondApproval not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
func (UnimplementedDaemonServiceServer) testEmbeddedByValue() {}
@@ -1526,24 +1498,6 @@ func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStr
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type DaemonService_ExposeServiceServer = grpc.ServerStreamingServer[ExposeServiceEvent]
func _DaemonService_RespondApproval_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RespondApprovalRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).RespondApproval(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: DaemonService_RespondApproval_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).RespondApproval(ctx, req.(*RespondApprovalRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1699,10 +1653,6 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetInstallerResult",
Handler: _DaemonService_GetInstallerResult_Handler,
},
{
MethodName: "RespondApproval",
Handler: _DaemonService_RespondApproval_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -111,7 +111,7 @@ func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.Daemo
return status.Errorf(codes.Internal, "create capture session: %v", err)
}
engine, err := s.claimCapture(sess, func() { pw.Close() })
engine, err := s.claimCapture(sess)
if err != nil {
sess.Stop()
pw.Close()
@@ -190,7 +190,10 @@ func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCap
s.stopBundleCaptureLocked()
s.cleanupBundleCapture()
s.evictActiveCaptureLocked()
if s.activeCapture != nil {
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
}
engine, err := s.getCaptureEngineLocked()
if err != nil {
@@ -301,58 +304,29 @@ func (s *Server) cleanupBundleCapture() {
s.bundleCapture = nil
}
// claimCapture reserves the engine's capture slot for sess. If another
// capture is already running it is evicted: a previous streaming session
// whose gRPC client died and never freed the slot stays stuck otherwise,
// and a bundle capture is just informational state.
func (s *Server) claimCapture(sess *capture.Session, cancel func()) (*internal.Engine, error) {
// claimCapture reserves the engine's capture slot for sess. Returns
// FailedPrecondition if another capture is already active.
func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.evictActiveCaptureLocked()
if s.activeCapture != nil {
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
}
engine, err := s.getCaptureEngineLocked()
if err != nil {
return nil, err
}
s.activeCapture = sess
s.activeCaptureCancel = cancel
return engine, nil
}
// evictActiveCaptureLocked tears down whatever capture currently owns
// the engine slot so a fresh claim can succeed. Caller must hold mutex.
func (s *Server) evictActiveCaptureLocked() {
if s.activeCapture == nil {
return
}
if s.bundleCapture != nil && s.bundleCapture.sess == s.activeCapture {
log.Infof("evicting running bundle capture to start a new capture")
s.stopBundleCaptureLocked()
return
}
log.Infof("evicting previous streaming capture to start a new one")
prev := s.activeCapture
cancel := s.activeCaptureCancel
if engine, err := s.getCaptureEngineLocked(); err == nil {
if err := engine.SetCapture(nil); err != nil {
log.Debugf("clear previous capture: %v", err)
}
}
s.activeCapture = nil
s.activeCaptureCancel = nil
prev.Stop()
if cancel != nil {
cancel()
}
}
// releaseCapture clears the active-capture owner if it still matches sess.
func (s *Server) releaseCapture(sess *capture.Session) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.activeCapture == sess {
s.activeCapture = nil
s.activeCaptureCancel = nil
}
}
@@ -367,7 +341,6 @@ func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Eng
log.Debugf("clear capture: %v", err)
}
s.activeCapture = nil
s.activeCaptureCancel = nil
}
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {

View File

@@ -100,12 +100,8 @@ type Server struct {
captureEnabled bool
bundleCapture *bundleCapture
// activeCapture is the session currently installed on the engine; guarded by s.mutex.
activeCapture *capture.Session
// activeCaptureCancel tears down the streaming pipe/cancel for the
// active streaming capture so eviction unblocks the StartCapture RPC
// handler. Nil for bundle captures (they own their own context).
activeCaptureCancel func()
networksDisabled bool
activeCapture *capture.Session
networksDisabled bool
sleepHandler *sleephandler.SleepHandler
@@ -460,8 +456,6 @@ func (s *Server) setConfigInputFromRequest(msg *proto.SetConfigRequest) (profile
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect
config.ServerSSHAllowed = msg.ServerSSHAllowed
config.ServerVNCAllowed = msg.ServerVNCAllowed
config.DisableVNCApproval = msg.DisableVNCApproval
config.NetworkMonitor = msg.NetworkMonitor
config.DisableClientRoutes = msg.DisableClientRoutes
config.DisableServerRoutes = msg.DisableServerRoutes
@@ -1257,7 +1251,6 @@ func (s *Server) Status(
pbFullStatus := fullStatus.ToProto()
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
pbFullStatus.VncServerState = s.getVNCServerState()
statusResponse.FullStatus = pbFullStatus
}
@@ -1297,38 +1290,6 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
return sshServerState
}
// getVNCServerState retrieves the current VNC server state.
func (s *Server) getVNCServerState() *proto.VNCServerState {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
}
enabled, sessions := engine.GetVNCServerStatus()
pbSessions := make([]*proto.VNCSessionInfo, 0, len(sessions))
for _, sess := range sessions {
pbSessions = append(pbSessions, &proto.VNCSessionInfo{
RemoteAddress: sess.RemoteAddress,
Mode: sess.Mode,
Username: sess.Username,
UserID: sess.UserID,
Initiator: sess.Initiator,
})
}
return &proto.VNCServerState{
Enabled: enabled,
Sessions: pbSessions,
}
}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
func (s *Server) GetPeerSSHHostKey(
ctx context.Context,
@@ -1569,27 +1530,6 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
return nil
}
// RespondApproval relays the user's accept/deny decision for a pending
// approval prompt to the engine's broker. Unknown or already-resolved
// request_ids are silently no-op'd so a slow UI cannot deny a prompt the
// user already handled (or that already timed out).
func (s *Server) RespondApproval(_ context.Context, msg *proto.RespondApprovalRequest) (*proto.RespondApprovalResponse, error) {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "engine not running")
}
if !engine.RespondApproval(msg.GetRequestId(), msg.GetAccept(), msg.GetViewOnly()) {
log.Debugf("approval response for unknown request_id %s", msg.GetRequestId())
}
return &proto.RespondApprovalResponse{}, nil
}
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
@@ -1705,8 +1645,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
DisableVNCApproval: cfg.DisableVNCApproval != nil && *cfg.DisableVNCApproval,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,

View File

@@ -58,8 +58,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
serverVNCAllowed := true
disableVNCApproval := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
@@ -85,8 +83,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
ServerVNCAllowed: &serverVNCAllowed,
DisableVNCApproval: &disableVNCApproval,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
@@ -131,10 +127,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.NotNil(t, cfg.ServerVNCAllowed)
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
require.NotNil(t, cfg.DisableVNCApproval)
require.Equal(t, disableVNCApproval, *cfg.DisableVNCApproval)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
@@ -187,8 +179,6 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"ServerVNCAllowed": true,
"DisableVNCApproval": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
@@ -250,8 +240,6 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"allow-server-vnc": "ServerVNCAllowed",
"disable-vnc-approval": "DisableVNCApproval",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",

View File

@@ -1,4 +1,4 @@
package sessionauth
package auth
import (
"errors"
@@ -15,8 +15,6 @@ const (
DefaultUserIDClaim = "sub"
// Wildcard is a special user ID that matches all users
Wildcard = "*"
// sessionPubKeyLen is the size of an X25519 static public key in bytes.
sessionPubKeyLen = 32
)
var (
@@ -24,7 +22,6 @@ var (
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
ErrSessionKeyNotKnown = errors.New("session pubkey not registered")
)
// Authorizer handles SSH fine-grained access control authorization
@@ -38,17 +35,6 @@ type Authorizer struct {
// machineUsers maps OS login usernames to lists of authorized user indexes
machineUsers map[string][]uint32
// sessionPubKeys maps an X25519 static public key (as map-safe
// array) to the hashed user identity that key authenticates as.
// Populated from management's temporary-access flow; used by VNC to
// authenticate via the Noise_IK handshake.
sessionPubKeys map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash
// sessionDisplayNames mirrors sessionPubKeys with the optional
// human-readable display name management associated with each
// session key. Used by the per-connection UI approval prompt; not
// consulted by any authorization decision.
sessionDisplayNames map[[sessionPubKeyLen]byte]string
// mu protects the list of users
mu sync.RWMutex
}
@@ -64,29 +50,13 @@ type Config struct {
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
// If a user wants to login as a specific OS user, their index must be in the corresponding list
MachineUsers map[string][]uint32
// SessionPubKeys binds ephemeral X25519 static public keys to hashed
// user identities. Populated for VNC; ignored on the SSH side.
SessionPubKeys []SessionPubKey
}
// SessionPubKey is a single ephemeral-key entry: the 32-byte X25519
// static public key plus the hashed user identity it authenticates as,
// optionally plus a human-readable display name for the UI approval
// prompt to identify the requester.
type SessionPubKey struct {
PubKey []byte
UserIDHash sshuserhash.UserIDHash
DisplayName string
}
// NewAuthorizer creates a new SSH authorizer with empty configuration
func NewAuthorizer() *Authorizer {
a := &Authorizer{
userIDClaim: DefaultUserIDClaim,
machineUsers: make(map[string][]uint32),
sessionPubKeys: make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash),
sessionDisplayNames: make(map[[sessionPubKeyLen]byte]string),
userIDClaim: DefaultUserIDClaim,
machineUsers: make(map[string][]uint32),
}
return a
@@ -102,8 +72,6 @@ func (a *Authorizer) Update(config *Config) {
a.userIDClaim = DefaultUserIDClaim
a.authorizedUsers = []sshuserhash.UserIDHash{}
a.machineUsers = make(map[string][]uint32)
a.sessionPubKeys = make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash)
a.sessionDisplayNames = make(map[[sessionPubKeyLen]byte]string)
log.Info("SSH authorization cleared")
return
}
@@ -126,35 +94,8 @@ func (a *Authorizer) Update(config *Config) {
}
a.machineUsers = machineUsers
sessionPubKeys := make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash, len(config.SessionPubKeys))
sessionDisplayNames := make(map[[sessionPubKeyLen]byte]string, len(config.SessionPubKeys))
conflicted := make(map[[sessionPubKeyLen]byte]struct{})
for _, e := range config.SessionPubKeys {
if len(e.PubKey) != sessionPubKeyLen {
continue
}
var key [sessionPubKeyLen]byte
copy(key[:], e.PubKey)
if _, bad := conflicted[key]; bad {
continue
}
if existing, ok := sessionPubKeys[key]; ok && existing != e.UserIDHash {
log.Warnf("SSH auth: session pubkey bound to conflicting user hashes; dropping binding")
delete(sessionPubKeys, key)
delete(sessionDisplayNames, key)
conflicted[key] = struct{}{}
continue
}
sessionPubKeys[key] = e.UserIDHash
if e.DisplayName != "" {
sessionDisplayNames[key] = e.DisplayName
}
}
a.sessionPubKeys = sessionPubKeys
a.sessionDisplayNames = sessionDisplayNames
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings, %d session pubkeys",
len(config.AuthorizedUsers), len(machineUsers), len(sessionPubKeys))
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
len(config.AuthorizedUsers), len(machineUsers))
}
// Authorize validates if a user is authorized to login as the specified OS user.
@@ -214,54 +155,6 @@ func (a *Authorizer) GetUserIDClaim() string {
return a.userIDClaim
}
// LookupSessionKey resolves a Noise-verified static public key to the
// hashed user identity registered with it. Fails closed when the key is
// unknown.
func (a *Authorizer) LookupSessionKey(pubKey []byte) (sshuserhash.UserIDHash, error) {
var zero sshuserhash.UserIDHash
if len(pubKey) != sessionPubKeyLen {
return zero, fmt.Errorf("session pubkey wrong length: %d", len(pubKey))
}
var key [sessionPubKeyLen]byte
copy(key[:], pubKey)
a.mu.RLock()
hash, ok := a.sessionPubKeys[key]
a.mu.RUnlock()
if !ok {
return zero, ErrSessionKeyNotKnown
}
return hash, nil
}
// LookupSessionDisplayName returns the human-readable display name
// management associated with a session pubkey, or empty string when none
// is recorded. Never returns an error: a missing/unknown key reports as
// "" and the caller falls back to other identifiers.
func (a *Authorizer) LookupSessionDisplayName(pubKey []byte) string {
if len(pubKey) != sessionPubKeyLen {
return ""
}
var key [sessionPubKeyLen]byte
copy(key[:], pubKey)
a.mu.RLock()
name := a.sessionDisplayNames[key]
a.mu.RUnlock()
return name
}
// AuthorizeOSUserBySessionKey resolves the OS-user mapping for a session
// key. Mirrors Authorize but skips the JWT-hash step since the key has
// already been verified and the user identity hash is in hand.
func (a *Authorizer) AuthorizeOSUserBySessionKey(userIDHash sshuserhash.UserIDHash, osUsername string) (string, error) {
a.mu.RLock()
defer a.mu.RUnlock()
userIndex, found := a.findUserIndex(userIDHash)
if !found {
return "", fmt.Errorf("session user (hash: %s) not in authorized list for OS user %q: %w", userIDHash, osUsername, ErrUserNotAuthorized)
}
return a.checkMachineUserMapping("session", osUsername, userIndex)
}
// findUserIndex finds the index of a hashed user ID in the authorized users list
// Returns the index and true if found, 0 and false if not found
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {

View File

@@ -1,7 +1,6 @@
package sessionauth
package auth
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
@@ -611,61 +610,3 @@ func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
}
func TestAuthorizer_LookupSessionKey_Valid(t *testing.T) {
pub := bytesRepeat(0x11, sessionPubKeyLen)
userHash, err := sshauth.HashUserID("alice")
require.NoError(t, err)
a := NewAuthorizer()
a.Update(&Config{
AuthorizedUsers: []sshauth.UserIDHash{userHash},
MachineUsers: map[string][]uint32{Wildcard: {0}},
SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}},
})
got, err := a.LookupSessionKey(pub)
require.NoError(t, err)
assert.Equal(t, userHash, got)
if _, err := a.AuthorizeOSUserBySessionKey(got, "alice"); err != nil {
t.Fatalf("AuthorizeOSUserBySessionKey: %v", err)
}
}
func TestAuthorizer_LookupSessionKey_UnknownPub(t *testing.T) {
a := NewAuthorizer()
a.Update(&Config{})
_, err := a.LookupSessionKey(bytesRepeat(0x22, sessionPubKeyLen))
require.ErrorIs(t, err, ErrSessionKeyNotKnown)
}
func TestAuthorizer_LookupSessionKey_WrongLength(t *testing.T) {
a := NewAuthorizer()
_, err := a.LookupSessionKey([]byte("short"))
require.Error(t, err)
}
func TestAuthorizer_LookupSessionKey_UpdateClears(t *testing.T) {
pub := bytesRepeat(0x33, sessionPubKeyLen)
userHash, err := sshauth.HashUserID("alice")
require.NoError(t, err)
a := NewAuthorizer()
a.Update(&Config{SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}}})
if _, err := a.LookupSessionKey(pub); err != nil {
t.Fatalf("setup lookup: %v", err)
}
a.Update(&Config{})
if _, err := a.LookupSessionKey(pub); !errors.Is(err, ErrSessionKeyNotKnown) {
t.Fatalf("expected ErrSessionKeyNotKnown, got %v", err)
}
}
func bytesRepeat(b byte, n int) []byte {
out := make([]byte, n)
for i := range out {
out[i] = b
}
return out
}

View File

@@ -26,10 +26,10 @@ import (
cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)

View File

@@ -23,11 +23,11 @@ import (
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)

View File

@@ -23,10 +23,10 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
"github.com/netbirdio/netbird/util/netrelay"
"github.com/netbirdio/netbird/version"
)
@@ -197,14 +197,6 @@ type Config struct {
// HostKey is the SSH server host key in PEM format
HostKeyPEM []byte
// NetstackNet, when non-nil, makes the SSH server listen via the
// supplied userspace network stack instead of an OS socket.
NetstackNet *netstack.Net
// NetworkValidation, when non-zero, restricts inbound connections to
// peers inside the NetBird overlay defined by this WireGuard address.
NetworkValidation wgaddr.Address
}
// SessionInfo contains information about an active SSH session
@@ -216,15 +208,12 @@ type SessionInfo struct {
PortForwards []string
}
// New creates an SSH server instance from the supplied Config. Fields are
// read once at construction; mutating Config afterwards has no effect.
// JWT == nil disables JWT authentication.
// New creates an SSH server instance with the provided host key and optional JWT configuration
// If jwtConfig is nil, JWT authentication is disabled
func New(config *Config) *Server {
s := &Server{
mu: sync.RWMutex{},
hostKeyPEM: config.HostKeyPEM,
netstackNet: config.NetstackNet,
wgAddress: config.NetworkValidation,
sessions: make(map[sessionKey]*sessionState),
pendingAuthJWT: make(map[authKey]string),
remoteForwardListeners: make(map[forwardKey]net.Listener),
@@ -445,6 +434,20 @@ func (s *Server) buildSessionInfo(state *sessionState) SessionInfo {
return info
}
// SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock()
defer s.mu.Unlock()
s.netstackNet = net
}
// SetNetworkValidation configures network-based connection filtering
func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.mu.Lock()
defer s.mu.Unlock()
s.wgAddress = addr
}
// UpdateSSHAuth updates the SSH fine-grained access control configuration
// This should be called when network map updates include new SSH auth configuration
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {

View File

@@ -132,19 +132,6 @@ type SSHServerStateOutput struct {
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
}
type VNCSessionOutput struct {
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
Mode string `json:"mode" yaml:"mode"`
Username string `json:"username,omitempty" yaml:"username,omitempty"`
UserID string `json:"userID,omitempty" yaml:"userID,omitempty"`
Initiator string `json:"initiator,omitempty" yaml:"initiator,omitempty"`
}
type VNCServerStateOutput struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Sessions []VNCSessionOutput `json:"sessions" yaml:"sessions"`
}
type OutputOverview struct {
Peers PeersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
@@ -168,7 +155,6 @@ type OutputOverview struct {
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
ProfileName string `json:"profileName" yaml:"profileName"`
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
}
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
@@ -189,7 +175,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
relayOverview := mapRelays(pbFullStatus.GetRelays())
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
vncServerOverview := mapVNCServer(pbFullStatus.GetVncServerState())
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
overview := OutputOverview{
@@ -215,7 +200,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
ProfileName: opts.ProfileName,
SSHServerState: sshServerOverview,
VNCServerState: vncServerOverview,
}
if opts.Anonymize {
@@ -297,26 +281,6 @@ func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
}
}
func mapVNCServer(state *proto.VNCServerState) VNCServerStateOutput {
if state == nil {
return VNCServerStateOutput{Sessions: []VNCSessionOutput{}}
}
sessions := make([]VNCSessionOutput, 0, len(state.GetSessions()))
for _, sess := range state.GetSessions() {
sessions = append(sessions, VNCSessionOutput{
RemoteAddress: sess.GetRemoteAddress(),
Mode: sess.GetMode(),
Username: sess.GetUsername(),
UserID: sess.GetUserID(),
Initiator: sess.GetInitiator(),
})
}
return VNCServerStateOutput{
Enabled: state.GetEnabled(),
Sessions: sessions,
}
}
func mapPeers(
peers []*proto.PeerState,
statusFilter string,
@@ -581,26 +545,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
}
}
vncServerStatus := "Disabled"
if o.VNCServerState.Enabled {
vncSessionCount := len(o.VNCServerState.Sessions)
if vncSessionCount > 0 {
sessionWord := "session"
if vncSessionCount > 1 {
sessionWord = "sessions"
}
vncServerStatus = fmt.Sprintf("Enabled (%d active %s)", vncSessionCount, sessionWord)
} else {
vncServerStatus = "Enabled"
}
if showSSHSessions && vncSessionCount > 0 {
for _, sess := range o.VNCServerState.Sessions {
vncServerStatus += "\n " + formatVNCSessionLine(sess)
}
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
var forwardingRulesString string
@@ -647,7 +591,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
"SSH Server: %s\n"+
"VNC Server: %s\n"+
"Networks: %s\n"+
"%s"+
"Peers count: %s\n",
@@ -667,7 +610,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
sshServerStatus,
vncServerStatus,
networks,
forwardingRulesString,
peersCountString,
@@ -1027,26 +969,6 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *PeerStateDetailOutput) {
}
}
// formatVNCSessionLine renders a single VNC session row for the detailed
// status output. The leading slot identifies the initiator (display name
// when known, hashed UserID otherwise); the post-arrow slot is the OS
// user the session targets and is omitted in attach mode where the
// destination is the current console user (unknown to the daemon).
func formatVNCSessionLine(sess VNCSessionOutput) string {
who := sess.Initiator
if who == "" {
who = sess.UserID
}
prefix := sess.RemoteAddress
if who != "" {
prefix = fmt.Sprintf("%s@%s", who, sess.RemoteAddress)
}
if sess.Username != "" {
return fmt.Sprintf("[%s -> %s] mode=%s", prefix, sess.Username, sess.Mode)
}
return fmt.Sprintf("[%s] mode=%s", prefix, sess.Mode)
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
@@ -1067,19 +989,6 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
overview.Relays.Details[i] = detail
}
anonymizeNSServerGroups(a, overview)
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
anonymizeEvents(a, overview)
anonymizeServerSessions(a, overview)
}
func anonymizeNSServerGroups(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
@@ -1091,9 +1000,13 @@ func anonymizeNSServerGroups(a *anonymize.Anonymizer, overview *OutputOverview)
}
}
}
}
func anonymizeEvents(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
for i, event := range overview.Events {
overview.Events[i].Message = a.AnonymizeString(event.Message)
overview.Events[i].UserMessage = a.AnonymizeString(event.UserMessage)
@@ -1102,24 +1015,13 @@ func anonymizeEvents(a *anonymize.Anonymizer, overview *OutputOverview) {
event.Metadata[k] = a.AnonymizeString(v)
}
}
}
func anonymizeRemoteAddress(a *anonymize.Anonymizer, addr string) string {
if host, port, err := net.SplitHostPort(addr); err == nil {
return fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
return a.AnonymizeIPString(addr)
}
func anonymizeServerSessions(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, session := range overview.SSHServerState.Sessions {
overview.SSHServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, session.RemoteAddress)
if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil {
overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
} else {
overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress)
}
overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command)
}
for i, sess := range overview.VNCServerState.Sessions {
overview.VNCServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, sess.RemoteAddress)
overview.VNCServerState.Sessions[i].Username = a.AnonymizeString(sess.Username)
overview.VNCServerState.Sessions[i].UserID = a.AnonymizeString(sess.UserID)
overview.VNCServerState.Sessions[i].Initiator = a.AnonymizeString(sess.Initiator)
}
}

View File

@@ -242,10 +242,6 @@ var overview = OutputOverview{
Enabled: false,
Sessions: []SSHSessionOutput{},
},
VNCServerState: VNCServerStateOutput{
Enabled: false,
Sessions: []VNCSessionOutput{},
},
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
@@ -411,10 +407,6 @@ func TestParsingToJSON(t *testing.T) {
"sshServer":{
"enabled":false,
"sessions":[]
},
"vncServer":{
"enabled":false,
"sessions":[]
}
}`
// @formatter:on
@@ -525,9 +517,6 @@ profileName: ""
sshServer:
enabled: false
sessions: []
vncServer:
enabled: false
sessions: []
`
assert.Equal(t, expectedYAML, yaml)
@@ -598,7 +587,6 @@ Wireguard port: %d
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion, overview.WgPort)
@@ -625,7 +613,6 @@ Wireguard port: 51820
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`

View File

@@ -2,9 +2,11 @@ package system
import (
"context"
"errors"
"net/netip"
"slices"
"strings"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/metadata"
@@ -63,7 +65,6 @@ type Info struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
DisableClientRoutes bool
DisableServerRoutes bool
@@ -85,7 +86,6 @@ type Info struct {
func (i *Info) SetFlags(
rosenpassEnabled, rosenpassPermissive bool,
serverSSHAllowed *bool,
serverVNCAllowed *bool,
disableClientRoutes, disableServerRoutes,
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool,
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
@@ -96,9 +96,6 @@ func (i *Info) SetFlags(
if serverSSHAllowed != nil {
i.ServerSSHAllowed = *serverSSHAllowed
}
if serverVNCAllowed != nil {
i.ServerVNCAllowed = *serverVNCAllowed
}
i.DisableClientRoutes = disableClientRoutes
i.DisableServerRoutes = disableServerRoutes
@@ -179,7 +176,7 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs .
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
}
files, err := checkFileAndProcess(processCheckPaths)
files, err := checkFileAndProcess(ctx, processCheckPaths)
if err != nil {
return nil, err
}
@@ -192,3 +189,43 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs .
log.Debugf("all system information gathered successfully")
return info, nil
}
// GetInfoWithChecksTimeout is GetInfoWithChecks bounded by timeout. Posture-check gathering
// runs uncancellable system calls (process enumeration, os.Stat), so calling it inline can
// block the caller for as long as such a call hangs. It runs in a goroutine instead: if it
// does not return within timeout the caller gets (nil, false) and should proceed with
// degraded behavior rather than block. On a gathering error it falls back to base GetInfo.
//
// The buffered channel lets the abandoned goroutine finish and exit once its blocking call
// returns, so it does not leak beyond the duration of that call.
func GetInfoWithChecksTimeout(ctx context.Context, timeout time.Duration, checks []*proto.Checks, excludeIPs ...netip.Addr) (*Info, bool) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
infoCh := make(chan *Info, 1)
go func() {
info, err := GetInfoWithChecks(ctx, checks, excludeIPs...)
if err != nil {
if ctx.Err() != nil {
return
}
log.Warnf("failed to get system info with checks: %v", err)
info = GetInfo(ctx)
info.removeAddresses(excludeIPs...)
}
infoCh <- info
}()
select {
case info := <-infoCh:
return info, true
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
log.Warnf("gathering system info with checks timed out after %s", timeout)
} else {
// Parent context canceled (e.g. shutdown), not a timeout.
log.Warnf("gathering system info with checks canceled: %v", ctx.Err())
}
return nil, false
}
}

View File

@@ -50,7 +50,7 @@ func GetInfo(ctx context.Context) *Info {
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
return []File{}, nil
}

View File

@@ -32,7 +32,7 @@ func GetInfo(ctx context.Context) *Info {
sysName := string(bytes.Split(utsname.Sysname[:], []byte{0})[0])
machine := string(bytes.Split(utsname.Machine[:], []byte{0})[0])
release := string(bytes.Split(utsname.Release[:], []byte{0})[0])
swVersion, err := exec.Command("sw_vers", "-productVersion").Output()
swVersion, err := exec.CommandContext(ctx, "sw_vers", "-productVersion").Output()
if err != nil {
log.Warnf("got an error while retrieving macOS version with sw_vers, error: %s. Using darwin version instead.\n", err)
swVersion = []byte(release)

View File

@@ -105,7 +105,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
return []File{}, nil
}

View File

@@ -103,7 +103,7 @@ func collectLocationInfo(info *Info) {
}
}
func checkFileAndProcess(_ []string) ([]File, error) {
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
return []File{}, nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
@@ -35,6 +36,20 @@ func Test_CustomHostname(t *testing.T) {
assert.Equal(t, want, got.Hostname)
}
func TestGetInfoWithChecksTimeout_Success(t *testing.T) {
info, ok := GetInfoWithChecksTimeout(context.Background(), 30*time.Second, nil)
assert.True(t, ok, "expected gathering to complete within the timeout")
assert.NotNil(t, info)
}
func TestGetInfoWithChecksTimeout_Timeout(t *testing.T) {
// A 1ns budget expires before the (real) system-info gathering can finish, so the
// caller must get (nil, false) instead of blocking on the in-flight goroutine.
info, ok := GetInfoWithChecksTimeout(context.Background(), time.Nanosecond, nil)
assert.False(t, ok, "expected timeout to be reported")
assert.Nil(t, info)
}
func Test_NetAddresses(t *testing.T) {
addr, err := networkAddresses()
if err != nil {

View File

@@ -3,24 +3,30 @@
package system
import (
"context"
"os"
"slices"
"github.com/shirou/gopsutil/v3/process"
)
// getRunningProcesses returns a list of running process paths.
func getRunningProcesses() ([]string, error) {
processIDs, err := process.Pids()
// getRunningProcesses returns a list of running process paths. The context bounds the work:
// the per-PID loop bails as soon as ctx is done, and the gopsutil calls honor it where they
// can, so a stuck enumeration cannot run unbounded.
func getRunningProcesses(ctx context.Context) ([]string, error) {
processIDs, err := process.PidsWithContext(ctx)
if err != nil {
return nil, err
}
processMap := make(map[string]bool)
for _, pID := range processIDs {
if err := ctx.Err(); err != nil {
return nil, err
}
p := &process.Process{Pid: pID}
path, _ := p.Exe()
path, _ := p.ExeWithContext(ctx)
if path != "" {
processMap[path] = false
}
@@ -35,18 +41,21 @@ func getRunningProcesses() ([]string, error) {
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
func checkFileAndProcess(ctx context.Context, paths []string) ([]File, error) {
files := make([]File, len(paths))
if len(paths) == 0 {
return files, nil
}
runningProcesses, err := getRunningProcesses()
runningProcesses, err := getRunningProcesses(ctx)
if err != nil {
return nil, err
}
for i, path := range paths {
if err := ctx.Err(); err != nil {
return nil, err
}
file := File{Path: path}
_, err := os.Stat(path)

View File

@@ -1,6 +1,7 @@
package system
import (
"context"
"testing"
"github.com/shirou/gopsutil/v3/process"
@@ -9,7 +10,7 @@ import (
func Benchmark_getRunningProcesses(b *testing.B) {
b.Run("getRunningProcesses new", func(b *testing.B) {
for i := 0; i < b.N; i++ {
ps, err := getRunningProcesses()
ps, err := getRunningProcesses(context.Background())
if err != nil {
b.Fatalf("unexpected error: %v", err)
}
@@ -29,12 +30,38 @@ func Benchmark_getRunningProcesses(b *testing.B) {
}
}
})
s, _ := getRunningProcesses()
s, _ := getRunningProcesses(context.Background())
b.Logf("getRunningProcesses returned %d processes", len(s))
s, _ = getRunningProcessesOld()
b.Logf("getRunningProcessesOld returned %d processes", len(s))
}
func TestCheckFileAndProcess_ContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
// With a canceled context and non-empty paths the gathering must bail with an error
// instead of running the (potentially blocking) process scan / stat loop.
if _, err := checkFileAndProcess(ctx, []string{"/does/not/exist"}); err == nil {
t.Fatal("expected error on canceled context, got nil")
}
}
func TestCheckFileAndProcess_EmptyPaths(t *testing.T) {
// No check paths means no work to do: it must return immediately with no error,
// even on a canceled context (nothing to scan or stat).
ctx, cancel := context.WithCancel(context.Background())
cancel()
files, err := checkFileAndProcess(ctx, nil)
if err != nil {
t.Fatalf("unexpected error for empty paths: %v", err)
}
if len(files) != 0 {
t.Fatalf("expected no files, got %d", len(files))
}
}
func getRunningProcessesOld() ([]string, error) {
processes, err := process.Processes()
if err != nil {

View File

@@ -1,259 +0,0 @@
//go:build !(linux && 386)
package main
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"strings"
"time"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/widget"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/approval"
"github.com/netbirdio/netbird/client/proto"
)
// Approval metadata that is remote-peer or dashboard controlled is passed to
// the forked netbird-ui via environment variables rather than argv, so it is
// not exposed to other local users through ps.
const (
envApprovalInitiator = "NB_APPROVAL_INITIATOR"
envApprovalPeerName = "NB_APPROVAL_PEER_NAME"
envApprovalSourceIP = "NB_APPROVAL_SOURCE_IP"
envApprovalUsername = "NB_APPROVAL_USERNAME"
envApprovalKeyFingerprint = "NB_APPROVAL_KEY_FINGERPRINT"
envApprovalSubject = "NB_APPROVAL_SUBJECT"
)
// handleApprovalEvent forks a netbird-ui child process to render the
// dialog on its own fyne main loop. Top-level windows opened from a
// background goroutine of the tray process don't render reliably on
// Linux/GTK, so the rest of the UI (settings, login URL, update) uses
// the same fork pattern.
func (s *serviceClient) handleApprovalEvent(ev *proto.SystemEvent) {
if ev == nil || ev.Category != proto.SystemEvent_APPROVAL {
return
}
requestID := ev.Metadata["request_id"]
if requestID == "" {
log.Warnf("approval event missing request_id: %v", ev.Metadata)
return
}
// Only the request id, kind, and deadline stay on argv: they are
// daemon-issued and non-sensitive. The remote-influenced fields go
// through the child's environment.
args := []string{
"--approval-request-id=" + requestID,
"--approval-kind=" + ev.Metadata["kind"],
"--approval-expires-at=" + ev.Metadata["expires_at"],
}
env := append(os.Environ(),
envApprovalInitiator+"="+ev.Metadata["initiator"],
envApprovalPeerName+"="+ev.Metadata["peer_name"],
envApprovalSourceIP+"="+ev.Metadata["source_ip"],
envApprovalUsername+"="+ev.Metadata["username"],
envApprovalKeyFingerprint+"="+ev.Metadata["peer_pubkey"],
envApprovalSubject+"="+ev.UserMessage,
)
go s.runApprovalCommand(s.ctx, env, args)
}
// runApprovalCommand forks netbird-ui to render the approval dialog,
// inheriting the parent environment plus the approval-specific variables. It
// mirrors runSelfCommand but sets cmd.Env so the sensitive metadata never
// appears on the child's argv.
func (s *serviceClient) runApprovalCommand(ctx context.Context, env, args []string) {
proc, err := os.Executable()
if err != nil {
log.Errorf("get executable path: %v", err)
return
}
cmdArgs := append([]string{"--approval=true", "--daemon-addr=" + s.addr}, args...)
cmd := exec.CommandContext(ctx, proc, cmdArgs...)
cmd.Env = env
if out := s.attachOutput(cmd); out != nil {
defer func() {
if err := out.Close(); err != nil {
log.Errorf("close log file %s: %v", s.logFile, err)
}
}()
}
log.Printf("running approval command: %s", cmd.String())
if err := cmd.Run(); err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
log.Printf("approval command failed with exit code %d", exitErr.ExitCode())
}
}
}
// showApprovalUI runs the dialog on the forked process's fyne main loop
// and forwards the user's decision to the daemon via RespondApproval.
func (s *serviceClient) showApprovalUI(req approvalRequest) {
w := s.app.NewWindow(approvalTitle(req.kind))
w.Resize(fyne.NewSize(480, 260))
w.CenterOnScreen()
w.RequestFocus()
var rows []string
if req.initiator != "" {
// The display name comes from the management dashboard and is
// not cryptographically asserted by the connecting client. The
// key fingerprint that follows IS: it's the Noise_IK static
// public key the client just proved possession of. Show both
// so the user can sanity-check that "Alice" is really the
// Alice they trust.
rows = append(rows, "From user: "+req.initiator)
}
if fp := approval.ShortKeyFingerprint(req.keyFingerprint); fp != "" {
rows = append(rows, "Key fp: "+fp)
}
if req.peerName != "" {
rows = append(rows, "Via peer: "+req.peerName)
}
if req.sourceIP != "" && req.sourceIP != req.peerName {
rows = append(rows, "Source IP: "+req.sourceIP)
}
if req.username != "" {
rows = append(rows, "OS user: "+req.username)
}
if len(rows) == 0 {
rows = []string{"Remote: " + req.displayPeer()}
}
body := strings.Join(rows, "\n")
bodyLabel := widget.NewLabel(body)
bodyLabel.Wrapping = fyne.TextWrapWord
countdown := widget.NewLabel("")
deadline := req.deadline()
updateCountdown := func() {
remaining := time.Until(deadline).Round(time.Second)
if remaining < 0 {
remaining = 0
}
countdown.SetText(fmt.Sprintf("Auto-deny in %s", remaining))
}
updateCountdown()
type outcome struct {
accept bool
viewOnly bool
}
decided := make(chan outcome, 1)
decide := func(o outcome) {
select {
case decided <- o:
default:
}
}
allow := widget.NewButton("Allow", func() { decide(outcome{accept: true}) })
allow.Importance = widget.HighImportance
allowView := widget.NewButton("Allow (view only)", func() { decide(outcome{accept: true, viewOnly: true}) })
deny := widget.NewButton("Deny", func() { decide(outcome{accept: false}) })
header := widget.NewLabelWithStyle(req.subject, fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
buttonRow := container.NewGridWithColumns(3, allow, allowView, deny)
info := container.NewVBox(header, widget.NewSeparator(), bodyLabel, widget.NewSeparator(), countdown)
w.SetContent(container.NewPadded(container.NewBorder(nil, buttonRow, nil, nil, info)))
w.SetCloseIntercept(func() { decide(outcome{}) })
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for range ticker.C {
if time.Until(deadline) <= 0 {
decide(outcome{})
return
}
fyne.Do(updateCountdown)
}
}()
go func() {
o := <-decided
s.sendApprovalResponse(req.requestID, o.accept, o.viewOnly)
fyne.Do(func() {
w.Close()
s.app.Quit()
})
}()
w.Show()
}
func (s *serviceClient) sendApprovalResponse(requestID string, accept, viewOnly bool) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Warnf("approval response: get daemon client: %v", err)
return
}
ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout)
defer cancel()
if _, err := conn.RespondApproval(ctx, &proto.RespondApprovalRequest{
RequestId: requestID,
Accept: accept,
ViewOnly: viewOnly,
}); err != nil {
log.Warnf("approval response: %v", err)
}
}
// approvalRequest is the parsed --approval-* CLI args that the forked
// dialog process consumes.
type approvalRequest struct {
requestID string
kind string
initiator string
peerName string
sourceIP string
username string
subject string
expiresAt string
keyFingerprint string
}
func (r approvalRequest) displayPeer() string {
switch {
case r.initiator != "":
return r.initiator
case r.peerName != "":
return r.peerName
case r.sourceIP != "":
return r.sourceIP
default:
return "unknown peer"
}
}
// deadline returns the wall-clock auto-deny moment. Falls back to a short
// local window when the daemon's expires_at is missing/unparsable, so a
// stale value never leaves the dialog open indefinitely.
func (r approvalRequest) deadline() time.Time {
if t, err := time.Parse(time.RFC3339, r.expiresAt); err == nil {
return t
}
return time.Now().Add(13 * time.Second)
}
func approvalTitle(kind string) string {
switch kind {
case "vnc":
return "Allow VNC Connection?"
case "ssh":
return "Allow SSH Connection?"
default:
return "Allow Incoming Connection?"
}
}

View File

@@ -112,25 +112,13 @@ func main() {
showQuickActions: flags.showQuickActions,
showUpdate: flags.showUpdate,
showUpdateVersion: flags.showUpdateVersion,
showApproval: flags.showApproval,
approvalRequest: approvalRequest{
requestID: flags.approvalRequestID,
kind: flags.approvalKind,
initiator: os.Getenv(envApprovalInitiator),
peerName: os.Getenv(envApprovalPeerName),
sourceIP: os.Getenv(envApprovalSourceIP),
username: os.Getenv(envApprovalUsername),
subject: os.Getenv(envApprovalSubject),
expiresAt: flags.approvalExpiresAt,
keyFingerprint: os.Getenv(envApprovalKeyFingerprint),
},
})
// Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set.
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate || flags.showApproval {
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate {
a.Run()
return
}
@@ -167,11 +155,6 @@ type cliFlags struct {
saveLogsInFile bool
showUpdate bool
showUpdateVersion string
showApproval bool
approvalRequestID string
approvalKind string
approvalExpiresAt string
}
// parseFlags reads and returns all needed command-line flags.
@@ -193,10 +176,6 @@ func parseFlags() *cliFlags {
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window")
flag.StringVar(&flags.showUpdateVersion, "update-version", "", "version to update to")
flag.BoolVar(&flags.showApproval, "approval", false, "show inbound-connection approval prompt window")
flag.StringVar(&flags.approvalRequestID, "approval-request-id", "", "approval prompt: daemon-issued request id")
flag.StringVar(&flags.approvalKind, "approval-kind", "", "approval prompt: subsystem kind (vnc, ssh, ...)")
flag.StringVar(&flags.approvalExpiresAt, "approval-expires-at", "", "approval prompt: RFC3339 deadline at which the daemon auto-denies")
flag.Parse()
return &flags
}
@@ -285,7 +264,6 @@ type serviceClient struct {
mQuit *systray.MenuItem
mNetworks *systray.MenuItem
mAllowSSH *systray.MenuItem
mAllowVNC *systray.MenuItem
mAutoConnect *systray.MenuItem
mEnableRosenpass *systray.MenuItem
mLazyConnEnabled *systray.MenuItem
@@ -324,8 +302,6 @@ type serviceClient struct {
sEnableSSHRemotePortForward *widget.Check
sDisableSSHAuth *widget.Check
iSSHJWTCacheTTL *widget.Entry
sServerVNCAllowed *widget.Check
sDisableVNCApproval *widget.Check
// observable settings over corresponding iMngURL and iPreSharedKey values.
managementURL string
@@ -347,8 +323,6 @@ type serviceClient struct {
enableSSHRemotePortForward bool
disableSSHAuth bool
sshJWTCacheTTL int
serverVNCAllowed bool
disableVNCApproval bool
connected bool
daemonVersion string
@@ -407,8 +381,6 @@ type newServiceClientArgs struct {
showQuickActions bool
showUpdate bool
showUpdateVersion string
showApproval bool
approvalRequest approvalRequest
}
// newServiceClient instance constructor
@@ -456,8 +428,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
}
case args.showUpdate:
s.showUpdateProgress(ctx, args.showUpdateVersion)
case args.showApproval:
s.showApprovalUI(args.approvalRequest)
}
return s
@@ -538,8 +508,6 @@ func (s *serviceClient) showSettingsUI() {
s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil)
s.iSSHJWTCacheTTL = widget.NewEntry()
s.sServerVNCAllowed = widget.NewCheck("Allow embedded VNC server on this peer", nil)
s.sDisableVNCApproval = widget.NewCheck("Skip per-connection approval prompt for VNC", nil)
s.wSettings.SetContent(s.getSettingsForm())
s.wSettings.Resize(fyne.NewSize(600, 400))
@@ -652,8 +620,7 @@ func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool
s.disableServerRoutes != s.sDisableServerRoutes.Checked ||
s.disableIPv6 != s.sDisableIPv6.Checked ||
s.blockLANAccess != s.sBlockLANAccess.Checked ||
s.hasSSHChanges() ||
s.hasVNCChanges()
s.hasSSHChanges()
}
func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) error {
@@ -712,8 +679,6 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
req.EnableSSHLocalPortForwarding = &s.sEnableSSHLocalPortForward.Checked
req.EnableSSHRemotePortForwarding = &s.sEnableSSHRemotePortForward.Checked
req.DisableSSHAuth = &s.sDisableSSHAuth.Checked
req.ServerVNCAllowed = &s.sServerVNCAllowed.Checked
req.DisableVNCApproval = &s.sDisableVNCApproval.Checked
sshJWTCacheTTLText := strings.TrimSpace(s.iSSHJWTCacheTTL.Text)
if sshJWTCacheTTLText != "" {
@@ -782,12 +747,10 @@ func (s *serviceClient) getSettingsForm() fyne.CanvasObject {
connectionForm := s.getConnectionForm()
networkForm := s.getNetworkForm()
sshForm := s.getSSHForm()
vncForm := s.getVNCForm()
tabs := container.NewAppTabs(
container.NewTabItem("Connection", connectionForm),
container.NewTabItem("Network", networkForm),
container.NewTabItem("SSH", sshForm),
container.NewTabItem("VNC", vncForm),
)
saveButton := widget.NewButtonWithIcon("Save", theme.ConfirmIcon(), s.saveSettings)
saveButton.Importance = widget.HighImportance
@@ -828,15 +791,6 @@ func (s *serviceClient) getSSHForm() *widget.Form {
}
}
func (s *serviceClient) getVNCForm() *widget.Form {
return &widget.Form{
Items: []*widget.FormItem{
{Text: "Allow VNC Server", Widget: s.sServerVNCAllowed},
{Text: "Disable Connection Approval Prompt", Widget: s.sDisableVNCApproval},
},
}
}
func (s *serviceClient) hasSSHChanges() bool {
currentSSHJWTCacheTTL := s.sshJWTCacheTTL
if text := strings.TrimSpace(s.iSSHJWTCacheTTL.Text); text != "" {
@@ -855,11 +809,6 @@ func (s *serviceClient) hasSSHChanges() bool {
s.sshJWTCacheTTL != currentSSHJWTCacheTTL
}
func (s *serviceClient) hasVNCChanges() bool {
return s.serverVNCAllowed != s.sServerVNCAllowed.Checked ||
s.disableVNCApproval != s.sDisableVNCApproval.Checked
}
func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
@@ -1143,7 +1092,6 @@ func (s *serviceClient) onTrayReady() {
s.mSettings = systray.AddMenuItem("Settings", disabledMenuDescr)
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
s.mAllowVNC = s.mSettings.AddSubMenuItemCheckbox("Allow VNC", allowVNCMenuDescr, false)
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable Lazy Connections", lazyConnMenuDescr, false)
@@ -1224,7 +1172,6 @@ func (s *serviceClient) onTrayReady() {
s.eventManager = event.NewManager(s.notifier, s.addr)
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
s.eventManager.AddHandler(s.handleApprovalEvent)
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
if event.Category == proto.SystemEvent_SYSTEM {
s.updateExitNodes()
@@ -1507,12 +1454,6 @@ func (s *serviceClient) getSrvConfig() {
if cfg.SSHJWTCacheTTL != nil {
s.sshJWTCacheTTL = *cfg.SSHJWTCacheTTL
}
if cfg.ServerVNCAllowed != nil {
s.serverVNCAllowed = *cfg.ServerVNCAllowed
}
if cfg.DisableVNCApproval != nil {
s.disableVNCApproval = *cfg.DisableVNCApproval
}
if s.showAdvancedSettings {
s.iMngURL.SetText(s.managementURL)
@@ -1568,12 +1509,6 @@ func (s *serviceClient) getSrvConfig() {
if cfg.SSHJWTCacheTTL != nil {
s.iSSHJWTCacheTTL.SetText(strconv.Itoa(*cfg.SSHJWTCacheTTL))
}
if cfg.ServerVNCAllowed != nil {
s.sServerVNCAllowed.SetChecked(*cfg.ServerVNCAllowed)
}
if cfg.DisableVNCApproval != nil {
s.sDisableVNCApproval.SetChecked(*cfg.DisableVNCApproval)
}
}
// MDM locks must run before the mNotifications-nil early return:
@@ -1640,8 +1575,6 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
config.DisableAutoConnect = cfg.DisableAutoConnect
config.ServerSSHAllowed = &cfg.ServerSSHAllowed
config.ServerVNCAllowed = &cfg.ServerVNCAllowed
config.DisableVNCApproval = &cfg.DisableVNCApproval
config.RosenpassEnabled = cfg.RosenpassEnabled
config.RosenpassPermissive = cfg.RosenpassPermissive
config.DisableNotifications = &cfg.DisableNotifications
@@ -1737,12 +1670,6 @@ func (s *serviceClient) loadSettings() {
s.mAllowSSH.Uncheck()
}
if cfg.ServerVNCAllowed {
s.mAllowVNC.Check()
} else {
s.mAllowVNC.Uncheck()
}
if cfg.DisableAutoConnect {
s.mAutoConnect.Uncheck()
} else {
@@ -1905,7 +1832,6 @@ func (s *serviceClient) applyMDMLocksToSettingsForm(set map[string]bool) {
func (s *serviceClient) updateConfig() error {
disableAutoStart := !s.mAutoConnect.Checked()
sshAllowed := s.mAllowSSH.Checked()
vncAllowed := s.mAllowVNC.Checked()
rosenpassEnabled := s.mEnableRosenpass.Checked()
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
blockInbound := s.mBlockInbound.Checked()
@@ -1934,7 +1860,6 @@ func (s *serviceClient) updateConfig() error {
Username: currUser.Username,
DisableAutoConnect: &disableAutoStart,
ServerSSHAllowed: &sshAllowed,
ServerVNCAllowed: &vncAllowed,
RosenpassEnabled: &rosenpassEnabled,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,

View File

@@ -2,7 +2,6 @@ package main
const (
allowSSHMenuDescr = "Allow SSH connections"
allowVNCMenuDescr = "Allow embedded VNC server"
autoConnectMenuDescr = "Connect automatically when the service starts"
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
lazyConnMenuDescr = "[Experimental] Enable lazy connections"

View File

@@ -112,7 +112,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) {
handlers := slices.Clone(e.handlers)
e.mu.Unlock()
if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) && !isV6DefaultRoutePartner(event) && event.Category != proto.SystemEvent_APPROVAL {
if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) && !isV6DefaultRoutePartner(event) {
title := e.getEventTitle(event)
body := event.UserMessage
id := event.Metadata["id"]

View File

@@ -39,8 +39,6 @@ func (h *eventHandler) listen(ctx context.Context) {
h.handleDisconnectClick()
case <-h.client.mAllowSSH.ClickedCh:
h.handleAllowSSHClick()
case <-h.client.mAllowVNC.ClickedCh:
h.handleAllowVNCClick()
case <-h.client.mAutoConnect.ClickedCh:
h.handleAutoConnectClick()
case <-h.client.mEnableRosenpass.ClickedCh:
@@ -136,15 +134,6 @@ func (h *eventHandler) handleAllowSSHClick() {
}
func (h *eventHandler) handleAllowVNCClick() {
h.toggleCheckbox(h.client.mAllowVNC)
if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mAllowVNC) // revert checkbox state on error
log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update VNC settings")
}
}
func (h *eventHandler) handleAutoConnectClick() {
h.toggleCheckbox(h.client.mAutoConnect)
if err := h.updateConfigWithErr(); err != nil {

View File

@@ -1,31 +0,0 @@
// Package vnc holds shared constants for the NetBird embedded VNC stack
// so non-server consumers (CLI capture, debug tooling) can refer to the
// well-known ports without depending on internal engine packages.
package vnc
// External and internal listen ports for the embedded VNC server.
// ExternalPort is what dashboard / browser clients see; the daemon
// DNATs it to InternalPort, where the in-process VNC server actually
// listens. Both flow over the WireGuard interface. AgentLegacyPort is
// the TCP port the per-session agent used before it switched to Unix
// sockets; kept here so packet captures from older builds still get
// tagged, and so any future on-wire agent variant has a reserved port.
const (
ExternalPort uint16 = 5900
InternalPort uint16 = 25900
AgentLegacyPort uint16 = 15900
)
// WellKnownPorts is the unordered set of ports a packet capture should
// treat as carrying NetBird VNC traffic.
var WellKnownPorts = [...]uint16{ExternalPort, InternalPort, AgentLegacyPort}
// IsWellKnownPort reports whether port matches any of WellKnownPorts.
func IsWellKnownPort(port uint16) bool {
for _, p := range WellKnownPorts {
if port == p {
return true
}
}
return false
}

View File

@@ -1,434 +0,0 @@
//go:build darwin && !ios
package server
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"os"
"os/exec"
"strconv"
"sync"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/configs"
)
// darwinAgentManager spawns a per-user VNC agent on demand and keeps it
// alive across multiple client connections within the same console-user
// session. A new agent is spawned the first time a client connects, or
// whenever the console user changes underneath us.
//
// Lifecycle is lazy by design: a daemon that never receives a VNC
// connection never spawns anything. The trade-off versus an eager spawn
// (the Windows model) is that the first VNC client pays the launchctl
// asuser + listen-readiness wait, ~hundreds of milliseconds in practice.
// That cost only repeats on user switch.
type darwinAgentManager struct {
mu sync.Mutex
authToken string
socketPath string
uid uint32
running bool
}
func newDarwinAgentManager(ctx context.Context) *darwinAgentManager {
m := &darwinAgentManager{}
go m.watchConsoleUser(ctx)
return m
}
// agentSocketName is the file name inside the per-uid socket directory
// the agent binds. The directory itself is created and chowned by the
// daemon (see prepareAgentSocketDir) so a non-root local user cannot
// pre-create or symlink the path before the agent listens.
const agentSocketName = "agent.sock"
// watchConsoleUser kills the cached agent whenever the console user
// changes (logout, fast user switch, login window). Without it the daemon
// keeps proxying to an agent whose TCC grant and WindowServer access
// belong to a user who is no longer at the screen, so the new user only
// ever sees the locked-screen wallpaper. Killing the agent breaks the
// loopback TCP that the daemon proxies into, the client disconnects, and
// the next reconnect runs ensure() against the new console uid.
func (m *darwinAgentManager) watchConsoleUser(ctx context.Context) {
t := time.NewTicker(2 * time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
uid, err := consoleUserID()
m.mu.Lock()
if !m.running {
m.mu.Unlock()
continue
}
if err != nil || uid != m.uid {
prev := m.uid
m.killLocked()
m.mu.Unlock()
if err != nil {
log.Infof("console user gone (was uid=%d): %v; agent stopped", prev, err)
} else {
log.Infof("console user changed %d -> %d; agent stopped, will respawn on next connect", prev, uid)
}
continue
}
m.mu.Unlock()
}
}
}
// Resolve spawns or respawns the per-user agent process as needed and
// returns its Unix-socket path, shared token, and the uid the agent was
// spawned under (so the daemon can validate peer credentials before
// dispatching the token). Each call is serialized so concurrent VNC
// clients share the same agent.
func (m *darwinAgentManager) Resolve(ctx context.Context) (string, string, uint32, error) {
consoleUID, err := consoleUserID()
if err != nil {
return "", "", 0, fmt.Errorf("no console user: %w", err)
}
m.mu.Lock()
defer m.mu.Unlock()
if m.running && m.uid == consoleUID && vncAgentRunning() {
return m.socketPath, m.authToken, m.uid, nil
}
m.killLocked()
// Reap stray agents so the new token is the only accepted one.
killAllVNCAgents()
socketDir, err := prepareAgentSocketDir(consoleUID)
if err != nil {
return "", "", 0, fmt.Errorf("prepare agent socket dir: %w", err)
}
socketPath := socketDir + "/" + agentSocketName
if err := os.Remove(socketPath); err != nil && !errors.Is(err, os.ErrNotExist) {
log.Debugf("clear stale agent socket %s: %v", socketPath, err)
}
token, err := generateAuthToken()
if err != nil {
return "", "", 0, fmt.Errorf("generate agent auth token: %w", err)
}
if err := spawnAgentForUser(consoleUID, socketPath, token); err != nil {
return "", "", 0, err
}
if err := waitForAgent(ctx, socketPath, 5*time.Second); err != nil {
killAllVNCAgents()
return "", "", 0, fmt.Errorf("agent did not start listening: %w", err)
}
m.authToken = token
m.socketPath = socketPath
m.uid = consoleUID
m.running = true
log.Infof("spawned VNC agent for console uid=%d on %s", consoleUID, socketPath)
return socketPath, token, consoleUID, nil
}
// prepareAgentSocketDir creates a per-uid subdirectory under the netbird
// runtime directory where the agent will bind its Unix socket. The leaf is
// owned by uid with mode 0700, so only the target user and root can write
// there. The parent is created root-owned with mode 0755 if missing.
// Symlinks at the per-uid level are refused (replaced with a fresh
// directory) so a low-priv user cannot redirect the chown that follows.
func prepareAgentSocketDir(uid uint32) (string, error) {
parent := configs.RuntimeDir
if err := ensureAgentSocketParent(parent); err != nil {
return "", err
}
subdir := fmt.Sprintf("%s/vnc-%d", parent, uid)
if err := purgeStaleAgentSubdir(subdir, uid); err != nil {
return "", err
}
if err := os.Mkdir(subdir, 0o700); err != nil && !errors.Is(err, os.ErrExist) {
return "", fmt.Errorf("mkdir %s: %w", subdir, err)
}
if err := os.Chmod(subdir, 0o700); err != nil {
return "", fmt.Errorf("chmod %s: %w", subdir, err)
}
if err := os.Chown(subdir, int(uid), -1); err != nil {
return "", fmt.Errorf("chown %s -> uid %d: %w", subdir, uid, err)
}
return subdir, nil
}
// ensureAgentSocketParent verifies the runtime parent dir exists, is not a
// symlink, and is owned by root.
func ensureAgentSocketParent(parent string) error {
if parent == "" {
return fmt.Errorf("no runtime directory configured for this platform")
}
if err := os.MkdirAll(parent, 0o755); err != nil {
return fmt.Errorf("mkdir %s: %w", parent, err)
}
info, err := os.Lstat(parent)
if err != nil {
return fmt.Errorf("lstat %s: %w", parent, err)
}
if info.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("%s is a symlink", parent)
}
if st, ok := info.Sys().(*syscall.Stat_t); ok && st.Uid != 0 {
return fmt.Errorf("%s not owned by root (uid=%d)", parent, st.Uid)
}
return nil
}
// purgeStaleAgentSubdir removes a leftover subdir unless it is a real dir
// owned by uid with mode 0700. Lstat (not Stat) so a symlink is detected.
func purgeStaleAgentSubdir(subdir string, uid uint32) error {
info, err := os.Lstat(subdir)
if errors.Is(err, os.ErrNotExist) {
return nil
}
if err != nil {
return fmt.Errorf("lstat %s: %w", subdir, err)
}
if agentSubdirOK(info, uid) {
return nil
}
if err := os.RemoveAll(subdir); err != nil {
return fmt.Errorf("remove stale %s: %w", subdir, err)
}
return nil
}
func agentSubdirOK(info os.FileInfo, uid uint32) bool {
if info.Mode()&os.ModeSymlink != 0 || !info.IsDir() {
return false
}
st, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return false
}
return st.Uid == uid && info.Mode().Perm() == 0o700
}
// stop terminates the spawned agent, if any. Intended for daemon shutdown.
func (m *darwinAgentManager) stop() {
m.mu.Lock()
defer m.mu.Unlock()
m.killLocked()
}
func (m *darwinAgentManager) killLocked() {
if !m.running {
return
}
killAllVNCAgents()
if m.socketPath != "" {
if err := os.Remove(m.socketPath); err != nil && !errors.Is(err, os.ErrNotExist) {
log.Debugf("remove agent socket %s: %v", m.socketPath, err)
}
}
m.running = false
m.authToken = ""
m.socketPath = ""
m.uid = 0
}
// consoleUserID returns the uid of the user currently sitting at the
// console (the one whose Aqua session is active). Returns
// errNoConsoleUser when nobody is logged in: at the login window
// /dev/console is owned by root.
func consoleUserID() (uint32, error) {
info, err := os.Stat("/dev/console")
if err != nil {
return 0, fmt.Errorf("stat /dev/console: %w", err)
}
st, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return 0, fmt.Errorf("/dev/console stat has unexpected type")
}
if st.Uid == 0 {
return 0, errNoConsoleUser
}
return st.Uid, nil
}
// spawnAgentForUser uses launchctl asuser to start a netbird vnc-agent
// process inside the target user's launchd bootstrap namespace. That is
// the only spawn mode on macOS that gives the child access to the user's
// WindowServer. The agent's stderr is relogged into the daemon log so
// startup failures are not silently lost when the readiness check times
// out.
func spawnAgentForUser(uid uint32, socketPath, token string) error {
exe, err := os.Executable()
if err != nil {
return fmt.Errorf("resolve own executable: %w", err)
}
cmd := exec.Command(
"/bin/launchctl", "asuser", strconv.FormatUint(uint64(uid), 10),
exe, vncAgentSubcommand,
"--socket", socketPath,
// Drop privs inside the agent: launchctl asuser preserves the
// daemon's uid (root), so without this the capture/input/
// encoder paths would run as root for the lifetime of the
// session. validateAgentPeer on the daemon side also relies on
// the agent's effective uid matching consoleUID.
"--target-uid", strconv.FormatUint(uint64(uid), 10),
)
cmd.Env = append(os.Environ(), agentTokenEnvVar+"="+token)
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("agent stderr pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("launchctl asuser: %w", err)
}
go func() {
defer stderr.Close()
relogAgentStream(stderr)
}()
go func() { _ = cmd.Wait() }()
return nil
}
// waitForAgent dials the agent's Unix socket until it answers. Used to
// gate proxy attempts until the spawned process has finished its Start.
func waitForAgent(ctx context.Context, socketPath string, wait time.Duration) error {
var d net.Dialer
deadline := time.Now().Add(wait)
for time.Now().Before(deadline) {
if ctx.Err() != nil {
return ctx.Err()
}
dialCtx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
c, err := d.DialContext(dialCtx, "unix", socketPath)
cancel()
if err == nil {
_ = c.Close()
return nil
}
time.Sleep(100 * time.Millisecond)
}
return fmt.Errorf("timeout dialing %s", socketPath)
}
// vncAgentRunning reports whether any vnc-agent process exists on the
// system. There is at most one agent per machine, so any match is "the"
// agent.
func vncAgentRunning() bool {
pids, err := vncAgentPIDs()
if err != nil {
log.Debugf("scan for vnc-agent: %v", err)
return false
}
return len(pids) > 0
}
// killAllVNCAgents sends SIGTERM to every process whose argv contains
// "vnc-agent", waits briefly for them to exit, and escalates to SIGKILL
// for any that remain. We enumerate kern.proc.all rather than
// kern.proc.uid because launchctl asuser preserves the caller's uid
// (root) on the spawned child, so a uid-scoped filter would never match.
func killAllVNCAgents() {
pids, err := vncAgentPIDs()
if err != nil {
log.Debugf("scan for vnc-agent: %v", err)
return
}
for _, pid := range pids {
_ = syscall.Kill(pid, syscall.SIGTERM)
}
if len(pids) == 0 {
return
}
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
remaining, _ := vncAgentPIDs()
if len(remaining) == 0 {
return
}
time.Sleep(100 * time.Millisecond)
}
leftover, _ := vncAgentPIDs()
for _, pid := range leftover {
_ = syscall.Kill(pid, syscall.SIGKILL)
}
}
// vncAgentPIDs returns the pids of vnc-agent subprocesses spawned from
// this binary. Matches exactly on argv[0] == our own executable path
// AND argv[1] == "vnc-agent" so unrelated processes that happen to have
// the same name elsewhere in argv are not targeted. Skips pid 0 and 1
// defensively.
func vncAgentPIDs() ([]int, error) {
procs, err := unix.SysctlKinfoProcSlice("kern.proc.all")
if err != nil {
return nil, fmt.Errorf("sysctl kern.proc.all: %w", err)
}
ownExe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("resolve own executable: %w", err)
}
var out []int
for i := range procs {
pid := int(procs[i].Proc.P_pid)
if pid <= 1 {
continue
}
argv, err := procArgv(pid)
if err != nil || !argvIsVNCAgent(argv, ownExe) {
continue
}
out = append(out, pid)
}
return out, nil
}
// procArgv reads the kernel's stored argv for pid via the kern.procargs2
// sysctl. Format: 4-byte argc, then argv[0..argc) each NUL-terminated,
// then envp, then padding. We only need argv so we stop after argc.
func procArgv(pid int) ([]string, error) {
raw, err := unix.SysctlRaw("kern.procargs2", pid)
if err != nil {
return nil, err
}
if len(raw) < 4 {
return nil, fmt.Errorf("procargs2 truncated")
}
argc := int(raw[0]) | int(raw[1])<<8 | int(raw[2])<<16 | int(raw[3])<<24
body := raw[4:]
// Skip the executable path (NUL-terminated) and any zero padding that
// follows before argv[0].
end := bytes.IndexByte(body, 0)
if end < 0 {
return nil, fmt.Errorf("procargs2 path unterminated")
}
body = body[end+1:]
for len(body) > 0 && body[0] == 0 {
body = body[1:]
}
args := make([]string, 0, argc)
for i := 0; i < argc; i++ {
end := bytes.IndexByte(body, 0)
if end < 0 {
break
}
args = append(args, string(body[:end]))
body = body[end+1:]
}
return args, nil
}
// argvIsVNCAgent reports whether argv belongs to a vnc-agent subprocess
// spawned from our binary. Requires argv[0] to match ownExe exactly and
// argv[1] to be the vnc-agent subcommand. Matches the spawn shape in
// spawnAgentForUser and rejects anything else.
func argvIsVNCAgent(argv []string, ownExe string) bool {
if len(argv) < 2 || ownExe == "" {
return false
}
return argv[0] == ownExe && argv[1] == vncAgentSubcommand
}

View File

@@ -1,305 +0,0 @@
//go:build darwin || windows
package server
import (
"bufio"
"bytes"
"context"
crand "crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"time"
log "github.com/sirupsen/logrus"
)
// errNoConsoleUser is the sentinel returned by sessionAgent.Resolve when
// the platform has no interactive user to attach a capture agent to (the
// macOS loginwindow state). Mapped to a distinct RFB reject code so the
// browser can show a meaningful message.
var errNoConsoleUser = errors.New("no user logged into console")
// sessionAgent abstracts the per-platform manager that spawns and tracks
// the user-session VNC agent. Resolve returns the agent's Unix-socket
// path, the shared per-spawn token, and the uid the agent was spawned
// under (used to validate peer credentials before the daemon hands the
// token to whoever is on the other end of the socket). Resolve may spawn
// the agent lazily.
type sessionAgent interface {
Resolve(ctx context.Context) (socketPath, token string, peerUID uint32, err error)
}
// prefixConn replays already-consumed header bytes ahead of the proxy
// stream by swapping in a different Reader on the same underlying Conn.
type prefixConn struct {
io.Reader
net.Conn
}
func (p *prefixConn) Read(b []byte) (int, error) { return p.Reader.Read(b) }
// handleServiceConnection runs the connection-header handshake (source
// check, Noise_IK auth) on conn, resolves the right per-session agent
// via sa, and proxies to it. Every accepted connection emits exactly one
// outcome line on the daemon log.
func (s *Server) handleServiceConnection(conn net.Conn, sa sessionAgent) {
start := time.Now()
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
if !s.isAllowedSource(conn.RemoteAddr()) {
connLog.Info("VNC connection rejected: source not allowed")
_ = conn.Close()
return
}
var headerBuf bytes.Buffer
tee := io.TeeReader(conn, &headerBuf)
teeConn := &prefixConn{Reader: tee, Conn: conn}
header, err := s.readConnectionHeader(teeConn)
if err != nil {
connLog.Infof("VNC connection rejected: header read failed: %v", err)
_ = conn.Close()
return
}
authedLog, sessionUserID, ok := s.authorizeSession(conn, header, connLog)
if !ok {
authedLog.Info("VNC connection rejected: auth failed")
return
}
if err := s.registerConnAuth(conn, header); err != nil {
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
authedLog.Warnf("VNC connection rejected: %v", err)
return
}
decision, err := s.gateApproval(conn, header)
if err != nil {
authedLog.Infof("VNC connection rejected: %v", err)
return
}
if decision.ViewOnly {
authedLog.Info("VNC connection approved by user (view-only)")
} else if s.requireApproval {
authedLog.Info("VNC connection approved by user")
}
socketPath, token, peerUID, err := sa.Resolve(s.ctx)
if err != nil {
code := RejectCodeCapturerError
if errors.Is(err, errNoConsoleUser) {
code = RejectCodeNoConsoleUser
}
rejectConnection(conn, codeMessage(code, err.Error()))
authedLog.Warnf("VNC connection rejected: agent unavailable: %v", err)
return
}
var initiator string
if s.authorizer != nil {
initiator = s.authorizer.LookupSessionDisplayName(header.clientStatic)
}
sessionID := s.addSession(ActiveSessionInfo{
RemoteAddress: conn.RemoteAddr().String(),
Mode: modeString(header.mode),
Username: header.username,
UserID: sessionUserID,
Initiator: initiator,
}, conn)
defer s.removeSession(sessionID)
replayConn := &prefixConn{
Reader: io.MultiReader(&headerBuf, conn),
Conn: conn,
}
if err := proxyToAgent(s.ctx, replayConn, socketPath, token, peerUID, decision.ViewOnly, authedLog); err != nil {
rejectConnection(conn, codeMessage(RejectCodeCapturerError, err.Error()))
authedLog.Warnf("VNC connection rejected: agent unreachable: %v", err)
return
}
authedLog.Infof("VNC connection closed (%dms)", time.Since(start).Milliseconds())
}
const (
// agentTokenLen is the size of the random per-spawn token in bytes.
agentTokenLen = 32
// agentTokenEnvVar names the environment variable the daemon uses to
// hand the per-spawn token to the agent child. Out-of-band channels
// like this keep the secret out of the command line, where listings
// such as `ps` or Windows tasklist would expose it.
agentTokenEnvVar = "NB_VNC_AGENT_TOKEN" // #nosec G101 -- env var name, not a credential
// vncAgentSubcommand is the CLI subcommand the daemon invokes to start
// the per-session agent process. Must match cmd.vncAgentCmd.Use in
// client/cmd/vnc_agent.go.
vncAgentSubcommand = "vnc-agent"
)
// generateAuthToken returns a fresh hex-encoded random token for one
// daemon→agent session. The daemon hands this to the spawned agent
// out-of-band (env var on Windows) and verifies it on every connection
// the agent accepts.
func generateAuthToken() (string, error) {
b := make([]byte, agentTokenLen)
if _, err := crand.Read(b); err != nil {
return "", fmt.Errorf("read random: %w", err)
}
return hex.EncodeToString(b), nil
}
// proxyToAgent dials the per-session agent's Unix socket, validates the
// peer's kernel-asserted uid (so the daemon never hands its per-spawn
// token to an impostor that won the listen race), writes the raw token
// bytes plus a single view-only flag byte, then copies bytes both ways
// until either side closes. The token + flag prefix must precede any RFB
// byte so the agent's verifyAgentToken can run first. Returns nil once a
// stream is established; the caller is responsible for sending an
// RFB-level rejection on error so the client sees a reason instead of a
// bare timeout. authedLog receives one audit line per dispatched
// preamble so an operator can correlate daemon→agent traffic with the
// remote session that triggered it.
func proxyToAgent(ctx context.Context, client net.Conn, socketPath, authToken string, peerUID uint32, viewOnly bool, authedLog *log.Entry) error {
tokenBytes, err := hex.DecodeString(authToken)
if err != nil || len(tokenBytes) != agentTokenLen {
return fmt.Errorf("invalid auth token (len=%d): %w", len(tokenBytes), err)
}
agentConn, err := dialAgentWithRetry(ctx, socketPath)
if err != nil {
return fmt.Errorf("dial agent at %s: %w", socketPath, err)
}
if err := validateAgentPeer(agentConn, peerUID); err != nil {
_ = agentConn.Close()
return fmt.Errorf("agent peer validation failed: %w", err)
}
preamble := make([]byte, len(tokenBytes)+1)
copy(preamble, tokenBytes)
if viewOnly {
preamble[len(tokenBytes)] = 1
}
if _, err := agentConn.Write(preamble); err != nil {
_ = agentConn.Close()
return fmt.Errorf("send auth preamble to agent: %w", err)
}
// Audit: one line per successfully-dispatched daemon→agent preamble.
// Token printed as its first 8 hex chars (enough to correlate, not
// enough to use). Kept at Info so the default deployment captures it.
tokenFp := authToken
if len(tokenFp) > 8 {
tokenFp = tokenFp[:8]
}
if authedLog != nil {
authedLog.Infof("VNC IPC: dispatched preamble to agent socket=%s peer_uid=%d view_only=%v token_fp=%s", socketPath, peerUID, viewOnly, tokenFp)
}
defer client.Close()
defer agentConn.Close()
log.Debugf("proxy connected to agent, starting bidirectional copy")
done := make(chan struct{}, 2)
cp := func(label string, dst, src net.Conn) {
n, err := io.Copy(dst, src)
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
done <- struct{}{}
}
go cp("client->agent", agentConn, client)
go cp("agent->client", client, agentConn)
<-done
return nil
}
// relogAgentStream reads log lines from the agent's stderr and re-emits
// them through the daemon's logrus, so the merged log keeps a single
// format. JSON lines (the agent's normal output) are parsed and dispatched
// by level; plain-text lines (cobra errors, panic traces) are forwarded
// verbatim so early-startup failures stay visible.
func relogAgentStream(r io.Reader) {
entry := log.WithField("component", "vnc-agent")
scanner := bufio.NewScanner(r)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
continue
}
if line[0] != '{' {
entry.Warn(string(line))
continue
}
var m map[string]any
if err := json.Unmarshal(line, &m); err != nil {
entry.Warn(string(line))
continue
}
msg, _ := m["msg"].(string)
if msg == "" {
continue
}
fields := make(log.Fields)
for k, v := range m {
switch k {
case "msg", "level", "time", "func":
continue
case "caller":
fields["source"] = v
default:
fields[k] = v
}
}
e := entry.WithFields(fields)
switch m["level"] {
case "error":
e.Error(msg)
case "warning":
e.Warn(msg)
case "debug":
e.Debug(msg)
case "trace":
e.Trace(msg)
default:
e.Info(msg)
}
}
}
// dialAgentWithRetry retries the loopback connect for up to ~10 s so the
// daemon does not race the agent's first listen. Returns the live conn or
// the final error. Aborts early when ctx is cancelled so a Stop() during
// service-mode startup doesn't leave a goroutine sleeping for 10 s.
func dialAgentWithRetry(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer
var lastErr error
for range 50 {
if err := ctx.Err(); err != nil {
if lastErr == nil {
lastErr = err
}
return nil, lastErr
}
dialCtx, cancel := context.WithTimeout(ctx, time.Second)
c, err := d.DialContext(dialCtx, "unix", addr)
cancel()
if err == nil {
return c, nil
}
lastErr = err
select {
case <-ctx.Done():
if errors.Is(lastErr, context.Canceled) || errors.Is(lastErr, context.DeadlineExceeded) {
lastErr = ctx.Err()
}
return nil, lastErr
case <-time.After(200 * time.Millisecond):
}
}
return nil, lastErr
}

View File

@@ -1,46 +0,0 @@
//go:build darwin && !ios
package server
import (
"fmt"
"net"
"golang.org/x/sys/unix"
)
// validateAgentPeer enforces that the peer behind the just-connected Unix
// socket is the agent we expect it to be: a process running under
// expectedUID, with the right effective uid stamped by the kernel on the
// socket. Refuses (with a non-nil error) if anything else is listening on
// the path (an unrelated local process that won the listen race or
// squatted the path before us). Defends against the daemon shipping its
// per-spawn auth token to a process that isn't the spawned agent.
func validateAgentPeer(conn net.Conn, expectedUID uint32) error {
uconn, ok := conn.(*net.UnixConn)
if !ok {
return fmt.Errorf("peer cred: expected *net.UnixConn, got %T", conn)
}
raw, err := uconn.SyscallConn()
if err != nil {
return fmt.Errorf("peer cred: syscall conn: %w", err)
}
var cred *unix.Xucred
var inner error
ctlErr := raw.Control(func(fd uintptr) {
cred, inner = unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
})
if ctlErr != nil {
return fmt.Errorf("peer cred: control: %w", ctlErr)
}
if inner != nil {
return fmt.Errorf("peer cred: getsockopt LOCAL_PEERCRED: %w", inner)
}
if cred == nil {
return fmt.Errorf("peer cred: nil xucred")
}
if cred.Uid != expectedUID {
return fmt.Errorf("peer cred: agent uid %d does not match expected %d", cred.Uid, expectedUID)
}
return nil
}

View File

@@ -1,115 +0,0 @@
//go:build darwin && !ios
package server
import (
"net"
"os"
"path/filepath"
"strings"
"sync"
"testing"
)
// TestValidateAgentPeerAcceptsOwnUID confirms the happy path: a Unix
// socket whose peer is the current process must validate when the
// expected uid matches the process's own. Both sides of a unix-socket
// pair share the same kernel cred, so this exercises the real getsockopt
// LOCAL_PEERCRED path.
func TestValidateAgentPeerAcceptsOwnUID(t *testing.T) {
dir := t.TempDir()
sockPath := filepath.Join(dir, "test.sock")
ln, err := net.Listen("unix", sockPath)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := ln.Accept()
if err == nil {
_ = c.Close()
}
}()
c, err := net.Dial("unix", sockPath)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.Close()
if err := validateAgentPeer(c, uint32(os.Getuid())); err != nil {
t.Fatalf("validateAgentPeer rejected own uid: %v", err)
}
wg.Wait()
}
// TestValidateAgentPeerRejectsWrongUID ensures the validator fails when
// the expected uid differs from the kernel-reported peer uid. This is
// the path that catches a hostile process that won the listen race.
func TestValidateAgentPeerRejectsWrongUID(t *testing.T) {
dir := t.TempDir()
sockPath := filepath.Join(dir, "test.sock")
ln, err := net.Listen("unix", sockPath)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := ln.Accept()
if err == nil {
_ = c.Close()
}
}()
c, err := net.Dial("unix", sockPath)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.Close()
// Pick a uid the test process certainly isn't running as.
wrongUID := uint32(os.Getuid()) + 1
err = validateAgentPeer(c, wrongUID)
if err == nil {
t.Fatal("expected mismatch error, got nil")
}
if !strings.Contains(err.Error(), "does not match expected") {
t.Fatalf("error should mention uid mismatch, got: %v", err)
}
wg.Wait()
}
// TestValidateAgentPeerRejectsNonUnix protects against being handed a
// non-Unix-socket connection (the validator can't enforce anything on
// e.g. a *net.TCPConn so it must refuse rather than silently pass).
func TestValidateAgentPeerRejectsNonUnix(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp: %v", err)
}
defer ln.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := ln.Accept()
if err == nil {
_ = c.Close()
}
}()
c, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial tcp: %v", err)
}
defer c.Close()
if err := validateAgentPeer(c, 0); err == nil {
t.Fatal("expected refusal on non-unix conn, got nil")
}
wg.Wait()
}

Some files were not shown because too many files have changed in this diff Show More