Compare commits

...

104 Commits

Author SHA1 Message Date
Viktor Liu
bf2fb2fd44 Address CodeRabbit review on embedded VNC PR 2026-05-24 18:52:57 +02:00
Viktor Liu
4e3e3ce6d3 Surface VNC initiator in status, clarify proxy logs, dampen capture noise 2026-05-24 17:07:59 +02:00
Viktor Liu
5e2830be8a Harden VNC server, IPC, and management plumbing 2026-05-24 16:02:36 +02:00
Viktor Liu
f557e665a5 Return error from gateApproval and log at the caller 2026-05-23 19:50:27 +02:00
Viktor Liu
fa57eedaf5 Address CodeRabbit review and fix CI on embedded-vnc 2026-05-23 19:44:21 +02:00
Viktor Liu
7cb6388349 Decline VNC approval early when no console user is logged in 2026-05-23 19:15:01 +02:00
Viktor Liu
1f912be673 Address codespell and Sonar findings on embedded-vnc 2026-05-23 19:06:02 +02:00
Viktor Liu
8d329da591 Evict orphaned packet captures and annotate VNC streams 2026-05-23 18:33:55 +02:00
Viktor Liu
8e72967bbe Add per-connection user-approval prompts for VNC 2026-05-23 18:33:55 +02:00
Viktor Liu
c29ef638f4 Switch VNC daemon-to-agent IPC to Unix sockets and audit-log every connection 2026-05-22 15:32:35 +02:00
Viktor Liu
97b7b010f5 Fold init-only VNC and SSH setters into Config-struct constructors 2026-05-22 13:32:25 +02:00
Viktor Liu
030c57150f Signal Zlib encode failure and fall back to Raw 2026-05-22 12:06:52 +02:00
Viktor Liu
0f03c612d1 Lower CreateTemporaryAccess complexity and emit VncAuth for session pubkeys 2026-05-22 12:01:18 +02:00
Viktor Liu
1cc5967198 Address follow-up CodeRabbit VNC findings 2026-05-22 11:35:16 +02:00
Viktor Liu
412193c602 Address CodeRabbit VNC review feedback 2026-05-21 18:09:07 +02:00
Viktor Liu
5e67febf57 Address Sonar findings and move noise to direct dependency 2026-05-21 17:55:27 +02:00
Viktor Liu
ee348ba007 Abort VNC agent dial retry loop on server shutdown 2026-05-21 17:44:22 +02:00
Viktor Liu
3d3055dc7f Replace VNC JWT auth with a Noise_IK handshake bound to ACL-pushed pubkeys 2026-05-21 17:36:15 +02:00
Viktor Liu
2f4ddf0796 Emit explicit Fn flagsChanged transitions around macOS navigation keys 2026-05-21 12:30:14 +02:00
Viktor Liu
98d533c8e8 Address CodeRabbit feedback on VNC server agent matching and session lifecycle 2026-05-21 12:01:45 +02:00
Viktor Liu
ef4ea2e311 Set Fn flag on macOS navigation keycodes so the next key isn't treated as Fn-modified 2026-05-20 18:03:38 +02:00
Viktor Liu
b41d11bbbe Allow Cursor pseudo-encoding in session mode and cache last XFixes sprite 2026-05-20 17:39:07 +02:00
Viktor Liu
f37e228cc2 Replace magic env-var and subcommand strings with named constants 2026-05-20 17:22:02 +02:00
Viktor Liu
640a267556 Address CodeRabbit feedback on VNC server 2026-05-20 17:16:55 +02:00
Viktor Liu
17359cdc1e Fix VNC lint, 386 atomic alignment, and Sonar code smells 2026-05-20 16:34:29 +02:00
Viktor Liu
7e5846a1ee Resolve merge conflicts with main 2026-05-20 15:38:01 +02:00
Viktor Liu
517bea0daf Collapse X11 DISPLAY/XAUTHORITY auto-detect logs into one line 2026-05-20 15:36:26 +02:00
Viktor Liu
9192b4f029 [client] Bump macOS sleep callback timeout to 20s (#6220) 2026-05-20 13:09:22 +02:00
Maycon Santos
c784b02550 [misc] Update contribution guidelines (#6219)
Update contribution guidelines and PR template to require discussing impactful changes with the team
2026-05-20 12:21:03 +02:00
Viktor Liu
896530fd82 Add ExtendedMouseButtons for back/forward mouse buttons 2026-05-20 12:15:00 +02:00
Viktor Liu
354fd004c7 Enable IdP JWKS refresh in VNC JWT validator 2026-05-20 12:15:00 +02:00
Viktor Liu
c28e41e82b Track macOS click count and pixel-scale wheel scroll 2026-05-20 12:14:53 +02:00
Viktor Liu
02b9fe704b Use pixel-mode scroll on macOS for smoother wheel events 2026-05-20 12:14:45 +02:00
Viktor Liu
5e200fa571 Drop unreliable Sequoia preflight from macOS Screen Recording check 2026-05-20 12:14:37 +02:00
Viktor Liu
7d61975f6c Proxy macOS VNC connections from the LaunchDaemon to a per-user agent via launchctl asuser 2026-05-20 12:12:20 +02:00
Viktor Liu
62b36112ea Extract daemon-to-agent loopback proxy and token helpers into a platform-neutral file 2026-05-20 12:11:15 +02:00
Viktor Liu
df9a6fb020 Drop pbpaste trace log that fires whenever the macOS pasteboard is empty 2026-05-20 12:11:15 +02:00
Viktor Liu
b1b04f9ec6 Composite remote cursor into the framebuffer when the dashboard toggles it on 2026-05-20 12:11:15 +02:00
Viktor Liu
fe15688f20 Emit Cursor pseudo-encoding on Linux, Windows, and macOS 2026-05-20 12:11:15 +02:00
Viktor Liu
2285db2b62 Treat ExtendedClipboard messages with the Caps bit as Caps regardless of co-set action bits 2026-05-20 12:11:15 +02:00
Viktor Liu
b3f0f53a23 Collapse dirty rects to their bounding box when the bbox is densely dirty 2026-05-20 12:11:15 +02:00
Viktor Liu
5eec9962ba Honour client JPEG quality fully now that backpressure caps it dynamically 2026-05-20 12:11:15 +02:00
Viktor Liu
393c102f45 Throttle VNC encoder JPEG quality and skip frames under write backpressure 2026-05-20 12:11:15 +02:00
Viktor Liu
b41fbad5e1 Surface DXGI fallback to GDI at warn level on Windows 2026-05-20 12:11:15 +02:00
Viktor Liu
24a5f2252c Accept any RGB shift permutation as Tight-compatible per RFB 7.7.6 2026-05-20 12:11:15 +02:00
Viktor Liu
9d189bb3e8 Restore Hextile SolidFill and Zlib encoding paths 2026-05-20 12:11:15 +02:00
Maycon Santos
8e2505b59c [management] Add metrics for peer status updates and ephemeral cleanup (#6196)
* [management] Add metrics for peer status updates and ephemeral cleanup

The session-fenced MarkPeerConnected / MarkPeerDisconnected path and
the ephemeral peer cleanup loop both run silently today: when fencing
rejects a stale stream, when a cleanup tick deletes peers, or when a
batch delete fails, we have no operational signal beyond log lines.

Add OpenTelemetry counters and a histogram so the same SLO-style
dashboards that already exist for the network-map controller can cover
peer connect/disconnect and ephemeral cleanup too.

All new attributes are bounded enums: operation in {connect,disconnect}
and outcome in {applied,stale,error,peer_not_found}. No account, peer,
or user ID is ever written as a metric label — total cardinality is
fixed at compile time (8 counter series, 2 histogram series, 4 unlabeled
ephemeral series).

Metric methods are nil-receiver safe so test composition that doesn't
wire telemetry (the bulk of the existing tests) works unchanged. The
ephemeral manager exposes a SetMetrics setter rather than taking the
collector through its constructor, keeping the constructor signature
stable across all test call sites.

* [management] Add OpenTelemetry metrics for ephemeral peer cleanup

Introduce counters for tracking ephemeral peer cleanup, including peers pending deletion, cleanup runs, successful deletions, and failed batches. Metrics are nil-receiver safe to ensure compatibility with test setups without telemetry.
2026-05-20 12:11:15 +02:00
Maycon Santos
97bc1eebde [management] Fence peer status updates with a session token (#6193)
* [management] Fence peer status updates with a session token

The connect/disconnect path used a best-effort LastSeen-after-streamStart
comparison to decide whether a status update should land. Under contention
— a re-sync arriving while the previous stream's disconnect was still in
flight, or two management replicas seeing the same peer at once — the
check was a read-then-decide-then-write window: any UPDATE in between
caused the wrong row to be written. The Go-side time.Now() that fed the
comparison also drifted under lock contention, since it was captured
seconds before the write actually committed.

Replace it with an integer-nanosecond fencing token stored alongside the
status. Every gRPC sync stream uses its open time (UnixNano) as its token.
Connects only land when the incoming token is strictly greater than the
stored one; disconnects only land when the incoming token equals the
stored one (i.e. we're the stream that owns the current session). Both
are single optimistic-locked UPDATEs — no read-then-write, no transaction
wrapper.

LastSeen is now written by the database itself (CURRENT_TIMESTAMP). The
caller never supplies it, so the value always reflects the real moment
of the UPDATE rather than the moment the caller queued the work — which
was already off by minutes under heavy lock contention.

Side effects (geo lookup, peer-login-expiration scheduling, network-map
fan-out) are explicitly documented as running after the fence UPDATE
commits, never inside it. Geo also skips the update when realIP equals
the stored ConnectionIP, dropping a redundant SavePeerLocation call on
same-IP reconnects.

Tests cover the three semantic cases (matched disconnect lands, stale
disconnect dropped, stale connect dropped) plus a 16-goroutine race test
that asserts the highest token always wins.

* [management] Add SessionStartedAt to peer status updates

Stored `SessionStartedAt` for fencing token propagation across goroutines and updated database queries/functions to handle the new field. Removed outdated geolocation handling logic and adjusted tests for concurrency safety.

* Rename `peer_status_required_approval` to `peer_status_requires_approval` in SQL store fields
2026-05-20 12:11:15 +02:00
Nicolas Frati
32a5a061b8 [management] fix: device redirect uri wasn't registered (#6191)
* fix: device redirect uri wasn't registered

* fix lint
2026-05-20 12:11:15 +02:00
Viktor Liu
d927ef468a Clean up legacy 32-bit and HKCU registry entries on Windows install (#6176) 2026-05-20 12:11:15 +02:00
Maycon Santos
d3f3e08035 Avoid context cancellation in cancelPeerRoutines (#6175)
When closing go routines and handling peer disconnect, we should avoid canceling the flow due to parent gRPC context cancellation.

This change triggers disconnection handling with a context that is not bound to the parent gRPC cancellation.
2026-05-20 12:11:15 +02:00
Viktor Liu
6bb66e0fad [management] Avoid peer IP reallocation when account settings update preserves the network range (#6173) 2026-05-20 12:11:15 +02:00
Maycon Santos
d250f92c43 feat(reverse-proxy): clusters API surfaces type, online status, and capability flags (#6148)
The cluster listing now answers three questions in one round-trip
instead of forcing the dashboard to cross-reference the domains API:
which clusters can this account see, are they currently up, and what
do they support. The ProxyCluster wire type drops the boolean
self_hosted in favour of a `type` enum (`account` / `shared`) plus
explicit `online`, `supports_custom_ports`, `require_subdomain`, and
`supports_crowdsec` fields.

Store query reworked so offline clusters still appear (no last_seen
WHERE), with online and connected_proxies both derived from the
existing 2-min active window via portable CASE expressions; the
1-hour heartbeat reaper still removes long-stale rows. Service
manager enriches each cluster with the capability flags via the
existing per-cluster lookups (CapabilityProvider now also exposes
ClusterSupportsCrowdSec).

GetActiveClusterAddresses* keep their tight 2-min filter so service
routing and domain enumeration aren't pulled into the wider window.

The hard cut removes self_hosted from the response — the dashboard is
the only consumer and is updated in the matching PR; no transitional
field is shipped.

Adds a cross-engine regression test asserting offline clusters
surface, connected_proxies counts only fresh proxies, and
account-scoped BYOP clusters never leak across accounts.
2026-05-20 10:08:34 +02:00
Maycon Santos
80966ab1b0 [management] Ensure SessionStartedAt has a default value (#6211)
* [management] Ensure SessionStartedAt has a default value

Avoid null values for the new column

* [management] Add PeerStatus with LastSeen in peer_test

* [management] Add migration for PeerStatusSessionStartedAt default value

* [management] Add PeerStatus with LastSeen in migration tests
2026-05-20 08:25:30 +02:00
Maycon Santos
af24fd7796 [management] Add metrics for peer status updates and ephemeral cleanup (#6196)
* [management] Add metrics for peer status updates and ephemeral cleanup

The session-fenced MarkPeerConnected / MarkPeerDisconnected path and
the ephemeral peer cleanup loop both run silently today: when fencing
rejects a stale stream, when a cleanup tick deletes peers, or when a
batch delete fails, we have no operational signal beyond log lines.

Add OpenTelemetry counters and a histogram so the same SLO-style
dashboards that already exist for the network-map controller can cover
peer connect/disconnect and ephemeral cleanup too.

All new attributes are bounded enums: operation in {connect,disconnect}
and outcome in {applied,stale,error,peer_not_found}. No account, peer,
or user ID is ever written as a metric label — total cardinality is
fixed at compile time (8 counter series, 2 histogram series, 4 unlabeled
ephemeral series).

Metric methods are nil-receiver safe so test composition that doesn't
wire telemetry (the bulk of the existing tests) works unchanged. The
ephemeral manager exposes a SetMetrics setter rather than taking the
collector through its constructor, keeping the constructor signature
stable across all test call sites.

* [management] Add OpenTelemetry metrics for ephemeral peer cleanup

Introduce counters for tracking ephemeral peer cleanup, including peers pending deletion, cleanup runs, successful deletions, and failed batches. Metrics are nil-receiver safe to ensure compatibility with test setups without telemetry.
2026-05-18 22:55:19 +02:00
Maycon Santos
13d32d274f [management] Fence peer status updates with a session token (#6193)
* [management] Fence peer status updates with a session token

The connect/disconnect path used a best-effort LastSeen-after-streamStart
comparison to decide whether a status update should land. Under contention
— a re-sync arriving while the previous stream's disconnect was still in
flight, or two management replicas seeing the same peer at once — the
check was a read-then-decide-then-write window: any UPDATE in between
caused the wrong row to be written. The Go-side time.Now() that fed the
comparison also drifted under lock contention, since it was captured
seconds before the write actually committed.

Replace it with an integer-nanosecond fencing token stored alongside the
status. Every gRPC sync stream uses its open time (UnixNano) as its token.
Connects only land when the incoming token is strictly greater than the
stored one; disconnects only land when the incoming token equals the
stored one (i.e. we're the stream that owns the current session). Both
are single optimistic-locked UPDATEs — no read-then-write, no transaction
wrapper.

LastSeen is now written by the database itself (CURRENT_TIMESTAMP). The
caller never supplies it, so the value always reflects the real moment
of the UPDATE rather than the moment the caller queued the work — which
was already off by minutes under heavy lock contention.

Side effects (geo lookup, peer-login-expiration scheduling, network-map
fan-out) are explicitly documented as running after the fence UPDATE
commits, never inside it. Geo also skips the update when realIP equals
the stored ConnectionIP, dropping a redundant SavePeerLocation call on
same-IP reconnects.

Tests cover the three semantic cases (matched disconnect lands, stale
disconnect dropped, stale connect dropped) plus a 16-goroutine race test
that asserts the highest token always wins.

* [management] Add SessionStartedAt to peer status updates

Stored `SessionStartedAt` for fencing token propagation across goroutines and updated database queries/functions to handle the new field. Removed outdated geolocation handling logic and adjusted tests for concurrency safety.

* Rename `peer_status_required_approval` to `peer_status_requires_approval` in SQL store fields
2026-05-18 20:25:12 +02:00
Viktor Liu
bc407527f4 Register VNC netstack service only when netstack is active 2026-05-18 14:50:10 +02:00
Viktor Liu
5543404188 Cap honored VNC client JPEG quality at 50 2026-05-18 14:50:10 +02:00
Viktor Liu
c2fdf62f1f Detect dead VNC peers on both ends and report session stats 2026-05-18 14:50:10 +02:00
Viktor Liu
b9f5264e36 Restore createRDPProxy wasm entry point for dashboard RDP 2026-05-18 14:50:10 +02:00
Nicolas Frati
705f87fc20 [management] fix: device redirect uri wasn't registered (#6191)
* fix: device redirect uri wasn't registered

* fix lint
2026-05-18 12:57:59 +02:00
Viktor Liu
97d0a6776f Release sticky modifiers and mouse buttons on client disconnect 2026-05-18 08:55:27 +02:00
Viktor Liu
7e7e056f3a Reset Tight zlib stream when deflater is recreated mid-session
Also scrub brand-name references from comments.
2026-05-18 07:54:21 +02:00
Viktor Liu
785f94d13f Guard buildExtClipProvideText against oversized input 2026-05-18 07:42:24 +02:00
Viktor Liu
bfb6750b13 Reset encoding capability flags on each SetEncodings 2026-05-18 07:41:42 +02:00
Viktor Liu
f5e1057127 Latin-1 round-trip for legacy CutText and soft-fail ext clipboard errors 2026-05-18 07:41:12 +02:00
Viktor Liu
ee393d0e62 Clamp Tight length to 22 bits and fall back to Raw on overflow 2026-05-17 21:27:13 +02:00
Viktor Liu
0b8fc5da59 Split session.go: encoder pipeline and clipboard handling into separate files 2026-05-17 17:32:01 +02:00
Viktor Liu
2d0a54f31a Fix golangci-lint and Sonar: drop newZlibState, extract applyEncoding, inline stub comment 2026-05-17 17:16:10 +02:00
Viktor Liu
61ec8d67de Honor QualityLevel and CompressLevel pseudo-encodings 2026-05-17 16:52:57 +02:00
Viktor Liu
76add0b9b2 Fix ExtendedClipboard auto-request by advertising all actions in Caps 2026-05-17 16:47:53 +02:00
Viktor Liu
a11341f57a Add ExtendedClipboard pseudo-encoding for UTF-8 bidirectional clipboard 2026-05-17 16:34:14 +02:00
Viktor Liu
b135d462d6 Drop unused zlibState.scratch field 2026-05-17 16:33:48 +02:00
Viktor Liu
da37a28951 Exclude VNC server from js, ios, and android builds 2026-05-17 15:48:15 +02:00
Viktor Liu
4f884d9f30 Add QEMU Extended Key Event for layout-independent input 2026-05-17 15:48:15 +02:00
Viktor Liu
2bed8b641b Lock pixel format to 32bpp little-endian truecolour and reject other formats 2026-05-17 15:48:15 +02:00
Viktor Liu
b4f696272a Drop unused VNC DES auth path 2026-05-17 15:48:15 +02:00
Viktor Liu
6d937af7a0 Drop dead Hextile and standalone Zlib encoding paths 2026-05-17 15:48:15 +02:00
Viktor Liu
db5b6cfbb7 Add DesktopSize, DesktopName, LastRect pseudo-encodings with resize detection 2026-05-17 15:48:15 +02:00
Viktor Liu
e75948753a Prompt for macOS Accessibility and Screen Recording at VNC enable time 2026-05-17 15:48:15 +02:00
Viktor Liu
047cc958b5 Throttle capture-failure log to once per 5s while capturer is down 2026-05-17 08:23:34 +02:00
Viktor Liu
cd005ef9a9 Add CopyRect detection and emission for tile-aligned moves 2026-05-17 08:13:52 +02:00
Viktor Liu
44ed0c1992 Drop xclip-no-selection trace log that fires every 2s on Xvfb 2026-05-17 08:13:46 +02:00
Viktor Liu
d6d3fa95c7 Drop unused getPeerFromResource helper 2026-05-17 06:48:46 +02:00
Viktor Liu
fa90283781 Extract wildcard user merge helper to satisfy case-clause length 2026-05-17 06:37:42 +02:00
Viktor Liu
8bf13b0d0c Merge SSH wildcard authorized users across matching rules 2026-05-17 06:33:27 +02:00
Viktor Liu
a8541a1529 Apply posture and validated-peers filtering on ResourceTypePeer policy resolution 2026-05-17 06:33:23 +02:00
Viktor Liu
94068d3ebc Drop -ac from Xvfb/Xorg invocations to keep xhost localuser grant authoritative 2026-05-17 06:32:50 +02:00
Viktor Liu
738c585ee7 Guard VNC session negotiated encoding state with RWMutex 2026-05-17 06:32:31 +02:00
Viktor Liu
9b5541d17d Extract session-address anonymization helper to lower status complexity 2026-05-16 22:11:28 +02:00
Viktor Liu
7123e6d1f4 Fix Windows lint errcheck/unused and Linux nilerr in console VNC fallback 2026-05-16 17:23:36 +02:00
Viktor Liu
62cf9e873b Track active VNC sessions in status and address CodeRabbit findings 2026-05-16 17:06:19 +02:00
Viktor Liu
3f91f49277 Clean up legacy 32-bit and HKCU registry entries on Windows install (#6176) 2026-05-16 16:52:57 +02:00
Viktor Liu
9f0aa1ce26 Add embedded VNC server with JWT auth and per-peer toggle 2026-05-16 16:49:14 +02:00
Maycon Santos
347c5bf317 Avoid context cancellation in cancelPeerRoutines (#6175)
When closing go routines and handling peer disconnect, we should avoid canceling the flow due to parent gRPC context cancellation.

This change triggers disconnection handling with a context that is not bound to the parent gRPC cancellation.
2026-05-16 16:29:01 +02:00
Viktor Liu
22e2519d71 [management] Avoid peer IP reallocation when account settings update preserves the network range (#6173) 2026-05-16 15:51:48 +02:00
Vlad
e916f12cca [proxy] auth token generation on mapping (#6157)
* [management / proxy] auth token generation on mapping

* fix tests
2026-05-15 19:13:44 +02:00
Viktor Liu
9ed2e2a5b4 [client] Drop DNS probes for passive health projection (#5971) 2026-05-15 17:07:38 +02:00
Viktor Liu
2ccae7ec47 [client] Mirror v4 exit selection onto v6 pair and honour SkipAutoApply per route (#6150) 2026-05-15 16:58:47 +02:00
Viktor Liu
07e5450117 [management] Bracket IPv6 reverse-proxy target hosts when building URL Host field (#6141) 2026-05-14 16:42:40 +02:00
Viktor Liu
3f914090cb [client] Bracket IPv6 in embed listeners, expand debug bundle (#6134) 2026-05-14 16:22:53 +02:00
Viktor Liu
ea9fab4396 [management] Allocate and preserve IPv6 overlay addresses for embedded proxy peers (#6132) 2026-05-14 16:05:33 +02:00
Vlad
77b479286e [management] fix offline statuses for public proxy clusters (#6133) 2026-05-14 13:27:50 +02:00
Maycon Santos
ab2a8794e7 [client] Add short flags for status command options (#6137)
* [client] Add short flags for status command options

* uppercase filters
2026-05-14 12:30:42 +02:00
205 changed files with 26325 additions and 2986 deletions

View File

@@ -12,6 +12,7 @@
- [ ] Is a feature enhancement
- [ ] It is a refactor
- [ ] Created tests that fail without the change (if possible)
- [ ] This change does **not** modify the public API, gRPC protocols, functionality behavior, CLI / service flags, or introduce a new feature — **OR** I have discussed it with the NetBird team beforehand (link the issue / Slack thread in the description). See [CONTRIBUTING.md](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTING.md#discuss-changes-with-the-netbird-team-first).
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).

View File

@@ -61,8 +61,8 @@ jobs:
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 58720256 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
if [ ${SIZE} -gt 62914560 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 60MB limit!"
exit 1
fi

View File

@@ -15,6 +15,7 @@ If you haven't already, join our slack workspace [here](https://docs.netbird.io/
- [Contributing to NetBird](#contributing-to-netbird)
- [Contents](#contents)
- [Code of conduct](#code-of-conduct)
- [Discuss changes with the NetBird team first](#discuss-changes-with-the-netbird-team-first)
- [Directory structure](#directory-structure)
- [Development setup](#development-setup)
- [Requirements](#requirements)
@@ -33,6 +34,14 @@ Conduct which can be found in the file [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
By participating, you are expected to uphold this code. Please report
unacceptable behavior to community@netbird.io.
## Discuss changes with the NetBird team first
Changes to the **public API**, **gRPC protocols**, **functionality behavior**, **CLI / service flags**, or **new features** should be discussed with the NetBird team before you start the work. These surfaces are part of NetBird's contract with operators, self-hosters, and downstream integrators, and changes to them have compatibility, security, and release-planning implications that benefit from an early conversation.
Open an issue or reach out on [Slack](https://docs.netbird.io/slack-url) to talk through what you have in mind. We'll help shape the change, flag any constraints we know about, and confirm the direction so the PR review can focus on implementation rather than design.
Typical bug fixes, internal refactors, documentation updates, and tests do not need pre-discussion — open the PR directly.
## Directory structure
The NetBird project monorepo is organized to maintain most of its individual dependencies code within their directories, except for a few auxiliary or shared packages.

View File

@@ -43,16 +43,16 @@ func init() {
ipsFilterMap = make(map[string]struct{})
prefixNamesFilterMap = make(map[string]struct{})
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer")
statusCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVarP(&ipv4Flag, "ipv4", "4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.PersistentFlags().BoolVarP(&ipv6Flag, "ipv6", "6", false, "display only NetBird IPv6 of this peer")
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
statusCmd.PersistentFlags().StringSliceVarP(&ipsFilter, "filter-by-ips", "I", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
statusCmd.PersistentFlags().StringSliceVarP(&prefixNamesFilter, "filter-by-names", "N", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVarP(&statusFilter, "filter-by-status", "S", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVarP(&connectionTypeFilter, "filter-by-connection-type", "T", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVarP(&checkFlag, "check", "C", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
}
func statusFunc(cmd *cobra.Command, args []string) error {

View File

@@ -361,6 +361,12 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
req.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(disableVNCApprovalFlag).Changed {
req.DisableVNCApproval = &disableVNCApproval
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
@@ -467,30 +473,14 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
if cmd.Flag(serverVNCAllowedFlag).Changed {
ic.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(disableVNCApprovalFlag).Changed {
ic.DisableVNCApproval = &disableVNCApproval
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
ic.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
applySSHFlagsToConfig(cmd, &ic)
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
@@ -566,6 +556,49 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
return &ic, nil
}
func applySSHFlagsToConfig(cmd *cobra.Command, ic *profilemanager.ConfigInput) {
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
ic.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
}
func applySSHFlagsToLogin(cmd *cobra.Command, req *proto.LoginRequest) {
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
req.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ttl := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &ttl
}
}
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
@@ -595,31 +628,14 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(serverSSHAllowedFlag).Changed {
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot
if cmd.Flag(serverVNCAllowedFlag).Changed {
loginRequest.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(disableVNCApprovalFlag).Changed {
loginRequest.DisableVNCApproval = &disableVNCApproval
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
loginRequest.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
}
applySSHFlagsToLogin(cmd, &loginRequest)
if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled

100
client/cmd/vnc_agent.go Normal file
View File

@@ -0,0 +1,100 @@
//go:build windows || (darwin && !ios)
package cmd
import (
"fmt"
"net"
"net/netip"
"os"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
var (
vncAgentSocket string
vncAgentTargetUID uint32
)
func init() {
vncAgentCmd.Flags().StringVar(&vncAgentSocket, "socket", "", "Unix-domain socket path the agent listens on (required)")
vncAgentCmd.Flags().Uint32Var(&vncAgentTargetUID, "target-uid", 0, "uid the agent should drop privileges to before listening (darwin only; 0 = stay as current uid)")
rootCmd.AddCommand(vncAgentCmd)
}
// vncAgentCmd runs a VNC server inside the user's interactive session,
// listening on a Unix-domain socket. The NetBird service spawns it: on
// Windows via CreateProcessAsUser into the console session, on macOS via
// launchctl asuser into the Aqua session.
var vncAgentCmd = &cobra.Command{
Use: "vnc-agent",
Short: "Run VNC capture agent (internal, spawned by service)",
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
log.SetReportCaller(true)
log.SetFormatter(&log.JSONFormatter{})
log.SetOutput(os.Stderr)
if vncAgentSocket == "" {
return fmt.Errorf("--socket is required")
}
token := os.Getenv("NB_VNC_AGENT_TOKEN")
if token == "" {
return fmt.Errorf("NB_VNC_AGENT_TOKEN not set; agent requires a token from the service")
}
// Purge the token from env so it doesn't leak via /proc/<pid>/environ.
if err := os.Unsetenv("NB_VNC_AGENT_TOKEN"); err != nil {
log.Debugf("unset NB_VNC_AGENT_TOKEN: %v", err)
}
// Drop root privileges to the target console user BEFORE creating
// the listening socket: keeps a post-auth bug in the encoder /
// input / capture paths confined to the user's own privileges
// rather than escalating to host root, and makes the daemon's
// LOCAL_PEERCRED check see the right uid. No-op on Windows
// (both processes run as SYSTEM) and when --target-uid is 0.
if vncAgentTargetUID != 0 {
if err := dropAgentPrivileges(vncAgentTargetUID); err != nil {
return fmt.Errorf("drop privileges to uid %d: %w", vncAgentTargetUID, err)
}
}
if err := os.Remove(vncAgentSocket); err != nil && !os.IsNotExist(err) {
log.Debugf("remove stale socket %s: %v", vncAgentSocket, err)
}
ln, err := net.Listen("unix", vncAgentSocket)
if err != nil {
return fmt.Errorf("listen on %s: %w", vncAgentSocket, err)
}
if err := os.Chmod(vncAgentSocket, 0o600); err != nil {
log.Debugf("chmod %s: %v", vncAgentSocket, err)
}
capturer, injector, err := newAgentResources()
if err != nil {
_ = ln.Close()
return err
}
srv := vncserver.New(vncserver.Config{
Capturer: capturer,
Injector: injector,
DisableAuth: true,
AgentTokenHex: token,
Listener: ln,
})
if err := srv.Start(cmd.Context(), netip.AddrPort{}, netip.Prefix{}); err != nil {
return fmt.Errorf("start vnc server: %w", err)
}
log.Infof("vnc-agent listening on %s, ready", vncAgentSocket)
<-cmd.Context().Done()
log.Info("vnc-agent context cancelled, shutting down")
return srv.Stop()
},
SilenceUsage: true,
}

View File

@@ -0,0 +1,18 @@
//go:build darwin && !ios
package cmd
import (
"fmt"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
capturer := vncserver.NewMacPoller()
injector, err := vncserver.NewMacInputInjector()
if err != nil {
return nil, nil, fmt.Errorf("macOS input injector: %w", err)
}
return capturer, injector, nil
}

View File

@@ -0,0 +1,50 @@
//go:build darwin && !ios
package cmd
import (
"fmt"
"os"
"syscall"
)
// dropAgentPrivileges drops the vnc-agent process from root (its
// launchctl-asuser-inherited starting uid) to the target console user
// before any other initialisation runs. Without this the agent runs as
// root for the lifetime of the session; any post-auth memory-safety
// issue in the capture/input/encode paths would then be a root-level
// RCE on the host instead of a user-level one. Also makes the daemon's
// LOCAL_PEERCRED check correctly identify the agent as the console user,
// not as root.
//
// Returns an error when the agent is running as a non-root uid that
// differs from targetUID: non-root can only setuid to itself, so a
// mismatch here means the spawn went to the wrong session.
func dropAgentPrivileges(targetUID uint32) error {
if targetUID == 0 {
return fmt.Errorf("refusing to keep agent running as root (target uid 0)")
}
cur := uint32(os.Getuid())
if cur == targetUID {
return nil
}
if cur != 0 {
return fmt.Errorf("agent uid %d does not match expected %d and we lack root to fix it", cur, targetUID)
}
// Drop supplementary groups first: setgid alone doesn't touch the
// auxiliary group list, leaving root's groups attached would let the
// dropped process write to root-only group-writable files.
if err := syscall.Setgroups([]int{}); err != nil {
return fmt.Errorf("setgroups([]): %w", err)
}
if err := syscall.Setgid(int(targetUID)); err != nil {
return fmt.Errorf("setgid(%d): %w", targetUID, err)
}
if err := syscall.Setuid(int(targetUID)); err != nil {
return fmt.Errorf("setuid(%d): %w", targetUID, err)
}
if uint32(os.Getuid()) != targetUID || uint32(os.Geteuid()) != targetUID {
return fmt.Errorf("setuid verification: uid=%d euid=%d, expected %d", os.Getuid(), os.Geteuid(), targetUID)
}
return nil
}

View File

@@ -0,0 +1,55 @@
//go:build darwin && !ios
package cmd
import (
"strings"
"testing"
)
// TestDropAgentPrivileges_RefusesRootTarget locks in the contract that
// dropAgentPrivileges must never be a no-op when asked to keep the
// agent as root (target uid 0). A future caller that passes 0 by
// mistake would otherwise leave the post-auth attack surface running
// with full root privileges.
func TestDropAgentPrivileges_RefusesRootTarget(t *testing.T) {
err := dropAgentPrivileges(0)
if err == nil {
t.Fatal("expected refusal for target uid 0, got nil")
}
if !strings.Contains(err.Error(), "root") {
t.Fatalf("error should mention root, got: %v", err)
}
}
// TestDropAgentPrivileges_NoOpWhenAlreadyTarget covers the dev path
// where the agent is launched by hand as the target user (no root
// available, no setuid needed). The helper must succeed silently
// instead of trying (and failing) a setuid to its current uid.
func TestDropAgentPrivileges_NoOpWhenAlreadyTarget(t *testing.T) {
// Skip when running as root: the early-return path we want to
// cover only fires when current uid == target uid.
uid := currentUIDForTest()
if uid == 0 {
t.Skip("test must not run as root; cannot exercise the no-op early-return")
}
if err := dropAgentPrivileges(uid); err != nil {
t.Fatalf("expected no-op when current uid == target, got: %v", err)
}
}
// TestDropAgentPrivileges_RefusesMismatchedNonRoot guards the "non-root
// caller tries to setuid to a different uid" path: setuid would fail
// with EPERM anyway, but the helper should surface a clear error
// before issuing the syscall so a misconfigured spawn (wrong --target-uid
// flag) is debuggable.
func TestDropAgentPrivileges_RefusesMismatchedNonRoot(t *testing.T) {
uid := currentUIDForTest()
if uid == 0 {
t.Skip("test must not run as root; covered case requires non-root caller")
}
err := dropAgentPrivileges(uid + 1)
if err == nil {
t.Fatal("expected refusal when non-root caller asks to setuid elsewhere")
}
}

View File

@@ -0,0 +1,11 @@
//go:build darwin && !ios
package cmd
import "os"
// currentUIDForTest exposes os.Getuid for the darwin dropprivs tests
// without leaking an os import into the test file itself.
func currentUIDForTest() uint32 {
return uint32(os.Getuid())
}

View File

@@ -0,0 +1,14 @@
//go:build windows
package cmd
// dropAgentPrivileges is a no-op on Windows: the agent and the daemon
// both run as SYSTEM (the daemon spawns the agent into the interactive
// session via CreateProcessAsUser with an impersonation token, but the
// resulting process still runs under SYSTEM, not under the user's
// account). The Windows path relies on the C:\Windows\Temp socket
// location (admin/SYSTEM-write-only) and the per-spawn token for
// integrity instead.
func dropAgentPrivileges(_ uint32) error {
return nil
}

View File

@@ -0,0 +1,15 @@
//go:build windows
package cmd
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
sessionID := vncserver.GetCurrentSessionID()
log.Infof("VNC agent running in Windows session %d", sessionID)
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), nil
}

16
client/cmd/vnc_flags.go Normal file
View File

@@ -0,0 +1,16 @@
package cmd
const (
serverVNCAllowedFlag = "allow-server-vnc"
disableVNCApprovalFlag = "disable-vnc-approval"
)
var (
serverVNCAllowed bool
disableVNCApproval bool
)
func init() {
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
upCmd.PersistentFlags().BoolVar(&disableVNCApproval, disableVNCApprovalFlag, false, "Disable per-connection user approval prompts for the embedded VNC server")
}

View File

@@ -336,7 +336,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
listenAddr := net.JoinHostPort(addr.String(), port)
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
@@ -357,7 +357,7 @@ func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
listenAddr := net.JoinHostPort(addr.String(), port)
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {

View File

@@ -260,15 +260,23 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
; or HKCU by legacy installers.
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
@@ -299,11 +307,16 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
; Remove autostart entries from every view a previous installer may have used.
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."

View File

@@ -0,0 +1,219 @@
// Package approval brokers per-attempt user-accept prompts for inbound
// remote access (VNC today, SSH and others in the future). A caller pushes
// a Prompt; the broker emits a SystemEvent on the daemon→UI stream and
// blocks until the UI calls the daemon's RespondApproval RPC, the per-
// request timeout fires, or no subscriber is connected. The latter case
// fails closed so a backgrounded UI cannot silently bypass the gate.
package approval
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
)
// Metadata keys the broker reserves on the emitted SystemEvent. Callers
// should not set these themselves; values in Prompt.Metadata that collide
// are overwritten by the broker.
const (
MetaRequestID = "request_id"
MetaKind = "kind"
MetaExpiresAt = "expires_at"
)
// ShortKeyFingerprint formats a hex-encoded Noise_IK static pubkey as a
// short, eyeball-able fingerprint to display in the approval dialog.
// The dashboard-supplied display name attached to a SessionPubKey isn't
// cryptographically asserted by the connecting client, so the prompt
// must also show something that IS: the key fingerprint, a hash of
// the static public key the client just proved possession of during the
// Noise handshake. Returns the empty string when the input is too short
// to plausibly be a hex pubkey, so the row is omitted rather than
// rendered as a misleading partial.
//
// Output format: 16 hex chars grouped as XXXX-XXXX-XXXX-XXXX (64 bits of
// fingerprint, resistant to random-prefix collisions and easy for a human
// to compare with an out-of-band reference).
func ShortKeyFingerprint(hexKey string) string {
if len(hexKey) < 8 {
return ""
}
src := hexKey
if len(src) > 16 {
src = src[:16]
}
var out []byte
for i, c := range src {
if i > 0 && i%4 == 0 {
out = append(out, '-')
}
out = append(out, byte(c))
}
return string(out)
}
// Kind values for the well-known prompt subjects. New subsystems should
// add a constant here so the UI can dispatch on a known string.
const (
KindVNC = "vnc"
KindSSH = "ssh"
)
// DefaultTimeout is the wall-clock window the user has to accept or deny a
// pending approval before the broker fails closed and returns ErrTimeout.
// Kept well under typical VNC client and dashboard connection timeouts so
// the RFB rejection actually reaches the browser instead of racing the
// browser's own "connection timed out" message.
const DefaultTimeout = 15 * time.Second
// timeoutValue returns the active timeout. It's a var so tests in this
// package can shorten the wait without exposing a setter on the public
// API. Production code always sees DefaultTimeout.
var timeoutValue = func() time.Duration { return DefaultTimeout }
// ErrNoSubscriber indicates no UI is connected to consume the prompt.
// The caller must reject the underlying connection (fail-closed).
var ErrNoSubscriber = errors.New("no UI subscriber connected for approval")
// ErrTimeout indicates the user did not respond within DefaultTimeout.
var ErrTimeout = errors.New("approval timed out")
// ErrDenied indicates the user explicitly denied the connection.
var ErrDenied = errors.New("approval denied")
// EventPublisher is the subset of peer.Status used to emit prompts.
type EventPublisher interface {
PublishEvent(
severity proto.SystemEvent_Severity,
category proto.SystemEvent_Category,
msg string,
userMsg string,
metadata map[string]string,
)
HasEventSubscribers() bool
}
// Prompt describes the pending request shown to the user. Kind selects
// the UI dispatch path (e.g. "vnc", "ssh"). Subject is the human-readable
// one-liner the UI may show as a title or notification body. Metadata is
// passed through verbatim and is the subsystem-specific payload (peer
// name, source IP, mode, etc.).
type Prompt struct {
Kind string
Subject string
Metadata map[string]string
}
// Decision carries the user's response to an approval prompt. ViewOnly is
// only meaningful when Accept is true; it lets the host grant the
// connection but signal the requester that input control is withheld.
type Decision struct {
Accept bool
ViewOnly bool
}
// Broker holds in-flight approval requests keyed by request ID.
type Broker struct {
pub EventPublisher
mu sync.Mutex
pending map[string]chan Decision
}
// New returns a broker that publishes prompts via pub.
func New(pub EventPublisher) *Broker {
return &Broker{
pub: pub,
pending: make(map[string]chan Decision),
}
}
// Request emits a SystemEvent for p and blocks until the UI calls Respond,
// ctx is cancelled, or DefaultTimeout elapses. Returns a Decision when
// the user replied; ErrDenied / ErrTimeout / ErrNoSubscriber / ctx.Err
// otherwise. Callers must treat any non-nil error as a deny.
func (b *Broker) Request(ctx context.Context, p Prompt) (Decision, error) {
var zero Decision
if b == nil || b.pub == nil {
return zero, fmt.Errorf("approval broker not configured")
}
if !b.pub.HasEventSubscribers() {
return zero, ErrNoSubscriber
}
id := uuid.NewString()
resp := make(chan Decision, 1)
b.mu.Lock()
b.pending[id] = resp
b.mu.Unlock()
defer b.dropPending(id)
timeout := timeoutValue()
expiresAt := time.Now().Add(timeout)
meta := make(map[string]string, len(p.Metadata)+3)
for k, v := range p.Metadata {
meta[k] = v
}
meta[MetaRequestID] = id
meta[MetaKind] = p.Kind
meta[MetaExpiresAt] = expiresAt.UTC().Format(time.RFC3339)
subject := p.Subject
if subject == "" {
subject = fmt.Sprintf("%s connection requires approval", p.Kind)
}
b.pub.PublishEvent(proto.SystemEvent_INFO, proto.SystemEvent_APPROVAL, subject, subject, meta)
log.Debugf("approval request %s (%s) emitted: %s", id, p.Kind, subject)
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case d := <-resp:
if !d.Accept {
return zero, ErrDenied
}
return d, nil
case <-timer.C:
return zero, ErrTimeout
case <-ctx.Done():
return zero, ctx.Err()
}
}
// Respond delivers the user's decision for id. Returns true when a pending
// request matched and was woken, false when id was unknown or already done.
func (b *Broker) Respond(id string, d Decision) bool {
if b == nil {
return false
}
b.mu.Lock()
ch, ok := b.pending[id]
if ok {
delete(b.pending, id)
}
b.mu.Unlock()
if !ok {
return false
}
select {
case ch <- d:
default:
}
return true
}
func (b *Broker) dropPending(id string) {
b.mu.Lock()
delete(b.pending, id)
b.mu.Unlock()
}

View File

@@ -0,0 +1,434 @@
package approval
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/proto"
)
// fakePublisher records published events and reports whether subscribers
// are connected. The subscribers flag is the security-critical signal:
// when false the broker must refuse to emit and the gate must fail closed.
type fakePublisher struct {
mu sync.Mutex
subscribers bool
events []*proto.SystemEvent
}
func (p *fakePublisher) PublishEvent(
severity proto.SystemEvent_Severity,
category proto.SystemEvent_Category,
msg string,
userMsg string,
metadata map[string]string,
) {
p.mu.Lock()
p.events = append(p.events, &proto.SystemEvent{
Severity: severity,
Category: category,
Message: msg,
UserMessage: userMsg,
Metadata: metadata,
})
p.mu.Unlock()
}
func (p *fakePublisher) HasEventSubscribers() bool {
p.mu.Lock()
defer p.mu.Unlock()
return p.subscribers
}
func (p *fakePublisher) lastEvent(t *testing.T) *proto.SystemEvent {
t.Helper()
p.mu.Lock()
defer p.mu.Unlock()
require.NotEmpty(t, p.events, "publisher saw no events")
return p.events[len(p.events)-1]
}
func (p *fakePublisher) eventCount() int {
p.mu.Lock()
defer p.mu.Unlock()
return len(p.events)
}
// TestRequestNoSubscriberFailsClosed is the core fail-closed invariant:
// when the UI is not subscribed, the broker must refuse without emitting
// an event or arming a waiter. A regression here is a silent bypass.
func TestRequestNoSubscriberFailsClosed(t *testing.T) {
pub := &fakePublisher{subscribers: false}
b := New(pub)
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
assert.ErrorIs(t, err, ErrNoSubscriber)
assert.Equal(t, 0, pub.eventCount(), "no event must be emitted when fail-closed")
b.mu.Lock()
pending := len(b.pending)
b.mu.Unlock()
assert.Equal(t, 0, pending, "no waiter must be registered on fail-closed")
}
// TestRequestTimeoutDenies verifies that a request without a UI response
// returns ErrTimeout (deny) rather than nil (silent accept). Uses a short
// per-test broker timeout via Respond after the fact to keep the test fast.
func TestRequestTimeoutDenies(t *testing.T) {
// Replace DefaultTimeout for the lifetime of this test.
orig := DefaultTimeout
defaultTimeout(t, 60*time.Millisecond)
defer defaultTimeout(t, orig)
pub := &fakePublisher{subscribers: true}
b := New(pub)
start := time.Now()
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
assert.ErrorIs(t, err, ErrTimeout, "missing user response must yield ErrTimeout, not nil")
assert.GreaterOrEqual(t, time.Since(start), 50*time.Millisecond, "timeout fired prematurely")
}
// TestRequestDenied returns ErrDenied when the UI responds with false.
func TestRequestDenied(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
var requestID string
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
}()
requestID = waitForRequestID(t, pub)
require.True(t, b.Respond(requestID, Decision{Accept: false}))
select {
case err := <-done:
assert.ErrorIs(t, err, ErrDenied)
case <-time.After(time.Second):
t.Fatal("Request did not return after Respond(false)")
}
}
// TestRequestAccepted is the happy path. Failure here doesn't bypass the
// gate but breaks the feature.
func TestRequestAccepted(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
}()
id := waitForRequestID(t, pub)
require.True(t, b.Respond(id, Decision{Accept: true}))
select {
case err := <-done:
assert.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("Request did not return after Respond(true)")
}
}
// TestRequestCtxCancelDenies verifies that an upstream cancel (e.g. the
// engine shutting down mid-prompt) returns the cancel error rather than
// nil. A nil here would be a silent bypass on shutdown races.
func TestRequestCtxCancelDenies(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
go func() {
done <- requestErr(b, ctx, Prompt{Kind: KindVNC, Subject: "test"})
}()
// Wait until the prompt is in flight so cancel races a live waiter.
_ = waitForRequestID(t, pub)
cancel()
select {
case err := <-done:
assert.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
t.Fatal("Request did not return after ctx cancel")
}
}
// TestRespondUnknownIsNoop ensures a stray RespondApproval RPC cannot
// affect or accidentally accept any in-flight request whose id it doesn't
// match. Also confirms it doesn't panic.
func TestRespondUnknownIsNoop(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
// No in-flight prompts: Respond returns false.
assert.False(t, b.Respond("does-not-exist", Decision{Accept: true}))
// With an in-flight prompt, a wrong id still returns false and the
// prompt remains armed (eventually timing out as a deny).
defaultTimeout(t, 60*time.Millisecond)
defer defaultTimeout(t, DefaultTimeout)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
realID := waitForRequestID(t, pub)
assert.False(t, b.Respond("totally-bogus", Decision{Accept: true}), "unknown id must not match")
assert.NotEqual(t, "totally-bogus", realID)
select {
case err := <-done:
assert.ErrorIs(t, err, ErrTimeout, "armed prompt must still time out, not accept")
case <-time.After(time.Second):
t.Fatal("prompt did not resolve")
}
}
// TestRespondAfterTimeoutNoop confirms a late accept response can't
// retroactively flip a denied (timed-out) request. The dropPending defer
// in Request must have removed the entry by the time Respond races in.
func TestRespondAfterTimeoutNoop(t *testing.T) {
defaultTimeout(t, 30*time.Millisecond)
defer defaultTimeout(t, DefaultTimeout)
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
id := waitForRequestID(t, pub)
select {
case err := <-done:
require.ErrorIs(t, err, ErrTimeout)
case <-time.After(time.Second):
t.Fatal("prompt did not time out")
}
assert.False(t, b.Respond(id, Decision{Accept: true}), "late respond must be no-op")
}
// TestRespondDoubleNoop ensures a duplicate ack from the UI doesn't leak
// past the matched waiter or panic on a closed/full channel.
func TestRespondDoubleNoop(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
id := waitForRequestID(t, pub)
require.True(t, b.Respond(id, Decision{Accept: true}))
assert.False(t, b.Respond(id, Decision{Accept: false}), "second response must be no-op")
select {
case err := <-done:
assert.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("prompt did not resolve")
}
}
// TestNilBrokerRequestErrors guards the engine pre-init path where the
// broker may not yet exist (or its publisher is nil): Request must
// error, never silently accept.
func TestNilBrokerRequestErrors(t *testing.T) {
var b *Broker
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC})
assert.Error(t, err, "nil broker must error, never silently accept")
b2 := New(nil)
_, err = b2.Request(context.Background(), Prompt{Kind: KindVNC})
assert.Error(t, err, "broker with nil publisher must error, never silently accept")
}
// TestPromptMetadataInjected confirms the broker stamps request_id, kind,
// and expires_at on the emitted event. The UI relies on these keys; if
// they are dropped, the user cannot route the prompt and the response
// path breaks (which fails closed via timeout).
func TestPromptMetadataInjected(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{
Kind: KindVNC,
Subject: "VNC connection from peerA",
Metadata: map[string]string{"peer_name": "peerA"},
})
}()
id := waitForRequestID(t, pub)
ev := pub.lastEvent(t)
assert.Equal(t, proto.SystemEvent_APPROVAL, ev.Category)
assert.Equal(t, KindVNC, ev.Metadata[MetaKind])
assert.Equal(t, id, ev.Metadata[MetaRequestID])
assert.NotEmpty(t, ev.Metadata[MetaExpiresAt])
assert.Equal(t, "peerA", ev.Metadata["peer_name"], "caller metadata must pass through")
require.True(t, b.Respond(id, Decision{Accept: true}))
<-done
}
// TestConcurrentRequests verifies that two concurrent prompts are tracked
// independently. A bug that aliases ids would let one Respond unblock
// the wrong waiter (a silent accept across prompts).
func TestConcurrentRequests(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
const n = 20
results := make(chan error, n)
for i := 0; i < n; i++ {
go func() {
results <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
}
ids := waitForNRequestIDs(t, pub, n)
require.Len(t, ids, n)
// Deny exactly half, accept the rest. Track outcome per id so we can
// match each Request's return value against the response we sent.
denySet := make(map[string]bool, n)
for i, id := range ids {
deny := i%2 == 0
denySet[id] = deny
require.True(t, b.Respond(id, Decision{Accept: !deny}))
}
// Collect all returns and check no nil errors slipped past a deny.
var accepted, denied atomic.Int32
for i := 0; i < n; i++ {
select {
case err := <-results:
if err == nil {
accepted.Add(1)
} else {
assert.ErrorIs(t, err, ErrDenied)
denied.Add(1)
}
case <-time.After(2 * time.Second):
t.Fatalf("only got %d/%d responses", i, n)
}
}
assert.Equal(t, int32(n/2), denied.Load())
assert.Equal(t, int32(n/2), accepted.Load())
}
// waitForRequestID blocks until the publisher sees its next event and
// returns the request_id stamped on it.
func waitForRequestID(t *testing.T, pub *fakePublisher) string {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
pub.mu.Lock()
count := len(pub.events)
var id string
if count > 0 {
id = pub.events[count-1].Metadata[MetaRequestID]
}
pub.mu.Unlock()
if id != "" {
return id
}
time.Sleep(2 * time.Millisecond)
}
t.Fatal("timeout waiting for emitted event")
return ""
}
func waitForNRequestIDs(t *testing.T, pub *fakePublisher, n int) []string {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
pub.mu.Lock()
count := len(pub.events)
pub.mu.Unlock()
if count >= n {
break
}
time.Sleep(2 * time.Millisecond)
}
pub.mu.Lock()
defer pub.mu.Unlock()
out := make([]string, 0, len(pub.events))
seen := make(map[string]struct{}, len(pub.events))
for _, ev := range pub.events {
id := ev.Metadata[MetaRequestID]
if id == "" {
continue
}
if _, dup := seen[id]; dup {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
if len(out) < n {
t.Fatalf("only got %d/%d request ids", len(out), n)
}
return out
}
// defaultTimeout swaps the broker's per-request wall-clock window so the
// timeout tests run quickly. Restores the prior value on the next call.
func defaultTimeout(t *testing.T, d time.Duration) {
t.Helper()
if d <= 0 {
t.Fatal("defaultTimeout must be > 0")
}
timeoutValue = func() time.Duration { return d }
}
// requestErr wraps Broker.Request to drop the Decision when tests only
// care about the error path. Keeps the goroutine bodies tight.
func requestErr(b *Broker, ctx context.Context, p Prompt) error {
_, err := b.Request(ctx, p)
return err
}
// TestRequestViewOnly checks the view-only outcome flows through Request's
// Decision return without being silently swallowed.
func TestRequestViewOnly(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
type result struct {
d Decision
err error
}
done := make(chan result, 1)
go func() {
d, err := b.Request(context.Background(), Prompt{Kind: KindVNC})
done <- result{d, err}
}()
id := waitForRequestID(t, pub)
require.True(t, b.Respond(id, Decision{Accept: true, ViewOnly: true}))
select {
case r := <-done:
assert.NoError(t, r.err)
assert.True(t, r.d.Accept)
assert.True(t, r.d.ViewOnly, "ViewOnly must survive the round-trip")
case <-time.After(time.Second):
t.Fatal("view-only request did not resolve")
}
}

View File

@@ -0,0 +1,62 @@
package approval
import "testing"
// TestShortKeyFingerprint locks in the format the VNC approval prompt
// shows to the user. The fingerprint is the user's only cryptographic
// anchor against a malicious management server that pushes a spoofed
// display name, so accidental changes to its format would silently
// undermine that defence.
func TestShortKeyFingerprint(t *testing.T) {
cases := []struct {
name string
in string
want string
}{
{
name: "full_32_byte_pubkey",
in: "0123456789abcdeffedcba9876543210ffeeddccbbaa99887766554433221100",
want: "0123-4567-89ab-cdef",
},
{
name: "exactly_16_chars",
in: "0123456789abcdef",
want: "0123-4567-89ab-cdef",
},
{
name: "borderline_8_chars",
in: "01234567",
want: "0123-4567",
},
{
name: "too_short_returns_empty",
in: "0123",
want: "",
},
{
name: "empty_returns_empty",
in: "",
want: "",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := ShortKeyFingerprint(tc.in)
if got != tc.want {
t.Fatalf("ShortKeyFingerprint(%q) = %q, want %q", tc.in, got, tc.want)
}
})
}
}
// TestShortKeyFingerprint_DistinctKeysDistinctOutputs guards against a
// formatting bug that would collapse different prefixes onto the same
// displayed fingerprint and let an attacker substitute their pubkey for
// a victim's while keeping the prompt visually identical.
func TestShortKeyFingerprint_DistinctKeysDistinctOutputs(t *testing.T) {
a := ShortKeyFingerprint("0123456789abcdef" + "rest_of_pubkey_ignored")
b := ShortKeyFingerprint("0123456789abcde0" + "rest_of_pubkey_ignored")
if a == b {
t.Fatalf("expected distinct outputs for distinct prefixes, both = %q", a)
}
}

View File

@@ -315,6 +315,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.ServerVNCAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,

View File

@@ -116,7 +116,6 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
dnsAddresses []netip.AddrPort,
stateFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
@@ -126,7 +125,6 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
HostDNSAddresses: dnsAddresses,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, "")
@@ -564,6 +562,8 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
DisableVNCApproval: config.DisableVNCApproval,
EnableSSHRoot: config.EnableSSHRoot,
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
@@ -646,6 +646,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.ServerVNCAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,

View File

@@ -45,8 +45,11 @@ netbird.out: Most recent, anonymized stdout log file of the NetBird client.
routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided.
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
iptables.txt: Anonymized iptables (IPv4) rules with packet counters, if --system-info flag was provided.
ip6tables.txt: Anonymized ip6tables (IPv6) rules with packet counters, if --system-info flag was provided.
ipset.txt: Anonymized ipset list output, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters across all families (ip, ip6, inet, etc.), if --system-info flag was provided.
sysctls.txt: Forwarding, reverse-path filter, source-validation, and conntrack accounting sysctl values that the NetBird client may read or modify, if --system-info flag was provided (Linux only).
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
@@ -165,22 +168,33 @@ The config.txt file contains anonymized configuration information of the NetBird
Other non-sensitive configuration options are included without anonymization.
Firewall Rules (Linux only)
The bundle includes two separate firewall rule files:
The bundle includes the following firewall-related files:
iptables.txt:
- Complete iptables ruleset with packet counters using 'iptables -v -n -L'
- IPv4 iptables ruleset with packet counters using 'iptables-save' and 'iptables -v -n -L'
- Includes all tables (filter, nat, mangle, raw, security)
- Shows packet and byte counters for each rule
- All IP addresses are anonymized
- Chain names, table names, and other non-sensitive information remain unchanged
ip6tables.txt:
- IPv6 ip6tables ruleset with packet counters using 'ip6tables-save' and 'ip6tables -v -n -L'
- Same table coverage and anonymization as iptables.txt
- Omitted when ip6tables is not installed or no IPv6 rules are present
ipset.txt:
- Output of 'ipset list' (family-agnostic)
- IP addresses are anonymized; set names and types remain unchanged
nftables.txt:
- Complete nftables ruleset obtained via 'nft -a list ruleset'
- Complete nftables ruleset across all families (ip, ip6, inet, arp, bridge, netdev) via 'nft -a list ruleset'
- Includes rule handle numbers and packet counters
- All tables, chains, and rules are included
- Shows packet and byte counters for each rule
- All IP addresses are anonymized
- Chain names, table names, and other non-sensitive information remain unchanged
- All IP addresses are anonymized; chain/table names remain unchanged
sysctls.txt:
- Forwarding (IPv4 + IPv6, global and per-interface), reverse-path filter, source-validation, conntrack accounting, and TCP-related sysctls that netbird may read or modify
- Per-interface keys are enumerated from /proc/sys/net/ipv{4,6}/conf
- Interface names anonymized when --anonymize is set
IP Rules (Linux only)
The ip_rules.txt file contains detailed IP routing rule information:
@@ -412,6 +426,10 @@ func (g *BundleGenerator) addSystemInfo() {
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
}
if err := g.addSysctls(); err != nil {
log.Errorf("failed to add sysctls to debug bundle: %v", err)
}
if err := g.addDNSInfo(); err != nil {
log.Errorf("failed to add DNS info to debug bundle: %v", err)
}
@@ -618,6 +636,12 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.SSHJWTCacheTTL != nil {
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
}
if g.internalConfig.ServerVNCAllowed != nil {
configContent.WriteString(fmt.Sprintf("ServerVNCAllowed: %v\n", *g.internalConfig.ServerVNCAllowed))
}
if g.internalConfig.DisableVNCApproval != nil {
configContent.WriteString(fmt.Sprintf("DisableVNCApproval: %v\n", *g.internalConfig.DisableVNCApproval))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))

View File

@@ -124,15 +124,18 @@ func getSystemdLogs(serviceName string) (string, error) {
// addFirewallRules collects and adds firewall rules to the archive
func (g *BundleGenerator) addFirewallRules() error {
log.Info("Collecting firewall rules")
iptablesRules, err := collectIPTablesRules()
g.addIPTablesRulesToBundle("iptables-save", "iptables", "iptables.txt")
g.addIPTablesRulesToBundle("ip6tables-save", "ip6tables", "ip6tables.txt")
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect iptables rules: %v", err)
log.Warnf("Failed to collect ipset information: %v", err)
} else {
if g.anonymize {
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
ipsetOutput = g.anonymizer.AnonymizeString(ipsetOutput)
}
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
log.Warnf("Failed to add iptables rules to bundle: %v", err)
if err := g.addFileToZip(strings.NewReader(ipsetOutput), "ipset.txt"); err != nil {
log.Warnf("Failed to add ipset output to bundle: %v", err)
}
}
@@ -151,44 +154,65 @@ func (g *BundleGenerator) addFirewallRules() error {
return nil
}
// collectIPTablesRules collects rules using both iptables-save and verbose listing
func collectIPTablesRules() (string, error) {
var builder strings.Builder
saveOutput, err := collectIPTablesSave()
// addIPTablesRulesToBundle collects iptables/ip6tables rules and writes them to the bundle.
func (g *BundleGenerator) addIPTablesRulesToBundle(saveBin, listBin, filename string) {
rules, err := collectIPTablesRules(saveBin, listBin)
if err != nil {
log.Warnf("Failed to collect iptables rules using iptables-save: %v", err)
} else {
builder.WriteString("=== iptables-save output ===\n")
log.Warnf("Failed to collect %s rules: %v", listBin, err)
return
}
if g.anonymize {
rules = g.anonymizer.AnonymizeString(rules)
}
if err := g.addFileToZip(strings.NewReader(rules), filename); err != nil {
log.Warnf("Failed to add %s rules to bundle: %v", listBin, err)
}
}
// collectIPTablesRules collects rules using both <saveBin> and verbose listing via <listBin>.
// Returns an error when neither command produced any output (e.g. the binary is missing),
// so the caller can skip writing an empty file.
func collectIPTablesRules(saveBin, listBin string) (string, error) {
var builder strings.Builder
var collected bool
var firstErr error
saveOutput, err := runCommand(saveBin)
switch {
case err != nil:
firstErr = err
log.Warnf("Failed to collect %s output: %v", saveBin, err)
case strings.TrimSpace(saveOutput) == "":
log.Debugf("%s produced no output, skipping", saveBin)
default:
builder.WriteString(fmt.Sprintf("=== %s output ===\n", saveBin))
builder.WriteString(saveOutput)
builder.WriteString("\n")
collected = true
}
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect ipset information: %v", err)
} else {
builder.WriteString("=== ipset list output ===\n")
builder.WriteString(ipsetOutput)
builder.WriteString("\n")
}
builder.WriteString("=== iptables -v -n -L output ===\n")
listHeader := fmt.Sprintf("=== %s -v -n -L output ===\n", listBin)
builder.WriteString(listHeader)
tables := []string{"filter", "nat", "mangle", "raw", "security"}
for _, table := range tables {
builder.WriteString(fmt.Sprintf("*%s\n", table))
stats, err := getTableStatistics(table)
stats, err := runCommand(listBin, "-v", "-n", "-L", "-t", table)
if err != nil {
log.Warnf("Failed to get statistics for table %s: %v", table, err)
if firstErr == nil {
firstErr = err
}
log.Warnf("Failed to get %s statistics for table %s: %v", listBin, table, err)
continue
}
builder.WriteString(fmt.Sprintf("*%s\n", table))
builder.WriteString(stats)
builder.WriteString("\n")
collected = true
}
if !collected {
return "", fmt.Errorf("collect %s rules: %w", listBin, firstErr)
}
return builder.String(), nil
}
@@ -214,34 +238,15 @@ func collectIPSets() (string, error) {
return ipsets, nil
}
// collectIPTablesSave uses iptables-save to get rule definitions
func collectIPTablesSave() (string, error) {
cmd := exec.Command("iptables-save")
// runCommand executes a command and returns its stdout, wrapping stderr in the error on failure.
func runCommand(name string, args ...string) (string, error) {
cmd := exec.Command(name, args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String())
}
rules := stdout.String()
if strings.TrimSpace(rules) == "" {
return "", fmt.Errorf("no iptables rules found")
}
return rules, nil
}
// getTableStatistics gets verbose statistics for an entire table using iptables command
func getTableStatistics(table string) (string, error) {
cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String())
return "", fmt.Errorf("execute %s: %w (stderr: %s)", name, err, stderr.String())
}
return stdout.String(), nil
@@ -804,3 +809,91 @@ func formatSetKeyType(keyType nftables.SetDatatype) string {
return fmt.Sprintf("type-%v", keyType)
}
}
// addSysctls collects forwarding and netbird-managed sysctl values and writes them to the bundle.
func (g *BundleGenerator) addSysctls() error {
log.Info("Collecting sysctls")
content := collectSysctls()
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "sysctls.txt"); err != nil {
return fmt.Errorf("add sysctls to bundle: %w", err)
}
return nil
}
// collectSysctls reads every sysctl that the netbird client may modify, plus
// global IPv4/IPv6 forwarding, and returns a formatted dump grouped by topic.
// Per-interface values are enumerated by listing /proc/sys/net/ipv{4,6}/conf.
func collectSysctls() string {
var builder strings.Builder
writeSysctlGroup(&builder, "forwarding", []string{
"net.ipv4.ip_forward",
"net.ipv6.conf.all.forwarding",
"net.ipv6.conf.default.forwarding",
})
writeSysctlGroup(&builder, "ipv4 per-interface forwarding", listInterfaceSysctls("ipv4", "forwarding"))
writeSysctlGroup(&builder, "ipv6 per-interface forwarding", listInterfaceSysctls("ipv6", "forwarding"))
writeSysctlGroup(&builder, "rp_filter", append(
[]string{"net.ipv4.conf.all.rp_filter", "net.ipv4.conf.default.rp_filter"},
listInterfaceSysctls("ipv4", "rp_filter")...,
))
writeSysctlGroup(&builder, "src_valid_mark", append(
[]string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"},
listInterfaceSysctls("ipv4", "src_valid_mark")...,
))
writeSysctlGroup(&builder, "conntrack", []string{
"net.netfilter.nf_conntrack_acct",
"net.netfilter.nf_conntrack_tcp_loose",
})
writeSysctlGroup(&builder, "tcp", []string{
"net.ipv4.tcp_tw_reuse",
})
return builder.String()
}
func writeSysctlGroup(builder *strings.Builder, title string, keys []string) {
builder.WriteString(fmt.Sprintf("=== %s ===\n", title))
for _, key := range keys {
value, err := readSysctl(key)
if err != nil {
builder.WriteString(fmt.Sprintf("%s = <error: %v>\n", key, err))
continue
}
builder.WriteString(fmt.Sprintf("%s = %s\n", key, value))
}
builder.WriteString("\n")
}
// listInterfaceSysctls returns net.ipvX.conf.<iface>.<leaf> keys for every
// interface present in /proc/sys/net/ipvX/conf, skipping "all" and "default"
// (callers add those explicitly so they appear first).
func listInterfaceSysctls(family, leaf string) []string {
dir := fmt.Sprintf("/proc/sys/net/%s/conf", family)
entries, err := os.ReadDir(dir)
if err != nil {
return nil
}
var keys []string
for _, e := range entries {
name := e.Name()
if name == "all" || name == "default" {
continue
}
keys = append(keys, fmt.Sprintf("net.%s.conf.%s.%s", family, name, leaf))
}
sort.Strings(keys)
return keys
}
func readSysctl(key string) (string, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
value, err := os.ReadFile(path)
if err != nil {
return "", err
}
return strings.TrimSpace(string(value)), nil
}

View File

@@ -17,3 +17,8 @@ func (g *BundleGenerator) addIPRules() error {
// IP rules are only supported on Linux
return nil
}
func (g *BundleGenerator) addSysctls() error {
// Sysctl collection is only supported on Linux
return nil
}

View File

@@ -862,6 +862,8 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
RosenpassEnabled: true,
RosenpassPermissive: true,
ServerSSHAllowed: &bTrue,
ServerVNCAllowed: &bTrue,
DisableVNCApproval: &bTrue,
EnableSSHRoot: &bTrue,
EnableSSHSFTP: &bTrue,
EnableSSHLocalPortForwarding: &bTrue,

View File

@@ -16,6 +16,10 @@ type hostManager interface {
restoreHostDNS() error
supportCustomPort() bool
string() string
// getOriginalNameservers returns the OS-side resolvers used as PriorityFallback
// upstreams: pre-takeover snapshots on desktop, the OS-pushed list on Android,
// hardcoded Quad9 on iOS, nil for noop / mock.
getOriginalNameservers() []netip.Addr
}
type SystemDNSSettings struct {
@@ -131,3 +135,11 @@ func (n noopHostConfigurator) supportCustomPort() bool {
func (n noopHostConfigurator) string() string {
return "noop"
}
func (n noopHostConfigurator) getOriginalNameservers() []netip.Addr {
return nil
}
func (m *mockHostConfigurator) getOriginalNameservers() []netip.Addr {
return nil
}

View File

@@ -1,14 +1,20 @@
package dns
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// androidHostManager is a noop on the OS side (Android's VPN service handles
// DNS for us) but tracks the OS-reported resolver list pushed via
// OnUpdatedHostDNSServer so it can serve as the fallback nameserver source.
type androidHostManager struct {
holder *hostsDNSHolder
}
func newHostManager() (*androidHostManager, error) {
return &androidHostManager{}, nil
func newHostManager(holder *hostsDNSHolder) (*androidHostManager, error) {
return &androidHostManager{holder: holder}, nil
}
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
@@ -26,3 +32,12 @@ func (a androidHostManager) supportCustomPort() bool {
func (a androidHostManager) string() string {
return "none"
}
func (a androidHostManager) getOriginalNameservers() []netip.Addr {
hosts := a.holder.get()
out := make([]netip.Addr, 0, len(hosts))
for ap := range hosts {
out = append(out, ap.Addr())
}
return out
}

View File

@@ -3,6 +3,7 @@ package dns
import (
"encoding/json"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
@@ -20,6 +21,14 @@ func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
}, nil
}
func (a iosHostManager) getOriginalNameservers() []netip.Addr {
// Quad9 v4+v6: 9.9.9.9, 2620:fe::fe.
return []netip.Addr{
netip.AddrFrom4([4]byte{9, 9, 9, 9}),
netip.AddrFrom16([16]byte{0x26, 0x20, 0x00, 0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfe}),
}
}
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
jsonData, err := json.Marshal(config)
if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"io"
"net/netip"
"os/exec"
"slices"
"strings"
"syscall"
"time"
@@ -44,9 +45,11 @@ const (
nrptMaxDomainsPerRule = 50
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigPathV6 = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigDhcpNameSrvKey = "DhcpNameServer"
interfaceConfigSearchListKey = "SearchList"
// Network interface DNS registration settings
disableDynamicUpdateKey = "DisableDynamicUpdate"
@@ -67,10 +70,11 @@ const (
)
type registryConfigurator struct {
guid string
routingAll bool
gpo bool
nrptEntryCount int
guid string
routingAll bool
gpo bool
nrptEntryCount int
origNameservers []netip.Addr
}
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
@@ -94,6 +98,17 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
gpo: useGPO,
}
origNameservers, err := configurator.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from non-WG adapters: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from non-WG adapters; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from non-WG adapters: %v", len(origNameservers), origNameservers)
}
configurator.origNameservers = origNameservers
if err := configurator.configureInterface(); err != nil {
log.Errorf("failed to configure interface settings: %v", err)
}
@@ -101,6 +116,98 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
return configurator, nil
}
// captureOriginalNameservers reads DNS addresses from every Tcpip(6) interface
// registry key except the WG adapter. v4 and v6 servers live in separate
// hives (Tcpip vs Tcpip6) keyed by the same interface GUID.
func (r *registryConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
var merr *multierror.Error
for _, root := range []string{interfaceConfigPath, interfaceConfigPathV6} {
addrs, err := r.captureFromTcpipRoot(root)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("%s: %w", root, err))
continue
}
for _, addr := range addrs {
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) captureFromTcpipRoot(rootPath string) ([]netip.Addr, error) {
root, err := registry.OpenKey(registry.LOCAL_MACHINE, rootPath, registry.READ)
if err != nil {
return nil, fmt.Errorf("open key: %w", err)
}
defer closer(root)
guids, err := root.ReadSubKeyNames(-1)
if err != nil {
return nil, fmt.Errorf("read subkeys: %w", err)
}
var out []netip.Addr
for _, guid := range guids {
if strings.EqualFold(guid, r.guid) {
continue
}
out = append(out, readInterfaceNameservers(rootPath, guid)...)
}
return out, nil
}
func readInterfaceNameservers(rootPath, guid string) []netip.Addr {
keyPath := rootPath + "\\" + guid
k, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
if err != nil {
return nil
}
defer closer(k)
// Static NameServer wins over DhcpNameServer for actual resolution.
for _, name := range []string{interfaceConfigNameServerKey, interfaceConfigDhcpNameSrvKey} {
raw, _, err := k.GetStringValue(name)
if err != nil || raw == "" {
continue
}
if out := parseRegistryNameservers(raw); len(out) > 0 {
return out
}
}
return nil
}
func parseRegistryNameservers(raw string) []netip.Addr {
var out []netip.Addr
for _, field := range strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == '\t' }) {
addr, err := netip.ParseAddr(strings.TrimSpace(field))
if err != nil {
continue
}
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
continue
}
// Drop unzoned link-local: not routable without a scope id. If
// the user wrote "fe80::1%eth0" ParseAddr preserves the zone.
if addr.IsLinkLocalUnicast() && addr.Zone() == "" {
continue
}
out = append(out, addr)
}
return out
}
func (r *registryConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(r.origNameservers)
}
func (r *registryConfigurator) supportCustomPort() bool {
return false
}

View File

@@ -25,6 +25,7 @@ func (h *hostsDNSHolder) set(list []netip.AddrPort) {
h.mutex.Unlock()
}
//nolint:unused
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
h.mutex.RLock()
l := h.unprotectedDNSList

View File

@@ -76,8 +76,6 @@ func (d *Resolver) ID() types.HandlerID {
return "local-resolver"
}
func (d *Resolver) ProbeAvailability(context.Context) {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithFields(log.Fields{

View File

@@ -9,6 +9,7 @@ import (
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
@@ -70,10 +71,6 @@ func (m *MockServer) SearchDomains() []string {
return make([]string, 0)
}
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
func (m *MockServer) ProbeAvailability() {
}
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
if m.UpdateServerConfigFunc != nil {
return m.UpdateServerConfigFunc(domains)
@@ -85,8 +82,8 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
// SetRouteSources mock implementation of SetRouteSources from Server interface
func (m *MockServer) SetRouteSources(selected, active func() route.HAMap) {
// Mock implementation - no-op
}

View File

@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"net/netip"
"slices"
"strings"
"time"
@@ -32,6 +33,15 @@ const (
networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection"
networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply"
networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete"
networkManagerDbusDeviceIp4ConfigProperty = networkManagerDbusDeviceInterface + ".Ip4Config"
networkManagerDbusDeviceIp6ConfigProperty = networkManagerDbusDeviceInterface + ".Ip6Config"
networkManagerDbusDeviceIfaceProperty = networkManagerDbusDeviceInterface + ".Interface"
networkManagerDbusGetDevicesMethod = networkManagerDest + ".GetDevices"
networkManagerDbusIp4ConfigInterface = "org.freedesktop.NetworkManager.IP4Config"
networkManagerDbusIp6ConfigInterface = "org.freedesktop.NetworkManager.IP6Config"
networkManagerDbusIp4ConfigNameserverDataProperty = networkManagerDbusIp4ConfigInterface + ".NameserverData"
networkManagerDbusIp4ConfigNameserversProperty = networkManagerDbusIp4ConfigInterface + ".Nameservers"
networkManagerDbusIp6ConfigNameserversProperty = networkManagerDbusIp6ConfigInterface + ".Nameservers"
networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0
networkManagerDbusIPv4Key = "ipv4"
networkManagerDbusIPv6Key = "ipv6"
@@ -51,9 +61,10 @@ var supportedNetworkManagerVersionConstraints = []string{
}
type networkManagerDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
routingAll bool
ifaceName string
dbusLinkObject dbus.ObjectPath
routingAll bool
ifaceName string
origNameservers []netip.Addr
}
// the types below are based on dbus specification, each field is mapped to a dbus type
@@ -92,10 +103,200 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusC
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
return &networkManagerDbusConfigurator{
c := &networkManagerDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil
}
origNameservers, err := c.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from NetworkManager: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from non-WG NetworkManager devices; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from non-WG NetworkManager devices: %v", len(origNameservers), origNameservers)
}
c.origNameservers = origNameservers
return c, nil
}
// captureOriginalNameservers reads DNS servers from every NM device's
// IP4Config / IP6Config except our WG device.
func (n *networkManagerDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
devices, err := networkManagerListDevices()
if err != nil {
return nil, fmt.Errorf("list devices: %w", err)
}
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
for _, dev := range devices {
if dev == n.dbusLinkObject {
continue
}
ifaceName := readNetworkManagerDeviceInterface(dev)
for _, addr := range readNetworkManagerDeviceDNS(dev) {
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
continue
}
// IP6Config.Nameservers is a byte slice without zone info;
// reattach the device's interface name so a captured fe80::…
// stays routable.
if addr.IsLinkLocalUnicast() && ifaceName != "" {
addr = addr.WithZone(ifaceName)
}
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nil
}
func readNetworkManagerDeviceInterface(devicePath dbus.ObjectPath) string {
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
if err != nil {
return ""
}
defer closeConn()
v, err := obj.GetProperty(networkManagerDbusDeviceIfaceProperty)
if err != nil {
return ""
}
s, _ := v.Value().(string)
return s
}
func networkManagerListDevices() ([]dbus.ObjectPath, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil {
return nil, fmt.Errorf("dbus NetworkManager: %w", err)
}
defer closeConn()
var devs []dbus.ObjectPath
if err := obj.Call(networkManagerDbusGetDevicesMethod, dbusDefaultFlag).Store(&devs); err != nil {
return nil, err
}
return devs, nil
}
func readNetworkManagerDeviceDNS(devicePath dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
if err != nil {
return nil
}
defer closeConn()
var out []netip.Addr
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp4ConfigProperty); path != "" {
out = append(out, readIPv4ConfigDNS(path)...)
}
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp6ConfigProperty); path != "" {
out = append(out, readIPv6ConfigDNS(path)...)
}
return out
}
func readNetworkManagerConfigPath(obj dbus.BusObject, property string) dbus.ObjectPath {
v, err := obj.GetProperty(property)
if err != nil {
return ""
}
path, ok := v.Value().(dbus.ObjectPath)
if !ok || path == "/" {
return ""
}
return path
}
func readIPv4ConfigDNS(path dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, path)
if err != nil {
return nil
}
defer closeConn()
// NameserverData (NM 1.13+) carries strings; older NMs only expose the
// legacy uint32 Nameservers property.
if out := readIPv4NameserverData(obj); len(out) > 0 {
return out
}
return readIPv4LegacyNameservers(obj)
}
func readIPv4NameserverData(obj dbus.BusObject) []netip.Addr {
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserverDataProperty)
if err != nil {
return nil
}
entries, ok := v.Value().([]map[string]dbus.Variant)
if !ok {
return nil
}
var out []netip.Addr
for _, entry := range entries {
addrVar, ok := entry["address"]
if !ok {
continue
}
s, ok := addrVar.Value().(string)
if !ok {
continue
}
if a, err := netip.ParseAddr(s); err == nil {
out = append(out, a)
}
}
return out
}
func readIPv4LegacyNameservers(obj dbus.BusObject) []netip.Addr {
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserversProperty)
if err != nil {
return nil
}
raw, ok := v.Value().([]uint32)
if !ok {
return nil
}
out := make([]netip.Addr, 0, len(raw))
for _, n := range raw {
var b [4]byte
binary.LittleEndian.PutUint32(b[:], n)
out = append(out, netip.AddrFrom4(b))
}
return out
}
func readIPv6ConfigDNS(path dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, path)
if err != nil {
return nil
}
defer closeConn()
v, err := obj.GetProperty(networkManagerDbusIp6ConfigNameserversProperty)
if err != nil {
return nil
}
raw, ok := v.Value().([][]byte)
if !ok {
return nil
}
out := make([]netip.Addr, 0, len(raw))
for _, b := range raw {
if a, ok := netip.AddrFromSlice(b); ok {
out = append(out, a)
}
}
return out
}
func (n *networkManagerDbusConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(n.origNameservers)
}
func (n *networkManagerDbusConfigurator) supportCustomPort() bool {

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager()
return newHostManager(s.hostsDNSHolder)
}

View File

@@ -6,7 +6,7 @@ import (
"net"
"net/netip"
"os"
"strings"
"runtime"
"testing"
"time"
@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -31,8 +32,10 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/proto"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
@@ -101,16 +104,17 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger())
}
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []netip.AddrPort
for _, srv := range servers {
srvs = append(srvs, srv.AddrPort())
}
return &upstreamResolverBase{
domain: domain,
upstreamServers: srvs,
cancel: func() {},
u := &upstreamResolverBase{
domain: domain.Domain(d),
cancel: func() {},
}
u.addRace(srvs)
return u
}
func TestUpdateDNSServer(t *testing.T) {
@@ -653,74 +657,8 @@ func TestDNSServerStartStop(t *testing.T) {
}
}
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}),
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: hostManager,
currentConfig: HostDNSConfig{
Domains: []DomainConfig{
{false, "domain0", false},
{false, "domain1", false},
{false, "domain2", false},
},
},
statusRecorder: peer.NewRecorder("mgm"),
}
var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
domains := []string{}
for _, item := range config.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
domainsUpdate = strings.Join(domains, ",")
return nil
}
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
Domains: []string{"domain1"},
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
},
}, nil, 0)
deactivate(nil)
expected := "domain0,domain2"
domains := []string{}
for _, item := range server.currentConfig.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
got := strings.Join(domains, ",")
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, got)
}
reactivate()
expected = "domain0,domain1,domain2"
domains = []string{}
for _, item := range server.currentConfig.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
got = strings.Join(domains, ",")
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
}
}
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -748,6 +686,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
}
func TestDNSPermanent_updateUpstream(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -841,6 +780,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
}
func TestDNSPermanent_matchOnly(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -913,6 +853,18 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
}
}
// skipUnlessAndroid marks tests that exercise the mobile-permanent DNS path,
// which only matches a real production setup on android (NewDefaultServerPermanentUpstream
// + androidHostManager). On non-android the desktop host manager replaces it
// during Initialize and the assertion stops making sense. Skipped here until we
// have an android CI runner.
func skipUnlessAndroid(t *testing.T) {
t.Helper()
if runtime.GOOS != "android" {
t.Skip("requires android runner; mobile-permanent path doesn't match production on this OS")
}
}
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
t.Helper()
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
@@ -1065,7 +1017,6 @@ type mockHandler struct {
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) Stop() {}
func (m *mockHandler) ProbeAvailability(context.Context) {}
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
type mockService struct{}
@@ -2085,6 +2036,598 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}
// TestBuildUpstreamHandler_MergesGroupsPerDomain verifies that multiple
// admin-defined nameserver groups targeting the same domain collapse into a
// single handler with each group preserved as a sequential inner list.
func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
wgInterface := &mocWGIface{}
service := NewServiceViaMemory(wgInterface)
server := &DefaultServer{
ctx: context.Background(),
wgInterface: wgInterface,
service: service,
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
groups := []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("192.0.2.1"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{"example.com"},
},
{
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("192.0.2.2"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("192.0.2.3"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{"example.com"},
},
}
muxUpdates, err := server.buildUpstreamHandlerUpdate(groups)
require.NoError(t, err)
require.Len(t, muxUpdates, 1, "same-domain groups should merge into one handler")
assert.Equal(t, "example.com", muxUpdates[0].domain)
assert.Equal(t, PriorityUpstream, muxUpdates[0].priority)
handler := muxUpdates[0].handler.(*upstreamResolver)
require.Len(t, handler.upstreamServers, 2, "handler should have two groups")
assert.Equal(t, upstreamRace{netip.MustParseAddrPort("192.0.2.1:53")}, handler.upstreamServers[0])
assert.Equal(t, upstreamRace{
netip.MustParseAddrPort("192.0.2.2:53"),
netip.MustParseAddrPort("192.0.2.3:53"),
}, handler.upstreamServers[1])
}
// TestEvaluateNSGroupHealth covers the records-only verdict. The gate
// (overlay route selected-but-no-active-peer) is intentionally NOT an
// input to the evaluator anymore: the verdict drives the Enabled flag,
// which must always reflect what we actually observed. Gate-aware event
// suppression is tested separately in the projection test.
//
// Matrix per upstream: {no record, fresh Ok, fresh Fail, stale Fail,
// stale Ok, Ok newer than Fail, Fail newer than Ok}.
// Group verdict: any fresh-working → Healthy; any fresh-broken with no
// fresh-working → Unhealthy; otherwise Undecided.
func TestEvaluateNSGroupHealth(t *testing.T) {
now := time.Now()
a := netip.MustParseAddrPort("192.0.2.1:53")
b := netip.MustParseAddrPort("192.0.2.2:53")
recentOk := UpstreamHealth{LastOk: now.Add(-2 * time.Second)}
recentFail := UpstreamHealth{LastFail: now.Add(-1 * time.Second), LastErr: "timeout"}
staleOk := UpstreamHealth{LastOk: now.Add(-10 * time.Minute)}
staleFail := UpstreamHealth{LastFail: now.Add(-10 * time.Minute), LastErr: "timeout"}
okThenFail := UpstreamHealth{
LastOk: now.Add(-10 * time.Second),
LastFail: now.Add(-1 * time.Second),
LastErr: "timeout",
}
failThenOk := UpstreamHealth{
LastOk: now.Add(-1 * time.Second),
LastFail: now.Add(-10 * time.Second),
LastErr: "timeout",
}
tests := []struct {
name string
health map[netip.AddrPort]UpstreamHealth
servers []netip.AddrPort
wantVerdict nsGroupVerdict
wantErrSubst string
}{
{
name: "no record, undecided",
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "fresh success, healthy",
health: map[netip.AddrPort]UpstreamHealth{a: recentOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictHealthy,
},
{
name: "fresh failure, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{a: recentFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "only stale success, undecided",
health: map[netip.AddrPort]UpstreamHealth{a: staleOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "only stale failure, undecided",
health: map[netip.AddrPort]UpstreamHealth{a: staleFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "both fresh, fail newer, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{a: okThenFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "both fresh, ok newer, healthy",
health: map[netip.AddrPort]UpstreamHealth{a: failThenOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictHealthy,
},
{
name: "two upstreams, one success wins",
health: map[netip.AddrPort]UpstreamHealth{
a: recentFail,
b: recentOk,
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictHealthy,
},
{
name: "two upstreams, one fail one unseen, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{
a: recentFail,
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "two upstreams, all recent failures, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{
a: {LastFail: now.Add(-5 * time.Second), LastErr: "timeout"},
b: {LastFail: now.Add(-1 * time.Second), LastErr: "SERVFAIL"},
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "SERVFAIL",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
verdict, err := evaluateNSGroupHealth(tc.health, tc.servers, now)
assert.Equal(t, tc.wantVerdict, verdict, "verdict mismatch")
if tc.wantErrSubst != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.wantErrSubst)
} else {
assert.NoError(t, err)
}
})
}
}
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
// without spinning up real handlers.
type healthStubHandler struct {
health map[netip.AddrPort]UpstreamHealth
}
func (h *healthStubHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (h *healthStubHandler) Stop() {}
func (h *healthStubHandler) ID() types.HandlerID { return "health-stub" }
func (h *healthStubHandler) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
return h.health
}
// TestProjection_SteadyStateIsSilent guards against duplicate events:
// while a group stays Unhealthy tick after tick, only the first
// Unhealthy transition may emit. Same for staying Healthy.
func TestProjection_SteadyStateIsSilent(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "first fail emits warning")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.tick()
fx.expectNoEvent("staying unhealthy must not re-emit")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "recovery on transition")
fx.tick()
fx.tick()
fx.expectNoEvent("staying healthy must not re-emit")
}
// projTestFixture is the common setup for the projection tests: a
// single-upstream group whose route classification the test can flip by
// assigning to selected/active. Callers drive failures/successes by
// mutating stub.health and calling refreshHealth.
type projTestFixture struct {
t *testing.T
recorder *peer.Status
events <-chan *proto.SystemEvent
server *DefaultServer
stub *healthStubHandler
group *nbdns.NameServerGroup
srv netip.AddrPort
selected route.HAMap
active route.HAMap
}
func newProjTestFixture(t *testing.T) *projTestFixture {
t.Helper()
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
srv := netip.MustParseAddrPort("100.64.0.1:53")
fx := &projTestFixture{
t: t,
recorder: recorder,
events: sub.Events(),
stub: &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{}},
srv: srv,
group: &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
},
}
fx.server = &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return fx.selected },
activeRoutes: func() route.HAMap { return fx.active },
warningDelayBase: defaultWarningDelayBase,
}
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
fx.server.mux.Lock()
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
fx.server.mux.Unlock()
return fx
}
func (f *projTestFixture) setHealth(h UpstreamHealth) {
f.stub.health = map[netip.AddrPort]UpstreamHealth{f.srv: h}
}
func (f *projTestFixture) tick() []peer.NSGroupState {
f.server.refreshHealth()
return f.recorder.GetDNSStates()
}
func (f *projTestFixture) expectNoEvent(why string) {
f.t.Helper()
select {
case evt := <-f.events:
f.t.Fatalf("unexpected event (%s): %+v", why, evt)
case <-time.After(100 * time.Millisecond):
}
}
func (f *projTestFixture) expectEvent(substr, why string) *proto.SystemEvent {
f.t.Helper()
select {
case evt := <-f.events:
assert.Contains(f.t, evt.Message, substr, why)
return evt
case <-time.After(time.Second):
f.t.Fatalf("expected event (%s) with %q", why, substr)
return nil
}
}
var overlayNetForTest = netip.MustParsePrefix("100.64.0.0/16")
var overlayMapForTest = route.HAMap{"overlay": {{Network: overlayNetForTest}}}
// TestProjection_PublicFailEmitsImmediately covers rule 1: an upstream
// that is not inside any selected route (public DNS) fires the warning
// on the first Unhealthy tick, no grace period.
func TestProjection_PublicFailEmitsImmediately(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled)
fx.expectEvent("unreachable", "public DNS failure")
}
// TestProjection_OverlayConnectedFailEmitsImmediately covers rule 2:
// the upstream is inside a selected route AND the route has a Connected
// peer. Tunnel is up, failure is real, emit immediately.
func TestProjection_OverlayConnectedFailEmitsImmediately(t *testing.T) {
fx := newProjTestFixture(t)
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled)
fx.expectEvent("unreachable", "overlay + connected failure")
}
// TestProjection_OverlayNotConnectedDelaysWarning covers rule 3: the
// upstream is routed but no peer is Connected (Connecting/Idle/missing).
// First tick: Unhealthy display, no warning. After the grace window
// elapses with no recovery, the warning fires.
func TestProjection_OverlayNotConnectedDelaysWarning(t *testing.T) {
grace := 50 * time.Millisecond
fx := newProjTestFixture(t)
fx.server.warningDelayBase = grace
fx.selected = overlayMapForTest
// active stays nil: routed but not connected.
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled, "display must reflect failure even during grace window")
fx.expectNoEvent("first fail tick within grace window")
time.Sleep(grace + 10*time.Millisecond)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "warning after grace window")
}
// TestProjection_OverlayAddrNoRouteDelaysWarning covers an upstream
// whose address is inside the WireGuard overlay range but is not
// covered by any selected route (peer-to-peer DNS without an explicit
// route). Until a peer reports Connected for that address, startup
// failures must be held just like the routed case.
func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
overlayPeer := netip.MustParseAddrPort("100.66.100.5:53")
server := &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: 50 * time.Millisecond,
}
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: overlayPeer.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlayPeer.Port())}},
}
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
select {
case evt := <-sub.Events():
t.Fatalf("unexpected event during grace window: %+v", evt)
case <-time.After(100 * time.Millisecond):
}
time.Sleep(60 * time.Millisecond)
stub.health = map[netip.AddrPort]UpstreamHealth{overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}}
server.refreshHealth()
select {
case evt := <-sub.Events():
assert.Contains(t, evt.Message, "unreachable")
case <-time.After(time.Second):
t.Fatal("expected warning after grace window")
}
}
// TestProjection_StopClearsHealthState verifies that Stop wipes the
// per-group projection state so a subsequent Start doesn't inherit
// sticky flags (notably everHealthy) that would bypass the grace
// window during the next peer handshake.
func TestProjection_StopClearsHealthState(t *testing.T) {
wgIface := &mocWGIface{}
server := &DefaultServer{
ctx: context.Background(),
wgInterface: wgIface,
service: NewServiceViaMemory(wgIface),
hostManager: &noopHostConfigurator{},
extraDomains: map[domain.Domain]int{},
dnsMuxMap: make(registeredHandlerMap),
statusRecorder: peer.NewRecorder("mgm"),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: defaultWarningDelayBase,
currentConfigHash: ^uint64(0),
}
server.ctx, server.ctxCancel = context.WithCancel(context.Background())
srv := netip.MustParseAddrPort("8.8.8.8:53")
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
}
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
server.healthProjectMu.Lock()
p, ok := server.nsGroupProj[generateGroupKey(group)]
server.healthProjectMu.Unlock()
require.True(t, ok, "projection state should exist after tick")
require.True(t, p.everHealthy, "tick with success must set everHealthy")
server.Stop()
server.healthProjectMu.Lock()
cleared := server.nsGroupProj == nil
server.healthProjectMu.Unlock()
assert.True(t, cleared, "Stop must clear nsGroupProj")
}
// TestProjection_OverlayRecoversDuringGrace covers the happy path of
// rule 3: startup failures while the peer is handshaking, then the peer
// comes up and a query succeeds before the grace window elapses. No
// warning should ever have fired, and no recovery either.
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
fx := newProjTestFixture(t)
fx.server.warningDelayBase = 200 * time.Millisecond
fx.selected = overlayMapForTest
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectNoEvent("fail within grace, warning suppressed")
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
states := fx.tick()
require.Len(t, states, 1)
assert.True(t, states[0].Enabled)
fx.expectNoEvent("recovery without prior warning must not emit")
}
// TestProjection_RecoveryOnlyAfterWarning enforces the invariant the
// whole design leans on: recovery events only appear when a warning
// event was actually emitted for the current streak. A Healthy verdict
// without a prior warning is silent, so the user never sees "recovered"
// out of thin air.
func TestProjection_RecoveryOnlyAfterWarning(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
states := fx.tick()
require.Len(t, states, 1)
assert.True(t, states[0].Enabled)
fx.expectNoEvent("first healthy tick should not recover anything")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "public fail emits immediately")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "recovery follows real warning")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "second cycle warning")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "second cycle recovery")
}
// TestProjection_EverHealthyOverridesDelay covers rule 4: once a group
// has ever been Healthy, subsequent failures skip the grace window even
// if classification says "routed + not connected". The system has
// proved it can work, so any new failure is real.
func TestProjection_EverHealthyOverridesDelay(t *testing.T) {
fx := newProjTestFixture(t)
// Large base so any emission must come from the everHealthy bypass, not elapsed time.
fx.server.warningDelayBase = time.Hour
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
// Establish "ever healthy".
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectNoEvent("first healthy tick")
// Peer drops. Query fails. Routed + not connected → normally grace,
// but everHealthy flag bypasses it.
fx.active = nil
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "failure after ever-healthy must be immediate")
}
// TestProjection_ReconnectBlipEmitsPair covers the explicit tradeoff
// from the design discussion: once a group has been healthy, a brief
// reconnect that produces a failing tick will fire warning + recovery.
// This is by design: user-visible blips are accurate signal, not noise.
func TestProjection_ReconnectBlipEmitsPair(t *testing.T) {
fx := newProjTestFixture(t)
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "blip warning")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "blip recovery")
}
// TestProjection_MixedGroupEmitsImmediately covers the multi-upstream
// rule: a group with at least one public upstream is in the "immediate"
// category regardless of the other upstreams' routing, because the
// public one has no peer-startup excuse. Prevents public-DNS failures
// from being hidden behind a routed sibling.
func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
events := sub.Events()
public := netip.MustParseAddrPort("8.8.8.8:53")
overlay := netip.MustParseAddrPort("100.64.0.1:53")
overlayMap := route.HAMap{"overlay": {{Network: netip.MustParsePrefix("100.64.0.0/16")}}}
server := &DefaultServer{
ctx: context.Background(),
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return overlayMap },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: time.Hour,
}
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{
{IP: public.Addr(), NSType: nbdns.UDPNameServerType, Port: int(public.Port())},
{IP: overlay.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlay.Port())},
},
}
stub := &healthStubHandler{
health: map[netip.AddrPort]UpstreamHealth{
public: {LastFail: time.Now(), LastErr: "servfail"},
overlay: {LastFail: time.Now(), LastErr: "timeout"},
},
}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
select {
case evt := <-events:
assert.Contains(t, evt.Message, "unreachable")
case <-time.After(time.Second):
t.Fatal("expected immediate warning because group contains a public upstream")
}
}
func TestDNSLoopPrevention(t *testing.T) {
wgInterface := &mocWGIface{}
service := NewServiceViaMemory(wgInterface)
@@ -2183,17 +2726,18 @@ func TestDNSLoopPrevention(t *testing.T) {
if tt.expectedHandlers > 0 {
handler := muxUpdates[0].handler.(*upstreamResolver)
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
flat := handler.flatUpstreams()
assert.Len(t, flat, len(tt.expectedServers))
if tt.shouldFilterOwnIP {
for _, upstream := range handler.upstreamServers {
for _, upstream := range flat {
assert.NotEqual(t, dnsServerIP, upstream.Addr())
}
}
for _, expected := range tt.expectedServers {
found := false
for _, upstream := range handler.upstreamServers {
for _, upstream := range flat {
if upstream.Addr() == expected {
found = true
break

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"time"
"github.com/godbus/dbus/v5"
@@ -40,10 +41,17 @@ const (
)
type systemdDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
ifaceName string
dbusLinkObject dbus.ObjectPath
ifaceName string
wgIndex int
origNameservers []netip.Addr
}
const (
systemdDbusLinkDNSProperty = systemdDbusLinkInterface + ".DNS"
systemdDbusLinkDefaultRouteProperty = systemdDbusLinkInterface + ".DefaultRoute"
)
// the types below are based on dbus specification, each field is mapped to a dbus type
// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types
// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types
@@ -79,10 +87,145 @@ func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, e
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
return &systemdDbusConfigurator{
c := &systemdDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil
wgIndex: iface.Index,
}
origNameservers, err := c.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from systemd-resolved: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from systemd-resolved default-route links; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from systemd-resolved default-route links: %v", len(origNameservers), origNameservers)
}
c.origNameservers = origNameservers
return c, nil
}
// captureOriginalNameservers reads per-link DNS from systemd-resolved for
// every default-route link except our own WG link. Non-default-route links
// (VPNs, docker bridges) are skipped because their upstreams wouldn't
// actually serve host queries.
func (s *systemdDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("list interfaces: %w", err)
}
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
for _, iface := range ifaces {
if !s.isCandidateLink(iface) {
continue
}
linkPath, err := getSystemdLinkPath(iface.Index)
if err != nil || !isSystemdLinkDefaultRoute(linkPath) {
continue
}
for _, addr := range readSystemdLinkDNS(linkPath) {
addr = normalizeSystemdAddr(addr, iface.Name)
if !addr.IsValid() {
continue
}
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nil
}
func (s *systemdDbusConfigurator) isCandidateLink(iface net.Interface) bool {
if iface.Index == s.wgIndex {
return false
}
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
return false
}
return true
}
// normalizeSystemdAddr unmaps v4-mapped-v6, drops unspecified, and reattaches
// the link's iface name as zone for link-local v6 (Link.DNS strips it).
// Returns the zero Addr to signal "skip this entry".
func normalizeSystemdAddr(addr netip.Addr, ifaceName string) netip.Addr {
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
return netip.Addr{}
}
if addr.IsLinkLocalUnicast() {
return addr.WithZone(ifaceName)
}
return addr
}
func getSystemdLinkPath(ifIndex int) (dbus.ObjectPath, error) {
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
if err != nil {
return "", fmt.Errorf("dbus resolve1: %w", err)
}
defer closeConn()
var p string
if err := obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, int32(ifIndex)).Store(&p); err != nil {
return "", err
}
return dbus.ObjectPath(p), nil
}
func isSystemdLinkDefaultRoute(linkPath dbus.ObjectPath) bool {
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
if err != nil {
return false
}
defer closeConn()
v, err := obj.GetProperty(systemdDbusLinkDefaultRouteProperty)
if err != nil {
return false
}
b, ok := v.Value().(bool)
return ok && b
}
func readSystemdLinkDNS(linkPath dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
if err != nil {
return nil
}
defer closeConn()
v, err := obj.GetProperty(systemdDbusLinkDNSProperty)
if err != nil {
return nil
}
entries, ok := v.Value().([][]any)
if !ok {
return nil
}
var out []netip.Addr
for _, entry := range entries {
if len(entry) < 2 {
continue
}
raw, ok := entry[1].([]byte)
if !ok {
continue
}
addr, ok := netip.AddrFromSlice(raw)
if !ok {
continue
}
out = append(out, addr)
}
return out
}
func (s *systemdDbusConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(s.origNameservers)
}
func (s *systemdDbusConfigurator) supportCustomPort() bool {

View File

@@ -1,3 +1,32 @@
// Package dns implements the client-side DNS stack: listener/service on the
// peer's tunnel address, handler chain that routes questions by domain and
// priority, and upstream resolvers that forward what remains to configured
// nameservers.
//
// # Upstream resolution and the race model
//
// When two or more nameserver groups target the same domain, DefaultServer
// merges them into one upstream handler whose state is:
//
// upstreamResolverBase
// └── upstreamServers []upstreamRace // one entry per source NS group
// └── []netip.AddrPort // primary, fallback, ...
//
// Each source nameserver group contributes one upstreamRace. Within a race
// upstreams are tried in order: the next is used only on failure (timeout,
// SERVFAIL, REFUSED, no response). NXDOMAIN is a valid answer and stops
// the walk. When more than one race exists, ServeDNS fans out one
// goroutine per race and returns the first valid answer, cancelling the
// rest. A handler with a single race skips the fan-out.
//
// # Health projection
//
// Query outcomes are recorded per-upstream in UpstreamHealth. The server
// periodically merges these snapshots across handlers and projects them
// into peer.NSGroupState. There is no active probing: a group is marked
// unhealthy only when every seen upstream has a recent failure and none
// has a recent success. Healthy→unhealthy fires a single
// SystemEvent_WARNING; steady-state refreshes do not duplicate it.
package dns
import (
@@ -11,11 +40,8 @@ import (
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -25,7 +51,8 @@ import (
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
var currentMTU uint16 = iface.DefaultMTU
@@ -67,15 +94,17 @@ const (
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
ClientTimeout = 5 * time.Second
reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
// payload from the tunnel MTU.
ipUDPHeaderSize = 60 + 8
)
const testRecord = "com."
// raceMaxTotalTimeout caps the combined time spent walking all upstreams
// within one race, so a slow primary can't eat the whole race budget.
raceMaxTotalTimeout = 5 * time.Second
// raceMinPerUpstreamTimeout is the floor applied when dividing
// raceMaxTotalTimeout across upstreams within a race.
raceMinPerUpstreamTimeout = 2 * time.Second
)
const (
protoUDP = "udp"
@@ -84,6 +113,69 @@ const (
type dnsProtocolKey struct{}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type UpstreamResolver interface {
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
// upstreamRace is an ordered list of upstreams derived from one configured
// nameserver group. Order matters: the first upstream is tried first, the
// second only on failure, and so on. Multiple upstreamRace values coexist
// inside one resolver when overlapping nameserver groups target the same
// domain; those races run in parallel and the first valid answer wins.
type upstreamRace []netip.AddrPort
// UpstreamHealth is the last query-path outcome for a single upstream,
// consumed by nameserver-group status projection.
type UpstreamHealth struct {
LastOk time.Time
LastFail time.Time
LastErr string
}
type upstreamResolverBase struct {
ctx context.Context
cancel context.CancelFunc
upstreamClient upstreamClient
upstreamServers []upstreamRace
domain domain.Domain
upstreamTimeout time.Duration
healthMu sync.RWMutex
health map[netip.AddrPort]*UpstreamHealth
statusRecorder *peer.Status
// selectedRoutes returns the current set of client routes the admin
// has enabled. Called lazily from the query hot path when an upstream
// might need a tunnel-bound client (iOS) and from health projection.
selectedRoutes func() route.HAMap
}
type upstreamFailure struct {
upstream netip.AddrPort
reason string
}
type raceResult struct {
msg *dns.Msg
upstream netip.AddrPort
protocol string
ede string
failures []upstreamFailure
}
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
return context.WithValue(ctx, dnsProtocolKey{}, network)
@@ -100,16 +192,8 @@ func dnsProtocolFromContext(ctx context.Context) string {
return ""
}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
// contextWithUpstreamProtocolResult stores a mutable result holder in the context.
func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
r := &upstreamProtocolResult{}
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
}
@@ -124,67 +208,37 @@ func setUpstreamProtocol(ctx context.Context, protocol string) {
}
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type UpstreamResolver interface {
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type upstreamResolverBase struct {
ctx context.Context
cancel context.CancelFunc
upstreamClient upstreamClient
upstreamServers []netip.AddrPort
domain string
disabled bool
successCount atomic.Int32
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration
wg sync.WaitGroup
deactivate func(error)
reactivate func()
statusRecorder *peer.Status
routeMatch func(netip.Addr) bool
}
type upstreamFailure struct {
upstream netip.AddrPort
reason string
}
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain.Domain) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx)
return &upstreamResolverBase{
ctx: ctx,
cancel: cancel,
domain: domain,
upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod,
statusRecorder: statusRecorder,
ctx: ctx,
cancel: cancel,
domain: d,
upstreamTimeout: UpstreamTimeout,
statusRecorder: statusRecorder,
}
}
// String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("Upstream %s", u.upstreamServers)
return fmt.Sprintf("Upstream %s", u.flatUpstreams())
}
// ID returns the unique handler ID
// ID returns the unique handler ID. Race groupings and within-race
// ordering are both part of the identity: [[A,B]] and [[A],[B]] query
// the same servers but with different semantics (serial fallback vs
// parallel race), so their handlers must not collide.
func (u *upstreamResolverBase) ID() types.HandlerID {
servers := slices.Clone(u.upstreamServers)
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
hash := sha256.New()
hash.Write([]byte(u.domain + ":"))
for _, s := range servers {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
hash.Write([]byte(u.domain.PunycodeString() + ":"))
for _, race := range u.upstreamServers {
hash.Write([]byte("["))
for _, s := range race {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
}
hash.Write([]byte("]"))
}
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
}
@@ -194,13 +248,31 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
}
func (u *upstreamResolverBase) Stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
log.Debugf("stopping serving DNS for upstreams %s", u.flatUpstreams())
u.cancel()
}
u.mutex.Lock()
u.wg.Wait()
u.mutex.Unlock()
// flatUpstreams is for logging and ID hashing only, not for dispatch.
func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort {
var out []netip.AddrPort
for _, g := range u.upstreamServers {
out = append(out, g...)
}
return out
}
// setSelectedRoutes swaps the accessor used to classify overlay-routed
// upstreams. Called when route sources are wired after the handler was
// built (permanent / iOS constructors).
func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) {
u.selectedRoutes = selected
}
func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) {
if len(servers) == 0 {
return
}
u.upstreamServers = append(u.upstreamServers, slices.Clone(servers))
}
// ServeDNS handles a DNS request
@@ -242,82 +314,201 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
minPerUpstream := 2 * time.Second
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
if scaledTimeout > minPerUpstream {
timeout = scaledTimeout
} else {
timeout = minPerUpstream
}
groups := u.upstreamServers
switch len(groups) {
case 0:
return false, nil
case 1:
return u.tryOnlyRace(ctx, w, r, groups[0], logger)
default:
return u.raceAll(ctx, w, r, groups, logger)
}
}
func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, group upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
res := u.tryRace(ctx, r, group)
if res.msg == nil {
return false, res.failures
}
if res.ede != "" {
resutil.SetMeta(w, "ede", res.ede)
}
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
return true, res.failures
}
// raceAll runs one worker per group in parallel, taking the first valid
// answer and cancelling the rest.
func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, groups []upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
raceCtx, cancel := context.WithCancel(ctx)
defer cancel()
// Buffer sized to len(groups) so workers never block on send, even
// after the coordinator has returned.
results := make(chan raceResult, len(groups))
for _, g := range groups {
// tryRace clones the request per attempt, so workers never share
// a *dns.Msg and concurrent EDNS0 mutations can't race.
go func(g upstreamRace) {
results <- u.tryRace(raceCtx, r, g)
}(g)
}
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
for range groups {
select {
case res := <-results:
failures = append(failures, res.failures...)
if res.msg != nil {
if res.ede != "" {
resutil.SetMeta(w, "ede", res.ede)
}
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
return true, failures
}
case <-ctx.Done():
return false, failures
}
}
return false, failures
}
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
var rm *dns.Msg
var t time.Duration
var err error
func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult {
timeout := u.upstreamTimeout
if len(group) > 1 {
// Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts
// still honor raceMinPerUpstreamTimeout as a floor for correctness
// on slow links, but the outer context ensures the combined walk
// cannot exceed the cap regardless of group size.
timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout)
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout)
defer cancel()
}
var failures []upstreamFailure
for _, upstream := range group {
if ctx.Err() != nil {
return raceResult{failures: failures}
}
// Clone the request per attempt: the exchange path mutates EDNS0
// options in-place, so reusing the same *dns.Msg across sequential
// upstreams would carry those mutations (e.g. a reduced UDP size)
// into the next attempt.
res, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout)
if failure != nil {
failures = append(failures, *failure)
continue
}
res.failures = failures
return res
}
return raceResult{failures: failures}
}
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (raceResult, *upstreamFailure) {
ctx, cancel := context.WithTimeout(parentCtx, timeout)
defer cancel()
ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx)
// Advertise EDNS0 so the upstream may include Extended DNS Errors
// (RFC 8914) in failure responses; we use those to short-circuit
// failover for definitive answers like DNSSEC validation failures.
// Operate on a copy so the inbound request is unchanged: a client that
// did not advertise EDNS0 must not see an OPT in the response.
// The caller already passed a per-attempt copy, so we can mutate r
// directly; hadEdns reflects the original client request's state and
// controls whether we strip the OPT from the response.
hadEdns := r.IsEdns0() != nil
reqUp := r
if !hadEdns {
reqUp = r.Copy()
reqUp.SetEdns0(upstreamUDPSize(), false)
r.SetEdns0(upstreamUDPSize(), false)
}
var startTime time.Time
var upstreamProto *upstreamProtocolResult
func() {
ctx, cancel := context.WithTimeout(parentCtx, timeout)
defer cancel()
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp)
}()
startTime := time.Now()
rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r)
if err != nil {
return u.handleUpstreamError(err, upstream, startTime)
// A parent cancellation (e.g., another race won and the coordinator
// cancelled the losers) is not an upstream failure. Check both the
// error chain and the parent context: a transport may surface the
// cancellation as a read/deadline error rather than context.Canceled.
if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) {
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "canceled"}
}
failure := u.handleUpstreamError(err, upstream, startTime)
u.markUpstreamFail(upstream, failure.reason)
return raceResult{}, failure
}
if rm == nil || !rm.Response {
return &upstreamFailure{upstream: upstream, reason: "no response"}
u.markUpstreamFail(upstream, "no response")
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
}
proto := ""
if upstreamProto != nil {
proto = upstreamProto.protocol
}
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
if code, ok := nonRetryableEDE(rm); ok {
resutil.SetMeta(w, "ede", edeName(code))
if !hadEdns {
stripOPT(rm)
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
}
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
reason := dns.RcodeToString[rm.Rcode]
u.markUpstreamFail(upstream, reason)
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
}
if !hadEdns {
stripOPT(rm)
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
}
// healthEntry returns the mutable health record for addr, lazily creating
// the map and the entry. Caller must hold u.healthMu.
func (u *upstreamResolverBase) healthEntry(addr netip.AddrPort) *UpstreamHealth {
if u.health == nil {
u.health = make(map[netip.AddrPort]*UpstreamHealth)
}
h := u.health[addr]
if h == nil {
h = &UpstreamHealth{}
u.health[addr] = h
}
return h
}
func (u *upstreamResolverBase) markUpstreamOk(addr netip.AddrPort) {
u.healthMu.Lock()
defer u.healthMu.Unlock()
h := u.healthEntry(addr)
h.LastOk = time.Now()
h.LastFail = time.Time{}
h.LastErr = ""
}
func (u *upstreamResolverBase) markUpstreamFail(addr netip.AddrPort, reason string) {
u.healthMu.Lock()
defer u.healthMu.Unlock()
h := u.healthEntry(addr)
h.LastFail = time.Now()
h.LastErr = reason
}
// UpstreamHealth returns a snapshot of per-upstream query outcomes.
func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
u.healthMu.RLock()
defer u.healthMu.RUnlock()
out := make(map[netip.AddrPort]UpstreamHealth, len(u.health))
for k, v := range u.health {
out[k] = *v
}
return out
}
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
@@ -358,12 +549,23 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
return &upstreamFailure{upstream: upstream, reason: reason}
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
u.successCount.Add(1)
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
if u.statusRecorder == nil {
return ""
}
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, proto string, logger *log.Entry) {
resutil.SetMeta(w, "upstream", upstream.String())
if upstreamProto != nil && upstreamProto.protocol != "" {
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
if proto != "" {
resutil.SetMeta(w, "upstream_protocol", proto)
}
// Clear Zero bit from external responses to prevent upstream servers from
@@ -372,14 +574,11 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
return true
}
return true
}
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
totalUpstreams := len(u.upstreamServers)
totalUpstreams := len(u.flatUpstreams())
failedCount := len(failures)
failureSummary := formatFailures(failures)
@@ -434,119 +633,6 @@ func edeName(code uint16) string {
return fmt.Sprintf("EDE %d", code)
}
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
u.mutex.Lock()
defer u.mutex.Unlock()
// avoid probe if upstreams could resolve at least one query
if u.successCount.Load() > 0 {
return
}
var success bool
var mu sync.Mutex
var wg sync.WaitGroup
var errs *multierror.Error
for _, upstream := range u.upstreamServers {
wg.Add(1)
go func(upstream netip.AddrPort) {
defer wg.Done()
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
if err != nil {
mu.Lock()
errs = multierror.Append(errs, err)
mu.Unlock()
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
return
}
mu.Lock()
success = true
mu.Unlock()
}(upstream)
}
wg.Wait()
select {
case <-ctx.Done():
return
case <-u.ctx.Done():
return
default:
}
// didn't find a working upstream server, let's disable and try later
if !success {
u.disable(errs.ErrorOrNil())
if u.statusRecorder == nil {
return
}
u.statusRecorder.PublishEvent(
proto.SystemEvent_WARNING,
proto.SystemEvent_DNS,
"All upstream servers failed (probe failed)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": u.upstreamServersString()},
)
}
}
// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response
func (u *upstreamResolverBase) waitUntilResponse() {
exponentialBackOff := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.1,
MaxInterval: u.reactivatePeriod,
MaxElapsedTime: 0,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}
operation := func() error {
select {
case <-u.ctx.Done():
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
default:
}
for _, upstream := range u.upstreamServers {
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err)
} else {
// at least one upstream server is available, stop probing
return nil
}
}
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
return fmt.Errorf("upstream check call error")
}
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
if err != nil {
if errors.Is(err, context.Canceled) {
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
} else {
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
}
return
}
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.successCount.Add(1)
u.reactivate()
u.mutex.Lock()
u.disabled = false
u.mutex.Unlock()
}
// isTimeout returns true if the given error is a network timeout error.
//
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
@@ -558,45 +644,6 @@ func isTimeout(err error) bool {
return false
}
func (u *upstreamResolverBase) disable(err error) {
if u.disabled {
return
}
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.successCount.Store(0)
u.deactivate(err)
u.disabled = true
u.wg.Add(1)
go func() {
defer u.wg.Done()
u.waitUntilResponse()
}()
}
func (u *upstreamResolverBase) upstreamServersString() string {
var servers []string
for _, server := range u.upstreamServers {
servers = append(servers, server.String())
}
return strings.Join(servers, ", ")
}
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
defer cancel()
if externalCtx != nil {
stop2 := context.AfterFunc(externalCtx, cancel)
defer stop2()
}
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
return err
}
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
func clientUDPMaxSize(r *dns.Msg) int {
if opt := r.IsEdns0(); opt != nil {
@@ -608,13 +655,10 @@ func clientUDPMaxSize(r *dns.Msg) int {
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the inbound request came over TCP (via context), it skips the UDP attempt.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// If the request came in over TCP, go straight to TCP upstream.
if dnsProtocolFromContext(ctx) == protoTCP {
tcpClient := *client
tcpClient.Net = protoTCP
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
@@ -634,18 +678,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
opt.SetUDPSize(maxUDPPayload)
}
var (
rm *dns.Msg
t time.Duration
err error
)
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
rm, t, err := client.ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with udp: %w", err)
}
@@ -659,15 +692,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
// data than the client's buffer, we could truncate locally and skip
// the TCP retry.
tcpClient := *client
tcpClient.Net = protoTCP
if ctx == nil {
rm, t, err = tcpClient.Exchange(r, upstream)
} else {
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
}
rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
@@ -681,6 +706,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a
// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on
// the tunnel interface), it is converted to the equivalent *net.TCPAddr
// so net.Dialer doesn't reject the TCP dial with "mismatched local
// address type".
func toTCPClient(c *dns.Client) *dns.Client {
tcp := *c
tcp.Net = protoTCP
if tcp.Dialer == nil {
return &tcp
}
d := *tcp.Dialer
if ua, ok := d.LocalAddr.(*net.UDPAddr); ok {
d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone}
}
tcp.Dialer = &d
return &tcp
}
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
@@ -822,15 +866,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
return bestMatch
}
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
if u.statusRecorder == nil {
return ""
// haMapRouteCount returns the total number of routes across all HA
// groups in the map. route.HAMap is keyed by HAUniqueID with slices of
// routes per key, so len(hm) is the number of HA groups, not routes.
func haMapRouteCount(hm route.HAMap) int {
total := 0
for _, routes := range hm {
total += len(routes)
}
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
return total
}
// haMapContains checks whether ip is covered by any concrete prefix in
// the HA map. haveDynamic is reported separately: dynamic (domain-based)
// routes carry a placeholder Network that can't be prefix-checked, so we
// can't know at this point whether ip is reached through one. Callers
// decide how to interpret the unknown: health projection treats it as
// "possibly routed" to avoid emitting false-positive warnings during
// startup, while iOS dial selection requires a concrete match before
// binding to the tunnel.
func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) {
for _, routes := range hm {
for _, r := range routes {
if r.IsDynamic() {
haveDynamic = true
continue
}
if r.Network.Contains(ip) {
return true, haveDynamic
}
}
}
return false, haveDynamic
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolver struct {
@@ -26,9 +27,9 @@ func newUpstreamResolver(
_ WGIface,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
c := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
hostsDNSHolder: hostsDNSHolder,

View File

@@ -12,6 +12,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolver struct {
@@ -24,9 +25,9 @@ func newUpstreamResolver(
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
nsNet: wgIface.GetNet(),

View File

@@ -15,6 +15,7 @@ import (
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolverIOS struct {
@@ -27,9 +28,9 @@ func newUpstreamResolver(
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
@@ -62,9 +63,16 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP = upstreamIP.Unmap()
}
addr := u.wgIface.Address()
var routed bool
if u.selectedRoutes != nil {
// Only a concrete prefix match binds to the tunnel: dialing
// through a private client for an upstream we can't prove is
// routed would break public resolvers.
routed, _ = haMapContains(u.selectedRoutes(), upstreamIP)
}
needsPrivate := addr.Network.Contains(upstreamIP) ||
addr.IPv6Net.Contains(upstreamIP) ||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
routed
if needsPrivate {
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
@@ -73,8 +81,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
}
}
// Cannot use client.ExchangeContext because it overwrites our Dialer
return ExchangeWithFallback(nil, client, r, upstream)
return ExchangeWithFallback(ctx, client, r, upstream)
}
// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"strings"
"sync/atomic"
"testing"
"time"
@@ -73,7 +74,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
}
}
resolver.upstreamServers = servers
resolver.addRace(servers)
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {
cancel()
@@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
return "", nil
}
type mockUpstreamResolver struct {
r *dns.Msg
rtt time.Duration
err error
}
// exchange mock implementation of exchange from upstreamResolver
func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
return c.r, c.rtt, c.err
}
type mockUpstreamResponse struct {
msg *dns.Msg
err error
msg *dns.Msg
err error
delay time.Duration
}
type mockUpstreamResolverPerServer struct {
@@ -153,63 +144,19 @@ type mockUpstreamResolverPerServer struct {
rtt time.Duration
}
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
if r, ok := c.responses[upstream]; ok {
return r.msg, c.rtt, r.err
func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
r, ok := c.responses[upstream]
if !ok {
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
}
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
}
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
mockClient := &mockUpstreamResolver{
err: dns.ErrTime,
r: new(dns.Msg),
rtt: time.Millisecond,
}
resolver := &upstreamResolverBase{
ctx: context.TODO(),
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
reactivatePeriod: time.Microsecond * 100,
}
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
failed := false
resolver.deactivate = func(error) {
failed = true
// After deactivation, make the mock client work again
mockClient.err = nil
}
reactivated := false
resolver.reactivate = func() {
reactivated = true
}
resolver.ProbeAvailability(context.TODO())
if !failed {
t.Errorf("expected that resolving was deactivated")
return
}
if !resolver.disabled {
t.Errorf("resolver should be Disabled")
return
}
time.Sleep(time.Millisecond * 200)
if !reactivated {
t.Errorf("expected that resolving was reactivated")
return
}
if resolver.disabled {
t.Errorf("should be enabled")
if r.delay > 0 {
select {
case <-time.After(r.delay):
case <-ctx.Done():
return nil, c.rtt, ctx.Err()
}
}
return r.msg, c.rtt, r.err
}
func TestUpstreamResolver_Failover(t *testing.T) {
@@ -339,9 +286,9 @@ func TestUpstreamResolver_Failover(t *testing.T) {
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: trackingClient,
upstreamServers: []netip.AddrPort{upstream1, upstream2},
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{upstream1, upstream2})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
@@ -421,9 +368,9 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamServers: []netip.AddrPort{upstream},
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{upstream})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
@@ -440,6 +387,136 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
}
// TestUpstreamResolver_RaceAcrossGroups covers two nameserver groups
// configured for the same domain, with one broken group. The merge+race
// path should answer as fast as the working group and not pay the timeout
// of the broken one on every query.
func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) {
broken := netip.MustParseAddrPort("192.0.2.1:53")
working := netip.MustParseAddrPort("192.0.2.2:53")
successAnswer := "192.0.2.100"
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
// Force the broken upstream to only unblock via timeout /
// cancellation so the assertion below can't pass if races
// were run serially.
broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond},
working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: 250 * time.Millisecond,
}
resolver.addRace([]netip.AddrPort{broken})
resolver.addRace([]netip.AddrPort{working})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
start := time.Now()
resolver.ServeDNS(responseWriter, inputMSG)
elapsed := time.Since(start)
require.NotNil(t, responseMSG, "should write a response")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
require.NotEmpty(t, responseMSG.Answer)
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
// Working group answers in a single RTT; the broken group's
// timeout (100ms) must not block the response.
assert.Less(t, elapsed, 100*time.Millisecond, "race must not wait for broken group's timeout")
}
// TestUpstreamResolver_AllGroupsFail checks that when every group fails the
// resolver returns SERVFAIL rather than leaking a partial response.
func TestUpstreamResolver_AllGroupsFail(t *testing.T) {
a := netip.MustParseAddrPort("192.0.2.1:53")
b := netip.MustParseAddrPort("192.0.2.2:53")
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
a.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
b.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{a})
resolver.addRace([]netip.AddrPort{b})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
require.NotNil(t, responseMSG)
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode)
}
// TestUpstreamResolver_HealthTracking verifies that query-path results are
// recorded into per-upstream health, which is what projects back to
// NSGroupState for status reporting.
func TestUpstreamResolver_HealthTracking(t *testing.T) {
ok := netip.MustParseAddrPort("192.0.2.10:53")
bad := netip.MustParseAddrPort("192.0.2.11:53")
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
ok.String(): {msg: buildMockResponse(dns.RcodeSuccess, "192.0.2.100")},
bad.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{ok, bad})
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
health := resolver.UpstreamHealth()
require.Contains(t, health, ok)
assert.False(t, health[ok].LastOk.IsZero(), "ok upstream should have LastOk set")
assert.Empty(t, health[ok].LastErr)
// bad upstream was never tried because ok answered first; its health
// should remain unset.
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
}
func TestFormatFailures(t *testing.T) {
testCases := []struct {
name string
@@ -665,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
// Verify that a client EDNS0 larger than our MTU-derived limit gets
// capped in the outgoing request so the upstream doesn't send a
// response larger than our read buffer.
var receivedUDPSize uint16
var receivedUDPSize atomic.Uint32
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
if opt := r.IsEdns0(); opt != nil {
receivedUDPSize = opt.UDPSize()
receivedUDPSize.Store(uint32(opt.UDPSize()))
}
m := new(dns.Msg)
m.SetReply(r)
@@ -699,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
require.NotNil(t, rm)
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
assert.Equal(t, expectedMax, receivedUDPSize,
assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()),
"upstream should see capped EDNS0, not the client's 4096")
}
@@ -874,7 +951,7 @@ func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: tracking,
upstreamServers: []netip.AddrPort{upstream1, upstream2},
upstreamServers: []upstreamRace{{upstream1, upstream2}},
upstreamTimeout: UpstreamTimeout,
}

View File

@@ -35,6 +35,7 @@ import (
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/approval"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
@@ -123,6 +124,8 @@ type EngineConfig struct {
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
DisableVNCApproval *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -204,7 +207,9 @@ type Engine struct {
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
sshServer sshServer
vncSrv vncServer
approvalBroker *approval.Broker
statusRecorder *peer.Status
@@ -285,6 +290,7 @@ func NewEngine(
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
approvalBroker: approval.New(services.StatusRecorder),
stateManager: services.StateManager,
portForwardManager: portforward.NewManager(),
checks: services.Checks,
@@ -320,6 +326,10 @@ func (e *Engine) Stop() error {
log.Warnf("failed to stop SSH server: %v", err)
}
if err := e.stopVNCServer(); err != nil {
log.Warnf("failed to stop VNC server: %v", err)
}
e.cleanupSSHConfig()
if e.ingressGatewayMgr != nil {
@@ -512,16 +522,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
for _, r := range routes {
if r.Network.Contains(ip) {
return true
}
}
}
return false
})
e.dnsServer.SetRouteSources(e.routeManager.GetSelectedClientRoutes, e.routeManager.GetActiveClientRoutes)
if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
@@ -1019,6 +1020,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1066,6 +1068,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
}
if err := e.updateVNC(); err != nil {
log.Warnf("failed handling VNC server setup: %v", err)
}
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.IPv6 = e.wgInterface.Address().IPv6String()
@@ -1191,6 +1197,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1380,15 +1387,17 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.updateSSHServerAuth(networkMap.GetSshAuth())
}
// VNC auth: always sync, including nil so cleared auth on the management
// side is applied locally, and so it isn't skipped on the RemotePeersIsEmpty
// cleanup path.
e.updateVNCServerAuth(networkMap.GetVncAuth())
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
e.networkSerial = serial
// Test received (upstream) servers for availability right away instead of upon usage.
// If no server of a server group responds this will disable the respective handler and retry later.
go e.dnsServer.ProbeAvailability()
return nil
}
@@ -1838,6 +1847,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1932,7 +1942,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
return dnsServer, nil
default:
@@ -2602,3 +2612,16 @@ func decodeRelayIP(b []byte) netip.Addr {
}
return ip.Unmap()
}
// RespondApproval relays the user's decision for a pending approval to
// the broker. viewOnly is honoured only when accept is true. Returns
// true when the request_id matched a live prompt.
func (e *Engine) RespondApproval(requestID string, accept, viewOnly bool) bool {
if e == nil || e.approvalBroker == nil {
return false
}
return e.approvalBroker.Respond(requestID, approval.Decision{
Accept: accept,
ViewOnly: accept && viewOnly,
})
}

View File

@@ -12,7 +12,7 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/netstack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
@@ -237,22 +237,18 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
return errors.New("wg interface not initialized")
}
wgAddr := e.wgInterface.Address()
serverConfig := &sshserver.Config{
HostKeyPEM: e.config.SSHKey,
JWT: jwtConfig,
HostKeyPEM: e.config.SSHKey,
JWT: jwtConfig,
NetstackNet: e.wgInterface.GetNet(),
NetworkValidation: wgAddr,
}
server := sshserver.New(serverConfig)
wgAddr := e.wgInterface.Address()
server.SetNetworkValidation(wgAddr)
netbirdIP := wgAddr.IP
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
server.SetNetstackNet(netstackNet)
}
e.configureSSHServer(server)
if err := server.Start(e.ctx, listenAddr); err != nil {

View File

@@ -0,0 +1,303 @@
//go:build !js && !ios && !android
package internal
import (
"context"
"errors"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/approval"
"github.com/netbirdio/netbird/client/internal/metrics"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/vnc"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
type vncServer interface {
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
Stop() error
ActiveSessions() []vncserver.ActiveSessionInfo
}
func (e *Engine) setupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vnc.ExternalPort, vnc.InternalPort); err != nil {
return fmt.Errorf("add VNC port redirection: %w", err)
}
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vnc.ExternalPort, localAddr, vnc.InternalPort)
return nil
}
func (e *Engine) cleanupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vnc.ExternalPort, vnc.InternalPort); err != nil {
return fmt.Errorf("remove VNC port redirection: %w", err)
}
return nil
}
// updateVNC handles starting/stopping the VNC server based on the config flag.
func (e *Engine) updateVNC() error {
if !e.config.ServerVNCAllowed {
if e.vncSrv != nil {
log.Info("VNC server disabled, stopping")
}
return e.stopVNCServer()
}
if e.config.BlockInbound {
log.Info("VNC server disabled because inbound connections are blocked")
return e.stopVNCServer()
}
if e.vncSrv != nil {
return nil
}
return e.startVNCServer()
}
func (e *Engine) startVNCServer() error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
capturer, injector, ok := newPlatformVNC()
if !ok {
log.Debug("VNC server not supported on this platform")
return nil
}
netbirdIP := e.wgInterface.Address().IP
var sessionRecorder func(vncserver.SessionTick)
if e.clientMetrics != nil {
sessionRecorder = func(t vncserver.SessionTick) {
e.clientMetrics.RecordVNCSessionTick(e.ctx, metrics.VNCSessionTick{
Period: t.Period,
BytesOut: t.BytesOut,
Writes: t.Writes,
FBUs: t.FBUs,
MaxFBUBytes: t.MaxFBUBytes,
MaxFBURects: t.MaxFBURects,
MaxWriteBytes: t.MaxWriteBytes,
WriteNanos: t.WriteNanos,
})
}
}
serviceMode := vncNeedsServiceMode()
if serviceMode {
log.Info("VNC: running as system service, enabling service mode (per-session agent proxy)")
}
requireApproval := e.config.DisableVNCApproval == nil || !*e.config.DisableVNCApproval
srv := vncserver.New(vncserver.Config{
Capturer: capturer,
Injector: injector,
IdentityKey: e.config.WgPrivateKey[:],
ServiceMode: serviceMode,
SessionRecorder: sessionRecorder,
NetstackNet: e.wgInterface.GetNet(),
RequireApproval: requireApproval,
Approver: &vncApprover{broker: e.approvalBroker, statusRecorder: e.statusRecorder},
})
listenAddr := netip.AddrPortFrom(netbirdIP, vnc.InternalPort)
network := e.wgInterface.Address().Network
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
return fmt.Errorf("start VNC server: %w", err)
}
e.vncSrv = srv
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, vnc.InternalPort)
log.Debugf("registered VNC service with netstack for TCP:%d", vnc.InternalPort)
}
}
if err := e.setupVNCPortRedirection(); err != nil {
log.Warnf("setup VNC port redirection: %v", err)
}
log.Info("VNC server enabled")
return nil
}
// updateVNCServerAuth updates VNC fine-grained access control from management.
// A nil vncAuth clears all authorized users and session pubkeys so management
// can revoke access by omitting the field on the next sync.
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
if e.vncSrv == nil {
return
}
vncSrv, ok := e.vncSrv.(*vncserver.Server)
if !ok {
return
}
if vncAuth == nil {
vncSrv.UpdateVNCAuth(&sshauth.Config{})
return
}
protoUsers := vncAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range vncAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
sessionPubKeys := make([]sshauth.SessionPubKey, 0, len(vncAuth.GetSessionPubKeys()))
for _, e := range vncAuth.GetSessionPubKeys() {
pub := e.GetPubKey()
if len(pub) != 32 {
log.Warnf("VNC session pubkey wrong length %d", len(pub))
continue
}
hash := e.GetUserIdHash()
if len(hash) != 16 {
log.Warnf("VNC session user id hash wrong length %d", len(hash))
continue
}
sessionPubKeys = append(sessionPubKeys, sshauth.SessionPubKey{
PubKey: pub,
UserIDHash: sshuserhash.UserIDHash(hash),
DisplayName: e.GetDisplayName(),
})
}
vncSrv.UpdateVNCAuth(&sshauth.Config{
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
SessionPubKeys: sessionPubKeys,
})
}
// GetVNCServerStatus returns whether the VNC server is running and the list
// of active VNC sessions. The pointer is captured under syncMsgMux so a
// concurrent updateVNC/stopVNCServer cannot swap it out between the nil
// check and the ActiveSessions call.
func (e *Engine) GetVNCServerStatus() (enabled bool, sessions []vncserver.ActiveSessionInfo) {
e.syncMsgMux.Lock()
vncSrv := e.vncSrv
e.syncMsgMux.Unlock()
if vncSrv == nil {
return false, nil
}
return true, vncSrv.ActiveSessions()
}
func (e *Engine) stopVNCServer() error {
if e.vncSrv == nil {
return nil
}
if err := e.cleanupVNCPortRedirection(); err != nil {
log.Warnf("cleanup VNC port redirection: %v", err)
}
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, vnc.InternalPort)
}
}
log.Info("stopping VNC server")
err := e.vncSrv.Stop()
e.vncSrv = nil
if err != nil {
return fmt.Errorf("stop VNC server: %w", err)
}
return nil
}
// vncApprover adapts the generic approval.Broker for the VNC server.
type vncApprover struct {
broker *approval.Broker
statusRecorder *peer.Status
}
func (a *vncApprover) Request(ctx context.Context, info vncserver.ApprovalInfo) (vncserver.ApprovalDecision, error) {
// Resolve the source overlay IP to a peer FQDN for the prompt label.
if info.PeerName == "" && info.SourceIP != "" && a.statusRecorder != nil {
if fqdn, ok := a.statusRecorder.PeerByIP(info.SourceIP); ok {
info.PeerName = fqdn
}
}
subject := fmt.Sprintf("VNC connection from %s", displayPeer(info))
meta := map[string]string{
"peer_name": info.PeerName,
"peer_pubkey": info.PeerPubKey,
"source_ip": info.SourceIP,
"mode": info.Mode,
"username": info.Username,
"initiator": info.Initiator,
}
d, err := a.broker.Request(ctx, approval.Prompt{
Kind: approval.KindVNC,
Subject: subject,
Metadata: meta,
})
if err != nil {
return vncserver.ApprovalDecision{}, err
}
return vncserver.ApprovalDecision{ViewOnly: d.ViewOnly}, nil
}
func displayPeer(info vncserver.ApprovalInfo) string {
if info.Initiator != "" {
return info.Initiator
}
if info.PeerName != "" {
return info.PeerName
}
if info.SourceIP != "" {
return info.SourceIP
}
if info.PeerPubKey != "" {
return info.PeerPubKey
}
return "unknown peer"
}

View File

@@ -0,0 +1,31 @@
//go:build freebsd
package internal
import (
"fmt"
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
// newConsoleVNC builds the FreeBSD console fallback: vt(4) framebuffer
// for capture, /dev/uinput for input. The uinput device requires the
// `uinput` kernel module (`kldload uinput`); without it, input init
// fails and we drop to a stub injector so the user still gets a
// view-only screen mirror.
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
poller := vncserver.NewFBPoller("")
w, h := poller.Width(), poller.Height()
if w == 0 || h == 0 {
poller.Close()
return nil, nil, fmt.Errorf("vt framebuffer init failed (vt may not allow mmap on this driver)")
}
if inj, err := vncserver.NewUInputInjector(w, h); err == nil {
return poller, inj, nil
} else {
log.Infof("VNC console: uinput unavailable (%v); view-only mode. Run `kldload uinput` to enable input.", err)
return poller, &vncserver.StubInputInjector{}, nil
}
}

View File

@@ -0,0 +1,30 @@
//go:build linux && !android
package internal
import (
"fmt"
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
// newConsoleVNC builds a framebuffer + uinput VNC backend for boxes
// without a running X server. Used as the auto-fallback when
// newPlatformVNC can't reach X. Returns an error when /dev/fb0 or
// /dev/uinput aren't usable so the caller can drop back to a stub.
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
poller := vncserver.NewFBPoller("")
w, h := poller.Width(), poller.Height()
if w == 0 || h == 0 {
poller.Close()
return nil, nil, fmt.Errorf("framebuffer capturer init failed (is /dev/fb0 readable?)")
}
inj, err := vncserver.NewUInputInjector(w, h)
if err != nil {
log.Debugf("uinput unavailable, falling back to view-only VNC: %v", err)
return poller, &vncserver.StubInputInjector{}, nil
}
return poller, inj, nil
}

View File

@@ -0,0 +1,34 @@
//go:build darwin && !ios
package internal
import (
"os"
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
capturer := vncserver.NewMacPoller()
// Prompt for Screen Recording at server-enable time rather than first
// client-connect. The native prompt is far easier for users to act on
// in the moment they toggled VNC on than later when "the screen looks
// like wallpaper" would otherwise be the only clue.
vncserver.PrimeScreenCapturePermission()
injector, err := vncserver.NewMacInputInjector()
if err != nil {
log.Debugf("VNC: macOS input injector: %v", err)
return capturer, &vncserver.StubInputInjector{}, true
}
return capturer, injector, true
}
// vncNeedsServiceMode reports whether the running process is a system
// LaunchDaemon (root, parented by launchd). Daemons sit in the global
// bootstrap namespace and cannot talk to WindowServer; we route capture
// through a per-user agent in that case.
func vncNeedsServiceMode() bool {
return os.Geteuid() == 0 && os.Getppid() == 1
}

View File

@@ -0,0 +1,23 @@
//go:build js || ios || android
package internal
import (
log "github.com/sirupsen/logrus"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
type vncServer interface{}
func (e *Engine) updateVNC() error { return nil }
func (e *Engine) updateVNCServerAuth(auth *mgmProto.VNCAuth) {
if auth == nil {
return
}
log.Debugf("ignoring VNC auth push on platform without a VNC server: %d session pubkeys, %d authorized users",
len(auth.GetSessionPubKeys()), len(auth.GetAuthorizedUsers()))
}
func (e *Engine) stopVNCServer() error { return nil }

View File

@@ -0,0 +1,13 @@
//go:build windows
package internal
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), true
}
func vncNeedsServiceMode() bool {
return vncserver.GetCurrentSessionID() == 0
}

View File

@@ -0,0 +1,35 @@
//go:build (linux && !android) || freebsd
package internal
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
// Prefer X11 when an X server is reachable. NewX11InputInjector probes
// DISPLAY (and /proc) eagerly, so a non-nil error here means no X.
injector, err := vncserver.NewX11InputInjector("")
if err == nil {
return vncserver.NewX11Poller(""), injector, true
}
log.Debugf("VNC: X11 not available: %v", err)
// Fallback for headless / pre-X states (kernel console, login manager
// without X, physical server in recovery): stream the framebuffer and
// inject input via /dev/uinput.
consoleCap, consoleInj, err := newConsoleVNC()
if err == nil {
log.Infof("VNC: using framebuffer console capture (%dx%d)", consoleCap.Width(), consoleCap.Height())
return consoleCap, consoleInj, true
}
log.Debugf("VNC: framebuffer console fallback unavailable: %v", err)
return &vncserver.StubCapturer{}, &vncserver.StubInputInjector{}, false
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -120,6 +120,36 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
m.trimLocked()
}
func (m *influxDBMetrics) RecordVNCSessionTick(_ context.Context, agentInfo AgentInfo, tick VNCSessionTick) {
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s",
agentInfo.DeploymentType.String(),
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
)
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_vnc_traffic",
tags: tags,
fields: map[string]float64{
"period_seconds": tick.Period.Seconds(),
"bytes_out": float64(tick.BytesOut),
"writes": float64(tick.Writes),
"fbus": float64(tick.FBUs),
"max_fbu_bytes": float64(tick.MaxFBUBytes),
"max_fbu_rects": float64(tick.MaxFBURects),
"max_write_bytes": float64(tick.MaxWriteBytes),
"write_time_seconds": float64(tick.WriteNanos) / 1e9,
},
timestamp: time.Now(),
})
m.trimLocked()
}
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
result := "success"
if !success {

View File

@@ -59,6 +59,11 @@ type metricsImplementation interface {
// RecordLoginDuration records how long the login to management took
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
// RecordVNCSessionTick records a periodic snapshot of one VNC
// session's wire activity. Called once per metricsConn tick interval
// (and once at session close), only when the tick saw activity.
RecordVNCSessionTick(ctx context.Context, agentInfo AgentInfo, tick VNCSessionTick)
// Export exports metrics in InfluxDB line protocol format
Export(w io.Writer) error
@@ -78,6 +83,21 @@ type ClientMetrics struct {
pushCancel context.CancelFunc
}
// VNCSessionTick is one sampling slice of a VNC session's wire activity.
// BytesOut / Writes / FBUs / WriteNanos are deltas observed during this
// tick; Max* fields are the high-water marks observed during the tick.
// Period is the wall-clock duration the deltas cover.
type VNCSessionTick struct {
Period time.Duration
BytesOut uint64
Writes uint64
FBUs uint64
MaxFBUBytes uint64
MaxFBURects uint64
MaxWriteBytes uint64
WriteNanos uint64
}
// ConnectionStageTimestamps holds timestamps for each connection stage
type ConnectionStageTimestamps struct {
SignalingReceived time.Time // First signal received from remote peer (both initial and reconnection)
@@ -127,6 +147,17 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
}
// RecordVNCSessionTick records a periodic snapshot of one VNC session.
func (c *ClientMetrics) RecordVNCSessionTick(ctx context.Context, tick VNCSessionTick) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
c.impl.RecordVNCSessionTick(ctx, agentInfo, tick)
}
// RecordLoginDuration records how long the login to management server took
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
if c == nil {

View File

@@ -73,6 +73,9 @@ func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
}
func (m *mockMetrics) RecordVNCSessionTick(_ context.Context, _ AgentInfo, _ VNCSessionTick) {
}
func (m *mockMetrics) Export(w io.Writer) error {
if m.exportData != "" {
_, err := w.Write([]byte(m.exportData))

View File

@@ -1191,6 +1191,15 @@ func (d *Status) SubscribeToEvents() *EventSubscription {
}
}
// HasEventSubscribers reports whether any client is currently subscribed
// to the daemon's SystemEvent stream. Used by the VNC approval broker to
// fail closed when no UI is connected to prompt the user.
func (d *Status) HasEventSubscribers() bool {
d.eventMux.Lock()
defer d.eventMux.Unlock()
return len(d.eventStreams) > 0
}
// UnsubscribeFromEvents removes an event subscription
func (d *Status) UnsubscribeFromEvents(sub *EventSubscription) {
if sub == nil {

View File

@@ -65,6 +65,8 @@ type ConfigInput struct {
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
ServerVNCAllowed *bool
DisableVNCApproval *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -116,6 +118,8 @@ type Config struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
ServerVNCAllowed *bool
DisableVNCApproval *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -418,6 +422,33 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.ServerVNCAllowed != nil {
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
if *input.ServerVNCAllowed {
log.Infof("enabling VNC server")
} else {
log.Infof("disabling VNC server")
}
config.ServerVNCAllowed = input.ServerVNCAllowed
updated = true
}
} else if config.ServerVNCAllowed == nil {
config.ServerVNCAllowed = util.False()
updated = true
}
if input.DisableVNCApproval != nil {
if config.DisableVNCApproval == nil || *input.DisableVNCApproval != *config.DisableVNCApproval {
if *input.DisableVNCApproval {
log.Infof("disabling VNC connection approval prompt")
} else {
log.Infof("enabling VNC connection approval prompt")
}
config.DisableVNCApproval = input.DisableVNCApproval
updated = true
}
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")

View File

@@ -53,6 +53,7 @@ type Manager interface {
GetRouteSelector() *routeselector.RouteSelector
GetClientRoutes() route.HAMap
GetSelectedClientRoutes() route.HAMap
GetActiveClientRoutes() route.HAMap
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
@@ -485,6 +486,39 @@ func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
}
// GetActiveClientRoutes returns the subset of selected client routes
// that are currently reachable: the route's peer is Connected and is
// the one actively carrying the route (not just an HA sibling).
func (m *DefaultManager) GetActiveClientRoutes() route.HAMap {
m.mux.Lock()
selected := m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
recorder := m.statusRecorder
m.mux.Unlock()
if recorder == nil {
return selected
}
out := make(route.HAMap, len(selected))
for id, routes := range selected {
for _, r := range routes {
st, err := recorder.GetPeer(r.Peer)
if err != nil {
continue
}
if st.ConnStatus != peer.StatusConnected {
continue
}
if _, hasRoute := st.GetRoutes()[r.Network.String()]; !hasRoute {
continue
}
out[id] = routes
break
}
}
return out
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
m.mux.Lock()
@@ -704,7 +738,10 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
}
func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool {
return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR
if len(routes) == 0 {
return false
}
return route.IsV4DefaultRoute(routes[0].Network) || route.IsV6DefaultRoute(routes[0].Network)
}
func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) {

View File

@@ -19,6 +19,7 @@ type MockManager struct {
GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
GetSelectedClientRoutesFunc func() route.HAMap
GetActiveClientRoutesFunc func() route.HAMap
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
StopFunc func(manager *statemanager.Manager)
}
@@ -78,6 +79,14 @@ func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
return nil
}
// GetActiveClientRoutes mock implementation of GetActiveClientRoutes from the Manager interface
func (m *MockManager) GetActiveClientRoutes() route.HAMap {
if m.GetActiveClientRoutesFunc != nil {
return m.GetActiveClientRoutesFunc()
}
return nil
}
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
if m.GetClientRoutesWithNetIDFunc != nil {

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"slices"
"strings"
"sync"
"github.com/hashicorp/go-multierror"
@@ -12,10 +13,6 @@ import (
"github.com/netbirdio/netbird/route"
)
const (
exitNodeCIDR = "0.0.0.0/0"
)
type RouteSelector struct {
mu sync.RWMutex
deselectedRoutes map[route.NetID]struct{}
@@ -124,13 +121,7 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
isSelected := !deselected
return isSelected
return rs.isSelectedLocked(routeID)
}
// FilterSelected removes unselected routes from the provided map.
@@ -144,23 +135,22 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
filtered := route.HAMap{}
for id, rt := range routes {
netID := id.NetID()
_, deselected := rs.deselectedRoutes[netID]
if !deselected {
if !rs.isDeselectedLocked(id.NetID()) {
filtered[id] = rt
}
}
return filtered
}
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
// transparently apply to the synthesized v6 entry.
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
}
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
@@ -174,7 +164,7 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
filtered := make(route.HAMap, len(routes))
for id, rt := range routes {
netID := id.NetID()
if rs.isDeselected(netID) {
if rs.isDeselectedLocked(netID) {
continue
}
@@ -189,13 +179,48 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
return filtered
}
func (rs *RouteSelector) isDeselected(netID route.NetID) bool {
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
name := string(id)
if !strings.HasSuffix(name, route.V6ExitSuffix) {
return id
}
if _, ok := rs.selectedRoutes[id]; ok {
return id
}
if _, ok := rs.deselectedRoutes[id]; ok {
return id
}
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected || rs.deselectAll
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func (rs *RouteSelector) applyExitNodeFilter(
@@ -204,26 +229,23 @@ func (rs *RouteSelector) applyExitNodeFilter(
rt []*route.Route,
out route.HAMap,
) {
if rs.hasUserSelections() {
// user made explicit selects/deselects
if rs.IsSelected(netID) {
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
// drops the synthesized v6 entry that lacks its own explicit state.
effective := rs.effectiveNetID(netID)
if rs.hasUserSelectionForRouteLocked(effective) {
if rs.isSelectedLocked(effective) {
out[id] = rt
}
return
}
// no explicit selections: only include routes marked !SkipAutoApply (=AutoApply)
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func (rs *RouteSelector) hasUserSelections() bool {
return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {

View File

@@ -330,6 +330,137 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
assert.Len(t, filtered, 0) // No routes should be selected
}
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
// v4 base, so legacy persisted selections that predate v6 pairing transparently
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
// unrelated v6 with no v4 base touched is unaffected
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
})
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
// via the mirror; the mirror only applies in exit-node code paths.
assert.False(t, rs.IsSelected("corp"))
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
})
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
routes := route.HAMap{
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
})
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
// A route literally named "exit1-something" must not pair-resolve.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
})
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
routes := route.HAMap{
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
})
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-default-route entry named "corp-v6" is not an exit node and
// must not be skipped because its v4 base "corp" is deselected.
corpV6 := &route.Route{NetID: "corp-v6", Network: netip.MustParsePrefix("10.0.0.0/8")}
routes := route.HAMap{
"corp-v6|10.0.0.0/8": {corpV6},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.Contains(t, filtered, route.HAUniqueID("corp-v6|10.0.0.0/8"),
"non-exit *-v6 routes must not inherit unrelated v4 state in FilterSelectedExitNodes")
})
}
// TestRouteSelector_SkipAutoApplyPerRoute verifies that management's
// SkipAutoApply flag governs each untouched route independently, even when
// the user has explicit selections on other routes.
func TestRouteSelector_SkipAutoApplyPerRoute(t *testing.T) {
autoApplied := &route.Route{
NetID: "Auto",
Network: netip.MustParsePrefix("0.0.0.0/0"),
SkipAutoApply: false,
}
skipApply := &route.Route{
NetID: "Skip",
Network: netip.MustParsePrefix("0.0.0.0/0"),
SkipAutoApply: true,
}
routes := route.HAMap{
"Auto|0.0.0.0/0": {autoApplied},
"Skip|0.0.0.0/0": {skipApply},
}
rs := routeselector.NewRouteSelector()
// User makes an unrelated explicit selection elsewhere.
require.NoError(t, rs.DeselectRoutes([]route.NetID{"Unrelated"}, []route.NetID{"Auto", "Skip", "Unrelated"}))
filtered := rs.FilterSelectedExitNodes(routes)
assert.Contains(t, filtered, route.HAUniqueID("Auto|0.0.0.0/0"), "AutoApply route should be included")
assert.NotContains(t, filtered, route.HAUniqueID("Skip|0.0.0.0/0"), "SkipAutoApply route should be excluded without explicit user selection")
}
// TestRouteSelector_V6ExitIsExitNode verifies that ::/0 routes are recognized
// as exit nodes by the selector's filter path.
func TestRouteSelector_V6ExitIsExitNode(t *testing.T) {
v6Exit := &route.Route{
NetID: "V6Only",
Network: netip.MustParsePrefix("::/0"),
SkipAutoApply: true,
}
routes := route.HAMap{
"V6Only|::/0": {v6Exit},
}
rs := routeselector.NewRouteSelector()
filtered := rs.FilterSelectedExitNodes(routes)
assert.Empty(t, filtered, "::/0 should be treated as an exit node and respect SkipAutoApply")
}
func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
initialRoutes := []route.NetID{"route1", "route2", "route3"}
newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"}

View File

@@ -188,7 +188,9 @@ func (d *Detector) triggerCallback(event EventType, cb func(event EventType), do
}
doneChan := make(chan struct{})
timeout := time.NewTimer(500 * time.Millisecond)
// macOS forces sleep ~30s after kIOMessageSystemWillSleep, so block long
// enough for teardown to finish while staying under that deadline.
timeout := time.NewTimer(20 * time.Second)
defer timeout.Stop()
go func() {

View File

@@ -74,6 +74,14 @@ func New(filePath string) *Manager {
}
}
// FilePath returns the path of the underlying state file.
func (m *Manager) FilePath() string {
if m == nil {
return ""
}
return m.filePath
}
// Start starts the state manager periodic save routine
func (m *Manager) Start() {
if m == nil {

View File

@@ -162,11 +162,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
hostDNS := []netip.AddrPort{
netip.MustParseAddrPort("9.9.9.9:53"),
netip.MustParseAddrPort("149.112.112.112:53"),
}
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
}
// Stop the internal client and free the resources

View File

@@ -64,6 +64,13 @@
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
</RegistryKey>
</Component>
<!-- Drop the HKCU Run\Netbird value written by legacy NSIS installers. -->
<Component Id="NetbirdLegacyHKCUCleanup" Guid="*">
<RegistryValue Root="HKCU" Key="Software\NetBird GmbH\Installer"
Name="LegacyHKCUCleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKCU"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
</Component>
</StandardDirectory>
<StandardDirectory Id="CommonAppDataFolder">
@@ -76,10 +83,28 @@
</Directory>
</StandardDirectory>
<!-- Drop Run, App Paths and Uninstall entries written by legacy NSIS
installers into the 32-bit registry view (HKLM\Software\Wow6432Node). -->
<Component Id="NetbirdLegacyWow6432Cleanup" Directory="NetbirdInstallDir"
Guid="bda5d628-16bd-4086-b2c1-5099d8d51763" Bitness="always32">
<RegistryValue Root="HKLM" Key="Software\NetBird GmbH\Installer"
Name="LegacyWow6432Cleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird-ui" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Uninstall\Netbird" />
</Component>
<ComponentGroup Id="NetbirdFilesComponent">
<ComponentRef Id="NetbirdFiles" />
<ComponentRef Id="NetbirdAumidRegistry" />
<ComponentRef Id="NetbirdAutoStart" />
<ComponentRef Id="NetbirdLegacyHKCUCleanup" />
<ComponentRef Id="NetbirdLegacyWow6432Cleanup" />
</ComponentGroup>
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />

File diff suppressed because it is too large Load Diff

View File

@@ -119,6 +119,14 @@ service DaemonService {
// ExposeService exposes a local port via the NetBird reverse proxy
rpc ExposeService(ExposeServiceRequest) returns (stream ExposeServiceEvent) {}
// RespondApproval delivers the user's accept/deny decision for a
// pending user-approval prompt. The daemon pushes the prompt as a
// SystemEvent with category APPROVAL and metadata key "request_id";
// the UI calls this RPC with the same request_id to unblock whichever
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
// the UI which subsystem the prompt belongs to.
rpc RespondApproval(RespondApprovalRequest) returns (RespondApprovalResponse) {}
}
@@ -205,6 +213,10 @@ message LoginRequest {
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
optional bool disable_ipv6 = 40;
optional bool serverVNCAllowed = 41;
optional bool disableVNCApproval = 42;
}
message LoginResponse {
@@ -314,6 +326,10 @@ message GetConfigResponse {
int32 sshJWTCacheTTL = 26;
bool disable_ipv6 = 27;
bool serverVNCAllowed = 28;
bool disableVNCApproval = 29;
}
// PeerState contains the latest state of a peer
@@ -394,6 +410,25 @@ message SSHServerState {
repeated SSHSessionInfo sessions = 2;
}
// VNCSessionInfo contains information about an active VNC session
message VNCSessionInfo {
string remoteAddress = 1;
string mode = 2;
string username = 3;
// userID is the Noise-verified session identity (hashed user ID from
// the ACL session-key entry), empty when auth is disabled.
string userID = 4;
// initiator is the human-readable display name of the dashboard user
// who minted the SessionPubKey, when known.
string initiator = 5;
}
// VNCServerState contains the latest state of the VNC server
message VNCServerState {
bool enabled = 1;
repeated VNCSessionInfo sessions = 2;
}
// FullStatus contains the full state held by the Status instance
message FullStatus {
ManagementState managementState = 1;
@@ -408,6 +443,7 @@ message FullStatus {
bool lazyConnectionEnabled = 9;
SSHServerState sshServerState = 10;
VNCServerState vncServerState = 11;
}
// Networks
@@ -595,6 +631,7 @@ message SystemEvent {
AUTHENTICATION = 2;
CONNECTIVITY = 3;
SYSTEM = 4;
APPROVAL = 5;
}
string id = 1;
@@ -678,6 +715,10 @@ message SetConfigRequest {
optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
optional bool disable_ipv6 = 35;
optional bool serverVNCAllowed = 36;
optional bool disableVNCApproval = 37;
}
message SetConfigResponse{}
@@ -872,3 +913,18 @@ message StartBundleCaptureRequest {
message StartBundleCaptureResponse {}
message StopBundleCaptureRequest {}
message StopBundleCaptureResponse {}
message RespondApprovalRequest {
// request_id matches the SystemEvent metadata key emitted by the daemon
// when a subsystem awaits user approval for an inbound connection.
string request_id = 1;
// accept is true if the user approved the request, false if they
// denied it. A missing or unknown request_id is treated as a no-op.
bool accept = 2;
// view_only signals that the user granted the connection but withheld
// input control. Only meaningful when accept is true; ignored when
// accept is false.
bool view_only = 3;
}
message RespondApprovalResponse {}

View File

@@ -58,6 +58,7 @@ const (
DaemonService_StopCPUProfile_FullMethodName = "/daemon.DaemonService/StopCPUProfile"
DaemonService_GetInstallerResult_FullMethodName = "/daemon.DaemonService/GetInstallerResult"
DaemonService_ExposeService_FullMethodName = "/daemon.DaemonService/ExposeService"
DaemonService_RespondApproval_FullMethodName = "/daemon.DaemonService/RespondApproval"
)
// DaemonServiceClient is the client API for DaemonService service.
@@ -134,6 +135,13 @@ type DaemonServiceClient interface {
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExposeServiceEvent], error)
// RespondApproval delivers the user's accept/deny decision for a
// pending user-approval prompt. The daemon pushes the prompt as a
// SystemEvent with category APPROVAL and metadata key "request_id";
// the UI calls this RPC with the same request_id to unblock whichever
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
// the UI which subsystem the prompt belongs to.
RespondApproval(ctx context.Context, in *RespondApprovalRequest, opts ...grpc.CallOption) (*RespondApprovalResponse, error)
}
type daemonServiceClient struct {
@@ -561,6 +569,16 @@ func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServi
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type DaemonService_ExposeServiceClient = grpc.ServerStreamingClient[ExposeServiceEvent]
func (c *daemonServiceClient) RespondApproval(ctx context.Context, in *RespondApprovalRequest, opts ...grpc.CallOption) (*RespondApprovalResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RespondApprovalResponse)
err := c.cc.Invoke(ctx, DaemonService_RespondApproval_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility.
@@ -635,6 +653,13 @@ type DaemonServiceServer interface {
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error
// RespondApproval delivers the user's accept/deny decision for a
// pending user-approval prompt. The daemon pushes the prompt as a
// SystemEvent with category APPROVAL and metadata key "request_id";
// the UI calls this RPC with the same request_id to unblock whichever
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
// the UI which subsystem the prompt belongs to.
RespondApproval(context.Context, *RespondApprovalRequest) (*RespondApprovalResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -762,6 +787,9 @@ func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *Ins
func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error {
return status.Error(codes.Unimplemented, "method ExposeService not implemented")
}
func (UnimplementedDaemonServiceServer) RespondApproval(context.Context, *RespondApprovalRequest) (*RespondApprovalResponse, error) {
return nil, status.Error(codes.Unimplemented, "method RespondApproval not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
func (UnimplementedDaemonServiceServer) testEmbeddedByValue() {}
@@ -1464,6 +1492,24 @@ func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStr
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type DaemonService_ExposeServiceServer = grpc.ServerStreamingServer[ExposeServiceEvent]
func _DaemonService_RespondApproval_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RespondApprovalRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).RespondApproval(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: DaemonService_RespondApproval_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).RespondApproval(ctx, req.(*RespondApprovalRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1615,6 +1661,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetInstallerResult",
Handler: _DaemonService_GetInstallerResult_Handler,
},
{
MethodName: "RespondApproval",
Handler: _DaemonService_RespondApproval_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -111,7 +111,7 @@ func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.Daemo
return status.Errorf(codes.Internal, "create capture session: %v", err)
}
engine, err := s.claimCapture(sess)
engine, err := s.claimCapture(sess, func() { pw.Close() })
if err != nil {
sess.Stop()
pw.Close()
@@ -190,10 +190,7 @@ func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCap
s.stopBundleCaptureLocked()
s.cleanupBundleCapture()
if s.activeCapture != nil {
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
}
s.evictActiveCaptureLocked()
engine, err := s.getCaptureEngineLocked()
if err != nil {
@@ -304,29 +301,58 @@ func (s *Server) cleanupBundleCapture() {
s.bundleCapture = nil
}
// claimCapture reserves the engine's capture slot for sess. Returns
// FailedPrecondition if another capture is already active.
func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) {
// claimCapture reserves the engine's capture slot for sess. If another
// capture is already running it is evicted: a previous streaming session
// whose gRPC client died and never freed the slot stays stuck otherwise,
// and a bundle capture is just informational state.
func (s *Server) claimCapture(sess *capture.Session, cancel func()) (*internal.Engine, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.activeCapture != nil {
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
}
s.evictActiveCaptureLocked()
engine, err := s.getCaptureEngineLocked()
if err != nil {
return nil, err
}
s.activeCapture = sess
s.activeCaptureCancel = cancel
return engine, nil
}
// evictActiveCaptureLocked tears down whatever capture currently owns
// the engine slot so a fresh claim can succeed. Caller must hold mutex.
func (s *Server) evictActiveCaptureLocked() {
if s.activeCapture == nil {
return
}
if s.bundleCapture != nil && s.bundleCapture.sess == s.activeCapture {
log.Infof("evicting running bundle capture to start a new capture")
s.stopBundleCaptureLocked()
return
}
log.Infof("evicting previous streaming capture to start a new one")
prev := s.activeCapture
cancel := s.activeCaptureCancel
if engine, err := s.getCaptureEngineLocked(); err == nil {
if err := engine.SetCapture(nil); err != nil {
log.Debugf("clear previous capture: %v", err)
}
}
s.activeCapture = nil
s.activeCaptureCancel = nil
prev.Stop()
if cancel != nil {
cancel()
}
}
// releaseCapture clears the active-capture owner if it still matches sess.
func (s *Server) releaseCapture(sess *capture.Session) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.activeCapture == sess {
s.activeCapture = nil
s.activeCaptureCancel = nil
}
}
@@ -341,6 +367,7 @@ func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Eng
log.Debugf("clear capture: %v", err)
}
s.activeCapture = nil
s.activeCaptureCancel = nil
}
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {

View File

@@ -93,8 +93,12 @@ type Server struct {
captureEnabled bool
bundleCapture *bundleCapture
// activeCapture is the session currently installed on the engine; guarded by s.mutex.
activeCapture *capture.Session
networksDisabled bool
activeCapture *capture.Session
// activeCaptureCancel tears down the streaming pipe/cancel for the
// active streaming capture so eviction unblocks the StartCapture RPC
// handler. Nil for bundle captures (they own their own context).
activeCaptureCancel func()
networksDisabled bool
sleepHandler *sleephandler.SleepHandler
@@ -376,6 +380,8 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect
config.ServerSSHAllowed = msg.ServerSSHAllowed
config.ServerVNCAllowed = msg.ServerVNCAllowed
config.DisableVNCApproval = msg.DisableVNCApproval
config.NetworkMonitor = msg.NetworkMonitor
config.DisableClientRoutes = msg.DisableClientRoutes
config.DisableServerRoutes = msg.DisableServerRoutes
@@ -1136,6 +1142,7 @@ func (s *Server) Status(
pbFullStatus := fullStatus.ToProto()
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
pbFullStatus.VncServerState = s.getVNCServerState()
statusResponse.FullStatus = pbFullStatus
}
@@ -1175,6 +1182,38 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
return sshServerState
}
// getVNCServerState retrieves the current VNC server state.
func (s *Server) getVNCServerState() *proto.VNCServerState {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
}
enabled, sessions := engine.GetVNCServerStatus()
pbSessions := make([]*proto.VNCSessionInfo, 0, len(sessions))
for _, sess := range sessions {
pbSessions = append(pbSessions, &proto.VNCSessionInfo{
RemoteAddress: sess.RemoteAddress,
Mode: sess.Mode,
Username: sess.Username,
UserID: sess.UserID,
Initiator: sess.Initiator,
})
}
return &proto.VNCServerState{
Enabled: enabled,
Sessions: pbSessions,
}
}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
func (s *Server) GetPeerSSHHostKey(
ctx context.Context,
@@ -1415,6 +1454,27 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
return nil
}
// RespondApproval relays the user's accept/deny decision for a pending
// approval prompt to the engine's broker. Unknown or already-resolved
// request_ids are silently no-op'd so a slow UI cannot deny a prompt the
// user already handled (or that already timed out).
func (s *Server) RespondApproval(_ context.Context, msg *proto.RespondApprovalRequest) (*proto.RespondApprovalResponse, error) {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "engine not running")
}
if !engine.RespondApproval(msg.GetRequestId(), msg.GetAccept(), msg.GetViewOnly()) {
log.Debugf("approval response for unknown request_id %s", msg.GetRequestId())
}
return &proto.RespondApprovalResponse{}, nil
}
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
@@ -1531,6 +1591,8 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
DisableVNCApproval: cfg.DisableVNCApproval != nil && *cfg.DisableVNCApproval,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,

View File

@@ -58,6 +58,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
serverVNCAllowed := true
disableVNCApproval := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
@@ -83,6 +85,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
ServerVNCAllowed: &serverVNCAllowed,
DisableVNCApproval: &disableVNCApproval,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
@@ -127,6 +131,10 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.NotNil(t, cfg.ServerVNCAllowed)
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
require.NotNil(t, cfg.DisableVNCApproval)
require.Equal(t, disableVNCApproval, *cfg.DisableVNCApproval)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
@@ -179,6 +187,8 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"ServerVNCAllowed": true,
"DisableVNCApproval": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
@@ -240,6 +250,8 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"allow-server-vnc": "ServerVNCAllowed",
"disable-vnc-approval": "DisableVNCApproval",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",

View File

@@ -28,7 +28,7 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"

View File

@@ -23,7 +23,7 @@ import (
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"

View File

@@ -23,7 +23,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
@@ -197,6 +197,14 @@ type Config struct {
// HostKey is the SSH server host key in PEM format
HostKeyPEM []byte
// NetstackNet, when non-nil, makes the SSH server listen via the
// supplied userspace network stack instead of an OS socket.
NetstackNet *netstack.Net
// NetworkValidation, when non-zero, restricts inbound connections to
// peers inside the NetBird overlay defined by this WireGuard address.
NetworkValidation wgaddr.Address
}
// SessionInfo contains information about an active SSH session
@@ -208,12 +216,15 @@ type SessionInfo struct {
PortForwards []string
}
// New creates an SSH server instance with the provided host key and optional JWT configuration
// If jwtConfig is nil, JWT authentication is disabled
// New creates an SSH server instance from the supplied Config. Fields are
// read once at construction; mutating Config afterwards has no effect.
// JWT == nil disables JWT authentication.
func New(config *Config) *Server {
s := &Server{
mu: sync.RWMutex{},
hostKeyPEM: config.HostKeyPEM,
netstackNet: config.NetstackNet,
wgAddress: config.NetworkValidation,
sessions: make(map[sessionKey]*sessionState),
pendingAuthJWT: make(map[authKey]string),
remoteForwardListeners: make(map[forwardKey]net.Listener),
@@ -434,20 +445,6 @@ func (s *Server) buildSessionInfo(state *sessionState) SessionInfo {
return info
}
// SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock()
defer s.mu.Unlock()
s.netstackNet = net
}
// SetNetworkValidation configures network-based connection filtering
func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.mu.Lock()
defer s.mu.Unlock()
s.wgAddress = addr
}
// UpdateSSHAuth updates the SSH fine-grained access control configuration
// This should be called when network map updates include new SSH auth configuration
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {

View File

@@ -131,6 +131,19 @@ type SSHServerStateOutput struct {
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
}
type VNCSessionOutput struct {
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
Mode string `json:"mode" yaml:"mode"`
Username string `json:"username,omitempty" yaml:"username,omitempty"`
UserID string `json:"userID,omitempty" yaml:"userID,omitempty"`
Initiator string `json:"initiator,omitempty" yaml:"initiator,omitempty"`
}
type VNCServerStateOutput struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Sessions []VNCSessionOutput `json:"sessions" yaml:"sessions"`
}
type OutputOverview struct {
Peers PeersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
@@ -153,6 +166,7 @@ type OutputOverview struct {
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
ProfileName string `json:"profileName" yaml:"profileName"`
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
}
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
@@ -173,6 +187,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
relayOverview := mapRelays(pbFullStatus.GetRelays())
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
vncServerOverview := mapVNCServer(pbFullStatus.GetVncServerState())
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
overview := OutputOverview{
@@ -197,6 +212,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
ProfileName: opts.ProfileName,
SSHServerState: sshServerOverview,
VNCServerState: vncServerOverview,
}
if opts.Anonymize {
@@ -271,6 +287,26 @@ func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
}
}
func mapVNCServer(state *proto.VNCServerState) VNCServerStateOutput {
if state == nil {
return VNCServerStateOutput{Sessions: []VNCSessionOutput{}}
}
sessions := make([]VNCSessionOutput, 0, len(state.GetSessions()))
for _, sess := range state.GetSessions() {
sessions = append(sessions, VNCSessionOutput{
RemoteAddress: sess.GetRemoteAddress(),
Mode: sess.GetMode(),
Username: sess.GetUsername(),
UserID: sess.GetUserID(),
Initiator: sess.GetInitiator(),
})
}
return VNCServerStateOutput{
Enabled: state.GetEnabled(),
Sessions: sessions,
}
}
func mapPeers(
peers []*proto.PeerState,
statusFilter string,
@@ -533,6 +569,26 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
}
}
vncServerStatus := "Disabled"
if o.VNCServerState.Enabled {
vncSessionCount := len(o.VNCServerState.Sessions)
if vncSessionCount > 0 {
sessionWord := "session"
if vncSessionCount > 1 {
sessionWord = "sessions"
}
vncServerStatus = fmt.Sprintf("Enabled (%d active %s)", vncSessionCount, sessionWord)
} else {
vncServerStatus = "Enabled"
}
if showSSHSessions && vncSessionCount > 0 {
for _, sess := range o.VNCServerState.Sessions {
vncServerStatus += "\n " + formatVNCSessionLine(sess)
}
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
var forwardingRulesString string
@@ -563,6 +619,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
"SSH Server: %s\n"+
"VNC Server: %s\n"+
"Networks: %s\n"+
"%s"+
"Peers count: %s\n",
@@ -581,6 +638,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
sshServerStatus,
vncServerStatus,
networks,
forwardingRulesString,
peersCountString,
@@ -940,6 +998,26 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *PeerStateDetailOutput) {
}
}
// formatVNCSessionLine renders a single VNC session row for the detailed
// status output. The leading slot identifies the initiator (display name
// when known, hashed UserID otherwise); the post-arrow slot is the OS
// user the session targets and is omitted in attach mode where the
// destination is the current console user (unknown to the daemon).
func formatVNCSessionLine(sess VNCSessionOutput) string {
who := sess.Initiator
if who == "" {
who = sess.UserID
}
prefix := sess.RemoteAddress
if who != "" {
prefix = fmt.Sprintf("%s@%s", who, sess.RemoteAddress)
}
if sess.Username != "" {
return fmt.Sprintf("[%s -> %s] mode=%s", prefix, sess.Username, sess.Mode)
}
return fmt.Sprintf("[%s] mode=%s", prefix, sess.Mode)
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
@@ -960,6 +1038,19 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
overview.Relays.Details[i] = detail
}
anonymizeNSServerGroups(a, overview)
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
anonymizeEvents(a, overview)
anonymizeServerSessions(a, overview)
}
func anonymizeNSServerGroups(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
@@ -971,13 +1062,9 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
}
}
}
}
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
func anonymizeEvents(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, event := range overview.Events {
overview.Events[i].Message = a.AnonymizeString(event.Message)
overview.Events[i].UserMessage = a.AnonymizeString(event.UserMessage)
@@ -986,13 +1073,24 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
event.Metadata[k] = a.AnonymizeString(v)
}
}
}
func anonymizeRemoteAddress(a *anonymize.Anonymizer, addr string) string {
if host, port, err := net.SplitHostPort(addr); err == nil {
return fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
return a.AnonymizeIPString(addr)
}
func anonymizeServerSessions(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, session := range overview.SSHServerState.Sessions {
if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil {
overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
} else {
overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress)
}
overview.SSHServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, session.RemoteAddress)
overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command)
}
for i, sess := range overview.VNCServerState.Sessions {
overview.VNCServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, sess.RemoteAddress)
overview.VNCServerState.Sessions[i].Username = a.AnonymizeString(sess.Username)
overview.VNCServerState.Sessions[i].UserID = a.AnonymizeString(sess.UserID)
overview.VNCServerState.Sessions[i].Initiator = a.AnonymizeString(sess.Initiator)
}
}

View File

@@ -240,6 +240,10 @@ var overview = OutputOverview{
Enabled: false,
Sessions: []SSHSessionOutput{},
},
VNCServerState: VNCServerStateOutput{
Enabled: false,
Sessions: []VNCSessionOutput{},
},
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
@@ -404,6 +408,10 @@ func TestParsingToJSON(t *testing.T) {
"sshServer":{
"enabled":false,
"sessions":[]
},
"vncServer":{
"enabled":false,
"sessions":[]
}
}`
// @formatter:on
@@ -513,6 +521,9 @@ profileName: ""
sshServer:
enabled: false
sessions: []
vncServer:
enabled: false
sessions: []
`
assert.Equal(t, expectedYAML, yaml)
@@ -582,6 +593,7 @@ Interface type: Kernel
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
@@ -607,6 +619,7 @@ Interface type: Kernel
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`

View File

@@ -62,6 +62,7 @@ type Info struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
DisableClientRoutes bool
DisableServerRoutes bool
@@ -83,6 +84,7 @@ type Info struct {
func (i *Info) SetFlags(
rosenpassEnabled, rosenpassPermissive bool,
serverSSHAllowed *bool,
serverVNCAllowed *bool,
disableClientRoutes, disableServerRoutes,
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool,
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
@@ -93,6 +95,9 @@ func (i *Info) SetFlags(
if serverSSHAllowed != nil {
i.ServerSSHAllowed = *serverSSHAllowed
}
if serverVNCAllowed != nil {
i.ServerVNCAllowed = *serverVNCAllowed
}
i.DisableClientRoutes = disableClientRoutes
i.DisableServerRoutes = disableServerRoutes

206
client/ui/approval.go Normal file
View File

@@ -0,0 +1,206 @@
//go:build !(linux && 386)
package main
import (
"context"
"fmt"
"strings"
"time"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/widget"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/approval"
"github.com/netbirdio/netbird/client/proto"
)
// handleApprovalEvent forks a netbird-ui child process to render the
// dialog on its own fyne main loop. Top-level windows opened from a
// background goroutine of the tray process don't render reliably on
// Linux/GTK, so the rest of the UI (settings, login URL, update) uses
// the same fork pattern.
func (s *serviceClient) handleApprovalEvent(ev *proto.SystemEvent) {
if ev == nil || ev.Category != proto.SystemEvent_APPROVAL {
return
}
requestID := ev.Metadata["request_id"]
if requestID == "" {
log.Warnf("approval event missing request_id: %v", ev.Metadata)
return
}
args := []string{
"--approval-request-id=" + requestID,
"--approval-kind=" + ev.Metadata["kind"],
"--approval-initiator=" + ev.Metadata["initiator"],
"--approval-peer-name=" + ev.Metadata["peer_name"],
"--approval-source-ip=" + ev.Metadata["source_ip"],
"--approval-username=" + ev.Metadata["username"],
"--approval-expires-at=" + ev.Metadata["expires_at"],
"--approval-key-fingerprint=" + ev.Metadata["peer_pubkey"],
"--approval-subject=" + ev.UserMessage,
}
go s.eventHandler.runSelfCommand(s.ctx, "approval", args...)
}
// showApprovalUI runs the dialog on the forked process's fyne main loop
// and forwards the user's decision to the daemon via RespondApproval.
func (s *serviceClient) showApprovalUI(req approvalRequest) {
w := s.app.NewWindow(approvalTitle(req.kind))
w.Resize(fyne.NewSize(480, 260))
w.CenterOnScreen()
w.RequestFocus()
var rows []string
if req.initiator != "" {
// The display name comes from the management dashboard and is
// not cryptographically asserted by the connecting client. The
// key fingerprint that follows IS: it's the Noise_IK static
// public key the client just proved possession of. Show both
// so the user can sanity-check that "Alice" is really the
// Alice they trust.
rows = append(rows, "From user: "+req.initiator)
}
if fp := approval.ShortKeyFingerprint(req.keyFingerprint); fp != "" {
rows = append(rows, "Key fp: "+fp)
}
if req.peerName != "" {
rows = append(rows, "Via peer: "+req.peerName)
}
if req.sourceIP != "" && req.sourceIP != req.peerName {
rows = append(rows, "Source IP: "+req.sourceIP)
}
if req.username != "" {
rows = append(rows, "OS user: "+req.username)
}
if len(rows) == 0 {
rows = []string{"Remote: " + req.displayPeer()}
}
body := strings.Join(rows, "\n")
bodyLabel := widget.NewLabel(body)
bodyLabel.Wrapping = fyne.TextWrapWord
countdown := widget.NewLabel("")
deadline := req.deadline()
updateCountdown := func() {
remaining := time.Until(deadline).Round(time.Second)
if remaining < 0 {
remaining = 0
}
countdown.SetText(fmt.Sprintf("Auto-deny in %s", remaining))
}
updateCountdown()
type outcome struct {
accept bool
viewOnly bool
}
decided := make(chan outcome, 1)
decide := func(o outcome) {
select {
case decided <- o:
default:
}
}
allow := widget.NewButton("Allow", func() { decide(outcome{accept: true}) })
allow.Importance = widget.HighImportance
allowView := widget.NewButton("Allow (view only)", func() { decide(outcome{accept: true, viewOnly: true}) })
deny := widget.NewButton("Deny", func() { decide(outcome{accept: false}) })
header := widget.NewLabelWithStyle(req.subject, fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
buttonRow := container.NewGridWithColumns(3, allow, allowView, deny)
info := container.NewVBox(header, widget.NewSeparator(), bodyLabel, widget.NewSeparator(), countdown)
w.SetContent(container.NewPadded(container.NewBorder(nil, buttonRow, nil, nil, info)))
w.SetCloseIntercept(func() { decide(outcome{}) })
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for range ticker.C {
if time.Until(deadline) <= 0 {
decide(outcome{})
return
}
fyne.Do(updateCountdown)
}
}()
go func() {
o := <-decided
s.sendApprovalResponse(req.requestID, o.accept, o.viewOnly)
fyne.Do(func() {
w.Close()
s.app.Quit()
})
}()
w.Show()
}
func (s *serviceClient) sendApprovalResponse(requestID string, accept, viewOnly bool) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Warnf("approval response: get daemon client: %v", err)
return
}
ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout)
defer cancel()
if _, err := conn.RespondApproval(ctx, &proto.RespondApprovalRequest{
RequestId: requestID,
Accept: accept,
ViewOnly: viewOnly,
}); err != nil {
log.Warnf("approval response: %v", err)
}
}
// approvalRequest is the parsed --approval-* CLI args that the forked
// dialog process consumes.
type approvalRequest struct {
requestID string
kind string
initiator string
peerName string
sourceIP string
username string
subject string
expiresAt string
keyFingerprint string
}
func (r approvalRequest) displayPeer() string {
switch {
case r.initiator != "":
return r.initiator
case r.peerName != "":
return r.peerName
case r.sourceIP != "":
return r.sourceIP
default:
return "unknown peer"
}
}
// deadline returns the wall-clock auto-deny moment. Falls back to a short
// local window when the daemon's expires_at is missing/unparsable, so a
// stale value never leaves the dialog open indefinitely.
func (r approvalRequest) deadline() time.Time {
if t, err := time.Parse(time.RFC3339, r.expiresAt); err == nil {
return t
}
return time.Now().Add(13 * time.Second)
}
func approvalTitle(kind string) string {
switch kind {
case "vnc":
return "Allow VNC Connection?"
case "ssh":
return "Allow SSH Connection?"
default:
return "Allow Incoming Connection?"
}
}

View File

@@ -97,13 +97,25 @@ func main() {
showQuickActions: flags.showQuickActions,
showUpdate: flags.showUpdate,
showUpdateVersion: flags.showUpdateVersion,
showApproval: flags.showApproval,
approvalRequest: approvalRequest{
requestID: flags.approvalRequestID,
kind: flags.approvalKind,
initiator: flags.approvalInitiator,
peerName: flags.approvalPeerName,
sourceIP: flags.approvalSourceIP,
username: flags.approvalUsername,
subject: flags.approvalSubject,
expiresAt: flags.approvalExpiresAt,
keyFingerprint: flags.approvalKeyFingerprint,
},
})
// Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set.
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate {
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate || flags.showApproval {
a.Run()
return
}
@@ -140,6 +152,17 @@ type cliFlags struct {
saveLogsInFile bool
showUpdate bool
showUpdateVersion string
showApproval bool
approvalRequestID string
approvalKind string
approvalInitiator string
approvalPeerName string
approvalSourceIP string
approvalUsername string
approvalSubject string
approvalExpiresAt string
approvalKeyFingerprint string
}
// parseFlags reads and returns all needed command-line flags.
@@ -161,6 +184,16 @@ func parseFlags() *cliFlags {
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window")
flag.StringVar(&flags.showUpdateVersion, "update-version", "", "version to update to")
flag.BoolVar(&flags.showApproval, "approval", false, "show inbound-connection approval prompt window")
flag.StringVar(&flags.approvalRequestID, "approval-request-id", "", "approval prompt: daemon-issued request id")
flag.StringVar(&flags.approvalKind, "approval-kind", "", "approval prompt: subsystem kind (vnc, ssh, ...)")
flag.StringVar(&flags.approvalInitiator, "approval-initiator", "", "approval prompt: display name of the user who initiated the connection")
flag.StringVar(&flags.approvalPeerName, "approval-peer-name", "", "approval prompt: remote peer FQDN")
flag.StringVar(&flags.approvalSourceIP, "approval-source-ip", "", "approval prompt: remote source IP")
flag.StringVar(&flags.approvalUsername, "approval-username", "", "approval prompt: requested OS username")
flag.StringVar(&flags.approvalSubject, "approval-subject", "", "approval prompt: human-readable subject line")
flag.StringVar(&flags.approvalExpiresAt, "approval-expires-at", "", "approval prompt: RFC3339 deadline at which the daemon auto-denies")
flag.StringVar(&flags.approvalKeyFingerprint, "approval-key-fingerprint", "", "approval prompt: hex-encoded Noise static pubkey of the connecting client")
flag.Parse()
return &flags
}
@@ -249,6 +282,7 @@ type serviceClient struct {
mQuit *systray.MenuItem
mNetworks *systray.MenuItem
mAllowSSH *systray.MenuItem
mAllowVNC *systray.MenuItem
mAutoConnect *systray.MenuItem
mEnableRosenpass *systray.MenuItem
mLazyConnEnabled *systray.MenuItem
@@ -287,6 +321,8 @@ type serviceClient struct {
sEnableSSHRemotePortForward *widget.Check
sDisableSSHAuth *widget.Check
iSSHJWTCacheTTL *widget.Entry
sServerVNCAllowed *widget.Check
sDisableVNCApproval *widget.Check
// observable settings over corresponding iMngURL and iPreSharedKey values.
managementURL string
@@ -308,6 +344,8 @@ type serviceClient struct {
enableSSHRemotePortForward bool
disableSSHAuth bool
sshJWTCacheTTL int
serverVNCAllowed bool
disableVNCApproval bool
connected bool
daemonVersion string
@@ -355,6 +393,8 @@ type newServiceClientArgs struct {
showQuickActions bool
showUpdate bool
showUpdateVersion string
showApproval bool
approvalRequest approvalRequest
}
// newServiceClient instance constructor
@@ -395,6 +435,8 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
s.showQuickActionsUI()
case args.showUpdate:
s.showUpdateProgress(ctx, args.showUpdateVersion)
case args.showApproval:
s.showApprovalUI(args.approvalRequest)
}
return s
@@ -478,6 +520,8 @@ func (s *serviceClient) showSettingsUI() {
s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil)
s.iSSHJWTCacheTTL = widget.NewEntry()
s.sServerVNCAllowed = widget.NewCheck("Allow embedded VNC server on this peer", nil)
s.sDisableVNCApproval = widget.NewCheck("Skip per-connection approval prompt for VNC", nil)
s.wSettings.SetContent(s.getSettingsForm())
s.wSettings.Resize(fyne.NewSize(600, 400))
@@ -590,7 +634,8 @@ func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool
s.disableServerRoutes != s.sDisableServerRoutes.Checked ||
s.disableIPv6 != s.sDisableIPv6.Checked ||
s.blockLANAccess != s.sBlockLANAccess.Checked ||
s.hasSSHChanges()
s.hasSSHChanges() ||
s.hasVNCChanges()
}
func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) error {
@@ -649,6 +694,8 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
req.EnableSSHLocalPortForwarding = &s.sEnableSSHLocalPortForward.Checked
req.EnableSSHRemotePortForwarding = &s.sEnableSSHRemotePortForward.Checked
req.DisableSSHAuth = &s.sDisableSSHAuth.Checked
req.ServerVNCAllowed = &s.sServerVNCAllowed.Checked
req.DisableVNCApproval = &s.sDisableVNCApproval.Checked
sshJWTCacheTTLText := strings.TrimSpace(s.iSSHJWTCacheTTL.Text)
if sshJWTCacheTTLText != "" {
@@ -709,10 +756,12 @@ func (s *serviceClient) getSettingsForm() fyne.CanvasObject {
connectionForm := s.getConnectionForm()
networkForm := s.getNetworkForm()
sshForm := s.getSSHForm()
vncForm := s.getVNCForm()
tabs := container.NewAppTabs(
container.NewTabItem("Connection", connectionForm),
container.NewTabItem("Network", networkForm),
container.NewTabItem("SSH", sshForm),
container.NewTabItem("VNC", vncForm),
)
saveButton := widget.NewButtonWithIcon("Save", theme.ConfirmIcon(), s.saveSettings)
saveButton.Importance = widget.HighImportance
@@ -753,6 +802,15 @@ func (s *serviceClient) getSSHForm() *widget.Form {
}
}
func (s *serviceClient) getVNCForm() *widget.Form {
return &widget.Form{
Items: []*widget.FormItem{
{Text: "Allow VNC Server", Widget: s.sServerVNCAllowed},
{Text: "Disable Connection Approval Prompt", Widget: s.sDisableVNCApproval},
},
}
}
func (s *serviceClient) hasSSHChanges() bool {
currentSSHJWTCacheTTL := s.sshJWTCacheTTL
if text := strings.TrimSpace(s.iSSHJWTCacheTTL.Text); text != "" {
@@ -771,6 +829,11 @@ func (s *serviceClient) hasSSHChanges() bool {
s.sshJWTCacheTTL != currentSSHJWTCacheTTL
}
func (s *serviceClient) hasVNCChanges() bool {
return s.serverVNCAllowed != s.sServerVNCAllowed.Checked ||
s.disableVNCApproval != s.sDisableVNCApproval.Checked
}
func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
@@ -1045,6 +1108,7 @@ func (s *serviceClient) onTrayReady() {
s.mSettings = systray.AddMenuItem("Settings", disabledMenuDescr)
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
s.mAllowVNC = s.mSettings.AddSubMenuItemCheckbox("Allow VNC", allowVNCMenuDescr, false)
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable Lazy Connections", lazyConnMenuDescr, false)
@@ -1118,6 +1182,7 @@ func (s *serviceClient) onTrayReady() {
s.eventManager = event.NewManager(s.notifier, s.addr)
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
s.eventManager.AddHandler(s.handleApprovalEvent)
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
if event.Category == proto.SystemEvent_SYSTEM {
s.updateExitNodes()
@@ -1353,6 +1418,12 @@ func (s *serviceClient) getSrvConfig() {
if cfg.SSHJWTCacheTTL != nil {
s.sshJWTCacheTTL = *cfg.SSHJWTCacheTTL
}
if cfg.ServerVNCAllowed != nil {
s.serverVNCAllowed = *cfg.ServerVNCAllowed
}
if cfg.DisableVNCApproval != nil {
s.disableVNCApproval = *cfg.DisableVNCApproval
}
if s.showAdvancedSettings {
s.iMngURL.SetText(s.managementURL)
@@ -1393,6 +1464,12 @@ func (s *serviceClient) getSrvConfig() {
if cfg.SSHJWTCacheTTL != nil {
s.iSSHJWTCacheTTL.SetText(strconv.Itoa(*cfg.SSHJWTCacheTTL))
}
if cfg.ServerVNCAllowed != nil {
s.sServerVNCAllowed.SetChecked(*cfg.ServerVNCAllowed)
}
if cfg.DisableVNCApproval != nil {
s.sDisableVNCApproval.SetChecked(*cfg.DisableVNCApproval)
}
}
if s.mNotifications == nil {
@@ -1452,6 +1529,8 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
config.DisableAutoConnect = cfg.DisableAutoConnect
config.ServerSSHAllowed = &cfg.ServerSSHAllowed
config.ServerVNCAllowed = &cfg.ServerVNCAllowed
config.DisableVNCApproval = &cfg.DisableVNCApproval
config.RosenpassEnabled = cfg.RosenpassEnabled
config.RosenpassPermissive = cfg.RosenpassPermissive
config.DisableNotifications = &cfg.DisableNotifications
@@ -1547,6 +1626,12 @@ func (s *serviceClient) loadSettings() {
s.mAllowSSH.Uncheck()
}
if cfg.ServerVNCAllowed {
s.mAllowVNC.Check()
} else {
s.mAllowVNC.Uncheck()
}
if cfg.DisableAutoConnect {
s.mAutoConnect.Uncheck()
} else {
@@ -1586,6 +1671,7 @@ func (s *serviceClient) loadSettings() {
func (s *serviceClient) updateConfig() error {
disableAutoStart := !s.mAutoConnect.Checked()
sshAllowed := s.mAllowSSH.Checked()
vncAllowed := s.mAllowVNC.Checked()
rosenpassEnabled := s.mEnableRosenpass.Checked()
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
blockInbound := s.mBlockInbound.Checked()
@@ -1614,6 +1700,7 @@ func (s *serviceClient) updateConfig() error {
Username: currUser.Username,
DisableAutoConnect: &disableAutoStart,
ServerSSHAllowed: &sshAllowed,
ServerVNCAllowed: &vncAllowed,
RosenpassEnabled: &rosenpassEnabled,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,

View File

@@ -2,6 +2,7 @@ package main
const (
allowSSHMenuDescr = "Allow SSH connections"
allowVNCMenuDescr = "Allow embedded VNC server"
autoConnectMenuDescr = "Connect automatically when the service starts"
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
lazyConnMenuDescr = "[Experimental] Enable lazy connections"

View File

@@ -112,7 +112,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) {
handlers := slices.Clone(e.handlers)
e.mu.Unlock()
if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) && !isV6DefaultRoutePartner(event) {
if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) && !isV6DefaultRoutePartner(event) && event.Category != proto.SystemEvent_APPROVAL {
title := e.getEventTitle(event)
body := event.UserMessage
id := event.Metadata["id"]

View File

@@ -39,6 +39,8 @@ func (h *eventHandler) listen(ctx context.Context) {
h.handleDisconnectClick()
case <-h.client.mAllowSSH.ClickedCh:
h.handleAllowSSHClick()
case <-h.client.mAllowVNC.ClickedCh:
h.handleAllowVNCClick()
case <-h.client.mAutoConnect.ClickedCh:
h.handleAutoConnectClick()
case <-h.client.mEnableRosenpass.ClickedCh:
@@ -134,6 +136,15 @@ func (h *eventHandler) handleAllowSSHClick() {
}
func (h *eventHandler) handleAllowVNCClick() {
h.toggleCheckbox(h.client.mAllowVNC)
if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mAllowVNC) // revert checkbox state on error
log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update VNC settings")
}
}
func (h *eventHandler) handleAutoConnectClick() {
h.toggleCheckbox(h.client.mAutoConnect)
if err := h.updateConfigWithErr(); err != nil {

View File

@@ -193,7 +193,15 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network {
}
func isDefaultRoute(routeRange string) bool {
return routeRange == "0.0.0.0/0" || routeRange == "::/0"
// routeRange is the merged display string from the daemon, e.g. "0.0.0.0/0",
// "::/0", or "0.0.0.0/0, ::/0" when a v4 exit node has a paired v6 entry.
for _, part := range strings.Split(routeRange, ",") {
switch strings.TrimSpace(part) {
case "0.0.0.0/0", "::/0":
return true
}
}
return false
}
func getExitNodeNetworks(routes []*proto.Network) []*proto.Network {

31
client/vnc/ports.go Normal file
View File

@@ -0,0 +1,31 @@
// Package vnc holds shared constants for the NetBird embedded VNC stack
// so non-server consumers (CLI capture, debug tooling) can refer to the
// well-known ports without depending on internal engine packages.
package vnc
// External and internal listen ports for the embedded VNC server.
// ExternalPort is what dashboard / browser clients see; the daemon
// DNATs it to InternalPort, where the in-process VNC server actually
// listens. Both flow over the WireGuard interface. AgentLegacyPort is
// the TCP port the per-session agent used before it switched to Unix
// sockets; kept here so packet captures from older builds still get
// tagged, and so any future on-wire agent variant has a reserved port.
const (
ExternalPort uint16 = 5900
InternalPort uint16 = 25900
AgentLegacyPort uint16 = 15900
)
// WellKnownPorts is the unordered set of ports a packet capture should
// treat as carrying NetBird VNC traffic.
var WellKnownPorts = [...]uint16{ExternalPort, InternalPort, AgentLegacyPort}
// IsWellKnownPort reports whether port matches any of WellKnownPorts.
func IsWellKnownPort(port uint16) bool {
for _, p := range WellKnownPorts {
if port == p {
return true
}
}
return false
}

View File

@@ -0,0 +1,415 @@
//go:build darwin && !ios
package server
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"os"
"os/exec"
"strconv"
"sync"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// darwinAgentManager spawns a per-user VNC agent on demand and keeps it
// alive across multiple client connections within the same console-user
// session. A new agent is spawned the first time a client connects, or
// whenever the console user changes underneath us.
//
// Lifecycle is lazy by design: a daemon that never receives a VNC
// connection never spawns anything. The trade-off versus an eager spawn
// (the Windows model) is that the first VNC client pays the launchctl
// asuser + listen-readiness wait, ~hundreds of milliseconds in practice.
// That cost only repeats on user switch.
type darwinAgentManager struct {
mu sync.Mutex
authToken string
socketPath string
uid uint32
running bool
}
func newDarwinAgentManager(ctx context.Context) *darwinAgentManager {
m := &darwinAgentManager{}
go m.watchConsoleUser(ctx)
return m
}
// agentSocketName is the file name inside the per-uid socket directory
// the agent binds. The directory itself is created and chowned by the
// daemon (see prepareAgentSocketDir) so a non-root local user cannot
// pre-create or symlink the path before the agent listens.
const agentSocketName = "agent.sock"
// watchConsoleUser kills the cached agent whenever the console user
// changes (logout, fast user switch, login window). Without it the daemon
// keeps proxying to an agent whose TCC grant and WindowServer access
// belong to a user who is no longer at the screen, so the new user only
// ever sees the locked-screen wallpaper. Killing the agent breaks the
// loopback TCP that the daemon proxies into, the client disconnects, and
// the next reconnect runs ensure() against the new console uid.
func (m *darwinAgentManager) watchConsoleUser(ctx context.Context) {
t := time.NewTicker(2 * time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
uid, err := consoleUserID()
m.mu.Lock()
if !m.running {
m.mu.Unlock()
continue
}
if err != nil || uid != m.uid {
prev := m.uid
m.killLocked()
m.mu.Unlock()
if err != nil {
log.Infof("console user gone (was uid=%d): %v; agent stopped", prev, err)
} else {
log.Infof("console user changed %d -> %d; agent stopped, will respawn on next connect", prev, uid)
}
continue
}
m.mu.Unlock()
}
}
}
// Resolve spawns or respawns the per-user agent process as needed and
// returns its Unix-socket path, shared token, and the uid the agent was
// spawned under (so the daemon can validate peer credentials before
// dispatching the token). Each call is serialized so concurrent VNC
// clients share the same agent.
func (m *darwinAgentManager) Resolve(ctx context.Context) (string, string, uint32, error) {
consoleUID, err := consoleUserID()
if err != nil {
return "", "", 0, fmt.Errorf("no console user: %w", err)
}
m.mu.Lock()
defer m.mu.Unlock()
if m.running && m.uid == consoleUID && vncAgentRunning() {
return m.socketPath, m.authToken, m.uid, nil
}
m.killLocked()
// Reap stray agents so the new token is the only accepted one.
killAllVNCAgents()
socketDir, err := prepareAgentSocketDir(consoleUID)
if err != nil {
return "", "", 0, fmt.Errorf("prepare agent socket dir: %w", err)
}
socketPath := socketDir + "/" + agentSocketName
if err := os.Remove(socketPath); err != nil && !errors.Is(err, os.ErrNotExist) {
log.Debugf("clear stale agent socket %s: %v", socketPath, err)
}
token, err := generateAuthToken()
if err != nil {
return "", "", 0, fmt.Errorf("generate agent auth token: %w", err)
}
if err := spawnAgentForUser(consoleUID, socketPath, token); err != nil {
return "", "", 0, err
}
if err := waitForAgent(ctx, socketPath, 5*time.Second); err != nil {
killAllVNCAgents()
return "", "", 0, fmt.Errorf("agent did not start listening: %w", err)
}
m.authToken = token
m.socketPath = socketPath
m.uid = consoleUID
m.running = true
log.Infof("spawned VNC agent for console uid=%d on %s", consoleUID, socketPath)
return socketPath, token, consoleUID, nil
}
// agentSocketParentDir is the root the daemon creates (as root, mode 0755)
// to hold per-uid agent-socket subdirectories. Keeping it under
// /var/run/netbird-vnc (rather than /tmp) means a non-root local user
// cannot squat the socket path: only root can create the parent, and
// only the target user (plus root) can write inside the per-uid subdir.
const agentSocketParentDir = "/var/run/netbird-vnc"
// prepareAgentSocketDir creates (and tightens permissions on) a per-uid
// subdirectory the agent will bind its socket inside, returning the
// directory path. The subdirectory is owned by uid with mode 0700, so
// the only writers are the target user and root. The parent is created
// root-owned with mode 0755 if it doesn't already exist. Symlinks at
// the per-uid level are refused (replaced with a fresh directory) to
// avoid a low-priv user redirecting our chown.
func prepareAgentSocketDir(uid uint32) (string, error) {
if err := os.MkdirAll(agentSocketParentDir, 0o755); err != nil {
return "", fmt.Errorf("mkdir %s: %w", agentSocketParentDir, err)
}
// Refuse to use the parent if it's a symlink or not owned by root.
pInfo, err := os.Lstat(agentSocketParentDir)
if err != nil {
return "", fmt.Errorf("lstat %s: %w", agentSocketParentDir, err)
}
if pInfo.Mode()&os.ModeSymlink != 0 {
return "", fmt.Errorf("%s is a symlink", agentSocketParentDir)
}
if st, ok := pInfo.Sys().(*syscall.Stat_t); ok && st.Uid != 0 {
return "", fmt.Errorf("%s not owned by root (uid=%d)", agentSocketParentDir, st.Uid)
}
subdir := fmt.Sprintf("%s/%d", agentSocketParentDir, uid)
// If a leftover entry exists, refuse it unless it's a real dir owned
// by the right uid with strict perms: otherwise remove and recreate
// from scratch under our control. Using os.Lstat (not Stat) so a
// symlink is detected and torn down.
if info, err := os.Lstat(subdir); err == nil {
bad := false
if info.Mode()&os.ModeSymlink != 0 {
bad = true
} else if !info.IsDir() {
bad = true
} else if st, ok := info.Sys().(*syscall.Stat_t); !ok || st.Uid != uid || info.Mode().Perm() != 0o700 {
bad = true
}
if bad {
if err := os.RemoveAll(subdir); err != nil {
return "", fmt.Errorf("remove stale %s: %w", subdir, err)
}
}
}
if err := os.Mkdir(subdir, 0o700); err != nil && !errors.Is(err, os.ErrExist) {
return "", fmt.Errorf("mkdir %s: %w", subdir, err)
}
if err := os.Chmod(subdir, 0o700); err != nil {
return "", fmt.Errorf("chmod %s: %w", subdir, err)
}
if err := os.Chown(subdir, int(uid), -1); err != nil {
return "", fmt.Errorf("chown %s -> uid %d: %w", subdir, uid, err)
}
return subdir, nil
}
// stop terminates the spawned agent, if any. Intended for daemon shutdown.
func (m *darwinAgentManager) stop() {
m.mu.Lock()
defer m.mu.Unlock()
m.killLocked()
}
func (m *darwinAgentManager) killLocked() {
if !m.running {
return
}
killAllVNCAgents()
if m.socketPath != "" {
if err := os.Remove(m.socketPath); err != nil && !errors.Is(err, os.ErrNotExist) {
log.Debugf("remove agent socket %s: %v", m.socketPath, err)
}
}
m.running = false
m.authToken = ""
m.socketPath = ""
m.uid = 0
}
// consoleUserID returns the uid of the user currently sitting at the
// console (the one whose Aqua session is active). Returns
// errNoConsoleUser when nobody is logged in: at the login window
// /dev/console is owned by root.
func consoleUserID() (uint32, error) {
info, err := os.Stat("/dev/console")
if err != nil {
return 0, fmt.Errorf("stat /dev/console: %w", err)
}
st, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return 0, fmt.Errorf("/dev/console stat has unexpected type")
}
if st.Uid == 0 {
return 0, errNoConsoleUser
}
return st.Uid, nil
}
// spawnAgentForUser uses launchctl asuser to start a netbird vnc-agent
// process inside the target user's launchd bootstrap namespace. That is
// the only spawn mode on macOS that gives the child access to the user's
// WindowServer. The agent's stderr is relogged into the daemon log so
// startup failures are not silently lost when the readiness check times
// out.
func spawnAgentForUser(uid uint32, socketPath, token string) error {
exe, err := os.Executable()
if err != nil {
return fmt.Errorf("resolve own executable: %w", err)
}
cmd := exec.Command(
"/bin/launchctl", "asuser", strconv.FormatUint(uint64(uid), 10),
exe, vncAgentSubcommand,
"--socket", socketPath,
// Drop privs inside the agent: launchctl asuser preserves the
// daemon's uid (root), so without this the capture/input/
// encoder paths would run as root for the lifetime of the
// session. validateAgentPeer on the daemon side also relies on
// the agent's effective uid matching consoleUID.
"--target-uid", strconv.FormatUint(uint64(uid), 10),
)
cmd.Env = append(os.Environ(), agentTokenEnvVar+"="+token)
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("agent stderr pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("launchctl asuser: %w", err)
}
go func() {
defer stderr.Close()
relogAgentStream(stderr)
}()
go func() { _ = cmd.Wait() }()
return nil
}
// waitForAgent dials the agent's Unix socket until it answers. Used to
// gate proxy attempts until the spawned process has finished its Start.
func waitForAgent(ctx context.Context, socketPath string, wait time.Duration) error {
var d net.Dialer
deadline := time.Now().Add(wait)
for time.Now().Before(deadline) {
if ctx.Err() != nil {
return ctx.Err()
}
dialCtx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
c, err := d.DialContext(dialCtx, "unix", socketPath)
cancel()
if err == nil {
_ = c.Close()
return nil
}
time.Sleep(100 * time.Millisecond)
}
return fmt.Errorf("timeout dialing %s", socketPath)
}
// vncAgentRunning reports whether any vnc-agent process exists on the
// system. There is at most one agent per machine, so any match is "the"
// agent.
func vncAgentRunning() bool {
pids, err := vncAgentPIDs()
if err != nil {
log.Debugf("scan for vnc-agent: %v", err)
return false
}
return len(pids) > 0
}
// killAllVNCAgents sends SIGTERM to every process whose argv contains
// "vnc-agent", waits briefly for them to exit, and escalates to SIGKILL
// for any that remain. We enumerate kern.proc.all rather than
// kern.proc.uid because launchctl asuser preserves the caller's uid
// (root) on the spawned child, so a uid-scoped filter would never match.
func killAllVNCAgents() {
pids, err := vncAgentPIDs()
if err != nil {
log.Debugf("scan for vnc-agent: %v", err)
return
}
for _, pid := range pids {
_ = syscall.Kill(pid, syscall.SIGTERM)
}
if len(pids) == 0 {
return
}
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
remaining, _ := vncAgentPIDs()
if len(remaining) == 0 {
return
}
time.Sleep(100 * time.Millisecond)
}
leftover, _ := vncAgentPIDs()
for _, pid := range leftover {
_ = syscall.Kill(pid, syscall.SIGKILL)
}
}
// vncAgentPIDs returns the pids of vnc-agent subprocesses spawned from
// this binary. Matches exactly on argv[0] == our own executable path
// AND argv[1] == "vnc-agent" so unrelated processes that happen to have
// the same name elsewhere in argv are not targeted. Skips pid 0 and 1
// defensively.
func vncAgentPIDs() ([]int, error) {
procs, err := unix.SysctlKinfoProcSlice("kern.proc.all")
if err != nil {
return nil, fmt.Errorf("sysctl kern.proc.all: %w", err)
}
ownExe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("resolve own executable: %w", err)
}
var out []int
for i := range procs {
pid := int(procs[i].Proc.P_pid)
if pid <= 1 {
continue
}
argv, err := procArgv(pid)
if err != nil || !argvIsVNCAgent(argv, ownExe) {
continue
}
out = append(out, pid)
}
return out, nil
}
// procArgv reads the kernel's stored argv for pid via the kern.procargs2
// sysctl. Format: 4-byte argc, then argv[0..argc) each NUL-terminated,
// then envp, then padding. We only need argv so we stop after argc.
func procArgv(pid int) ([]string, error) {
raw, err := unix.SysctlRaw("kern.procargs2", pid)
if err != nil {
return nil, err
}
if len(raw) < 4 {
return nil, fmt.Errorf("procargs2 truncated")
}
argc := int(raw[0]) | int(raw[1])<<8 | int(raw[2])<<16 | int(raw[3])<<24
body := raw[4:]
// Skip the executable path (NUL-terminated) and any zero padding that
// follows before argv[0].
end := bytes.IndexByte(body, 0)
if end < 0 {
return nil, fmt.Errorf("procargs2 path unterminated")
}
body = body[end+1:]
for len(body) > 0 && body[0] == 0 {
body = body[1:]
}
args := make([]string, 0, argc)
for i := 0; i < argc; i++ {
end := bytes.IndexByte(body, 0)
if end < 0 {
break
}
args = append(args, string(body[:end]))
body = body[end+1:]
}
return args, nil
}
// argvIsVNCAgent reports whether argv belongs to a vnc-agent subprocess
// spawned from our binary. Requires argv[0] to match ownExe exactly and
// argv[1] to be the vnc-agent subcommand. Matches the spawn shape in
// spawnAgentForUser and rejects anything else.
func argvIsVNCAgent(argv []string, ownExe string) bool {
if len(argv) < 2 || ownExe == "" {
return false
}
return argv[0] == ownExe && argv[1] == vncAgentSubcommand
}

View File

@@ -0,0 +1,305 @@
//go:build darwin || windows
package server
import (
"bufio"
"bytes"
"context"
crand "crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"time"
log "github.com/sirupsen/logrus"
)
// errNoConsoleUser is the sentinel returned by sessionAgent.Resolve when
// the platform has no interactive user to attach a capture agent to (the
// macOS loginwindow state). Mapped to a distinct RFB reject code so the
// browser can show a meaningful message.
var errNoConsoleUser = errors.New("no user logged into console")
// sessionAgent abstracts the per-platform manager that spawns and tracks
// the user-session VNC agent. Resolve returns the agent's Unix-socket
// path, the shared per-spawn token, and the uid the agent was spawned
// under (used to validate peer credentials before the daemon hands the
// token to whoever is on the other end of the socket). Resolve may spawn
// the agent lazily.
type sessionAgent interface {
Resolve(ctx context.Context) (socketPath, token string, peerUID uint32, err error)
}
// prefixConn replays already-consumed header bytes ahead of the proxy
// stream by swapping in a different Reader on the same underlying Conn.
type prefixConn struct {
io.Reader
net.Conn
}
func (p *prefixConn) Read(b []byte) (int, error) { return p.Reader.Read(b) }
// handleServiceConnection runs the connection-header handshake (source
// check, Noise_IK auth) on conn, resolves the right per-session agent
// via sa, and proxies to it. Every accepted connection emits exactly one
// outcome line on the daemon log.
func (s *Server) handleServiceConnection(conn net.Conn, sa sessionAgent) {
start := time.Now()
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
if !s.isAllowedSource(conn.RemoteAddr()) {
connLog.Info("VNC connection rejected: source not allowed")
_ = conn.Close()
return
}
var headerBuf bytes.Buffer
tee := io.TeeReader(conn, &headerBuf)
teeConn := &prefixConn{Reader: tee, Conn: conn}
header, err := s.readConnectionHeader(teeConn)
if err != nil {
connLog.Infof("VNC connection rejected: header read failed: %v", err)
_ = conn.Close()
return
}
authedLog, sessionUserID, ok := s.authorizeSession(conn, header, connLog)
if !ok {
authedLog.Info("VNC connection rejected: auth failed")
return
}
if err := s.registerConnAuth(conn, header); err != nil {
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
authedLog.Warnf("VNC connection rejected: %v", err)
return
}
decision, err := s.gateApproval(conn, header)
if err != nil {
authedLog.Infof("VNC connection rejected: %v", err)
return
}
if decision.ViewOnly {
authedLog.Info("VNC connection approved by user (view-only)")
} else if s.requireApproval {
authedLog.Info("VNC connection approved by user")
}
socketPath, token, peerUID, err := sa.Resolve(s.ctx)
if err != nil {
code := RejectCodeCapturerError
if errors.Is(err, errNoConsoleUser) {
code = RejectCodeNoConsoleUser
}
rejectConnection(conn, codeMessage(code, err.Error()))
authedLog.Warnf("VNC connection rejected: agent unavailable: %v", err)
return
}
var initiator string
if s.authorizer != nil {
initiator = s.authorizer.LookupSessionDisplayName(header.clientStatic)
}
sessionID := s.addSession(ActiveSessionInfo{
RemoteAddress: conn.RemoteAddr().String(),
Mode: modeString(header.mode),
Username: header.username,
UserID: sessionUserID,
Initiator: initiator,
}, conn)
defer s.removeSession(sessionID)
replayConn := &prefixConn{
Reader: io.MultiReader(&headerBuf, conn),
Conn: conn,
}
if err := proxyToAgent(s.ctx, replayConn, socketPath, token, peerUID, decision.ViewOnly, authedLog); err != nil {
rejectConnection(conn, codeMessage(RejectCodeCapturerError, err.Error()))
authedLog.Warnf("VNC connection rejected: agent unreachable: %v", err)
return
}
authedLog.Infof("VNC connection closed (%dms)", time.Since(start).Milliseconds())
}
const (
// agentTokenLen is the size of the random per-spawn token in bytes.
agentTokenLen = 32
// agentTokenEnvVar names the environment variable the daemon uses to
// hand the per-spawn token to the agent child. Out-of-band channels
// like this keep the secret out of the command line, where listings
// such as `ps` or Windows tasklist would expose it.
agentTokenEnvVar = "NB_VNC_AGENT_TOKEN" // #nosec G101 -- env var name, not a credential
// vncAgentSubcommand is the CLI subcommand the daemon invokes to start
// the per-session agent process. Must match cmd.vncAgentCmd.Use in
// client/cmd/vnc_agent.go.
vncAgentSubcommand = "vnc-agent"
)
// generateAuthToken returns a fresh hex-encoded random token for one
// daemon→agent session. The daemon hands this to the spawned agent
// out-of-band (env var on Windows) and verifies it on every connection
// the agent accepts.
func generateAuthToken() (string, error) {
b := make([]byte, agentTokenLen)
if _, err := crand.Read(b); err != nil {
return "", fmt.Errorf("read random: %w", err)
}
return hex.EncodeToString(b), nil
}
// proxyToAgent dials the per-session agent's Unix socket, validates the
// peer's kernel-asserted uid (so the daemon never hands its per-spawn
// token to an impostor that won the listen race), writes the raw token
// bytes plus a single view-only flag byte, then copies bytes both ways
// until either side closes. The token + flag prefix must precede any RFB
// byte so the agent's verifyAgentToken can run first. Returns nil once a
// stream is established; the caller is responsible for sending an
// RFB-level rejection on error so the client sees a reason instead of a
// bare timeout. authedLog receives one audit line per dispatched
// preamble so an operator can correlate daemon→agent traffic with the
// remote session that triggered it.
func proxyToAgent(ctx context.Context, client net.Conn, socketPath, authToken string, peerUID uint32, viewOnly bool, authedLog *log.Entry) error {
tokenBytes, err := hex.DecodeString(authToken)
if err != nil || len(tokenBytes) != agentTokenLen {
return fmt.Errorf("invalid auth token (len=%d): %w", len(tokenBytes), err)
}
agentConn, err := dialAgentWithRetry(ctx, socketPath)
if err != nil {
return fmt.Errorf("dial agent at %s: %w", socketPath, err)
}
if err := validateAgentPeer(agentConn, peerUID); err != nil {
_ = agentConn.Close()
return fmt.Errorf("agent peer validation failed: %w", err)
}
preamble := make([]byte, len(tokenBytes)+1)
copy(preamble, tokenBytes)
if viewOnly {
preamble[len(tokenBytes)] = 1
}
if _, err := agentConn.Write(preamble); err != nil {
_ = agentConn.Close()
return fmt.Errorf("send auth preamble to agent: %w", err)
}
// Audit: one line per successfully-dispatched daemon→agent preamble.
// Token printed as its first 8 hex chars (enough to correlate, not
// enough to use). Kept at Info so the default deployment captures it.
tokenFp := authToken
if len(tokenFp) > 8 {
tokenFp = tokenFp[:8]
}
if authedLog != nil {
authedLog.Infof("VNC IPC: dispatched preamble to agent socket=%s peer_uid=%d view_only=%v token_fp=%s", socketPath, peerUID, viewOnly, tokenFp)
}
defer client.Close()
defer agentConn.Close()
log.Debugf("proxy connected to agent, starting bidirectional copy")
done := make(chan struct{}, 2)
cp := func(label string, dst, src net.Conn) {
n, err := io.Copy(dst, src)
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
done <- struct{}{}
}
go cp("client->agent", agentConn, client)
go cp("agent->client", client, agentConn)
<-done
return nil
}
// relogAgentStream reads log lines from the agent's stderr and re-emits
// them through the daemon's logrus, so the merged log keeps a single
// format. JSON lines (the agent's normal output) are parsed and dispatched
// by level; plain-text lines (cobra errors, panic traces) are forwarded
// verbatim so early-startup failures stay visible.
func relogAgentStream(r io.Reader) {
entry := log.WithField("component", "vnc-agent")
scanner := bufio.NewScanner(r)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
continue
}
if line[0] != '{' {
entry.Warn(string(line))
continue
}
var m map[string]any
if err := json.Unmarshal(line, &m); err != nil {
entry.Warn(string(line))
continue
}
msg, _ := m["msg"].(string)
if msg == "" {
continue
}
fields := make(log.Fields)
for k, v := range m {
switch k {
case "msg", "level", "time", "func":
continue
case "caller":
fields["source"] = v
default:
fields[k] = v
}
}
e := entry.WithFields(fields)
switch m["level"] {
case "error":
e.Error(msg)
case "warning":
e.Warn(msg)
case "debug":
e.Debug(msg)
case "trace":
e.Trace(msg)
default:
e.Info(msg)
}
}
}
// dialAgentWithRetry retries the loopback connect for up to ~10 s so the
// daemon does not race the agent's first listen. Returns the live conn or
// the final error. Aborts early when ctx is cancelled so a Stop() during
// service-mode startup doesn't leave a goroutine sleeping for 10 s.
func dialAgentWithRetry(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer
var lastErr error
for range 50 {
if err := ctx.Err(); err != nil {
if lastErr == nil {
lastErr = err
}
return nil, lastErr
}
dialCtx, cancel := context.WithTimeout(ctx, time.Second)
c, err := d.DialContext(dialCtx, "unix", addr)
cancel()
if err == nil {
return c, nil
}
lastErr = err
select {
case <-ctx.Done():
if errors.Is(lastErr, context.Canceled) || errors.Is(lastErr, context.DeadlineExceeded) {
lastErr = ctx.Err()
}
return nil, lastErr
case <-time.After(200 * time.Millisecond):
}
}
return nil, lastErr
}

View File

@@ -0,0 +1,46 @@
//go:build darwin && !ios
package server
import (
"fmt"
"net"
"golang.org/x/sys/unix"
)
// validateAgentPeer enforces that the peer behind the just-connected Unix
// socket is the agent we expect it to be: a process running under
// expectedUID, with the right effective uid stamped by the kernel on the
// socket. Refuses (with a non-nil error) if anything else is listening on
// the path (an unrelated local process that won the listen race or
// squatted the path before us). Defends against the daemon shipping its
// per-spawn auth token to a process that isn't the spawned agent.
func validateAgentPeer(conn net.Conn, expectedUID uint32) error {
uconn, ok := conn.(*net.UnixConn)
if !ok {
return fmt.Errorf("peer cred: expected *net.UnixConn, got %T", conn)
}
raw, err := uconn.SyscallConn()
if err != nil {
return fmt.Errorf("peer cred: syscall conn: %w", err)
}
var cred *unix.Xucred
var inner error
ctlErr := raw.Control(func(fd uintptr) {
cred, inner = unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
})
if ctlErr != nil {
return fmt.Errorf("peer cred: control: %w", ctlErr)
}
if inner != nil {
return fmt.Errorf("peer cred: getsockopt LOCAL_PEERCRED: %w", inner)
}
if cred == nil {
return fmt.Errorf("peer cred: nil xucred")
}
if cred.Uid != expectedUID {
return fmt.Errorf("peer cred: agent uid %d does not match expected %d", cred.Uid, expectedUID)
}
return nil
}

View File

@@ -0,0 +1,115 @@
//go:build darwin && !ios
package server
import (
"net"
"os"
"path/filepath"
"strings"
"sync"
"testing"
)
// TestValidateAgentPeerAcceptsOwnUID confirms the happy path: a Unix
// socket whose peer is the current process must validate when the
// expected uid matches the process's own. Both sides of a unix-socket
// pair share the same kernel cred, so this exercises the real getsockopt
// LOCAL_PEERCRED path.
func TestValidateAgentPeerAcceptsOwnUID(t *testing.T) {
dir := t.TempDir()
sockPath := filepath.Join(dir, "test.sock")
ln, err := net.Listen("unix", sockPath)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := ln.Accept()
if err == nil {
_ = c.Close()
}
}()
c, err := net.Dial("unix", sockPath)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.Close()
if err := validateAgentPeer(c, uint32(os.Getuid())); err != nil {
t.Fatalf("validateAgentPeer rejected own uid: %v", err)
}
wg.Wait()
}
// TestValidateAgentPeerRejectsWrongUID ensures the validator fails when
// the expected uid differs from the kernel-reported peer uid. This is
// the path that catches a hostile process that won the listen race.
func TestValidateAgentPeerRejectsWrongUID(t *testing.T) {
dir := t.TempDir()
sockPath := filepath.Join(dir, "test.sock")
ln, err := net.Listen("unix", sockPath)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := ln.Accept()
if err == nil {
_ = c.Close()
}
}()
c, err := net.Dial("unix", sockPath)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.Close()
// Pick a uid the test process certainly isn't running as.
wrongUID := uint32(os.Getuid()) + 1
err = validateAgentPeer(c, wrongUID)
if err == nil {
t.Fatal("expected mismatch error, got nil")
}
if !strings.Contains(err.Error(), "does not match expected") {
t.Fatalf("error should mention uid mismatch, got: %v", err)
}
wg.Wait()
}
// TestValidateAgentPeerRejectsNonUnix protects against being handed a
// non-Unix-socket connection (the validator can't enforce anything on
// e.g. a *net.TCPConn so it must refuse rather than silently pass).
func TestValidateAgentPeerRejectsNonUnix(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp: %v", err)
}
defer ln.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := ln.Accept()
if err == nil {
_ = c.Close()
}
}()
c, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial tcp: %v", err)
}
defer c.Close()
if err := validateAgentPeer(c, 0); err == nil {
t.Fatal("expected refusal on non-unix conn, got nil")
}
wg.Wait()
}

View File

@@ -0,0 +1,19 @@
//go:build windows
package server
import (
"net"
)
// validateAgentPeer is a best-effort no-op on Windows: AF_UNIX sockets on
// Windows do not expose SO_PEERCRED equivalents, and both the daemon and
// the spawned agent run as SYSTEM in distinct sessions. The remaining
// trust comes from the location of the socket file (under
// C:\Windows\Temp, writable only by SYSTEM/Administrators) and from the
// per-spawn auth token preamble that follows this call. Documented as a
// known gap; a future hardening pass could interrogate the connected
// pipe's PID via process-token APIs.
func validateAgentPeer(_ net.Conn, _ uint32) error {
return nil
}

View File

@@ -0,0 +1,628 @@
//go:build windows
package server
import (
"context"
"encoding/binary"
"errors"
"fmt"
"os"
"runtime"
"sync"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
stillActive = 259
tokenPrimary = 1
securityImpersonation = 2
tokenSessionID = 12
createUnicodeEnvironment = 0x00000400
createNoWindow = 0x08000000
createSuspended = 0x00000004
createBreakawayFromJob = 0x01000000
)
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
userenv = windows.NewLazySystemDLL("userenv.dll")
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
procCreateJobObjectW = kernel32.NewProc("CreateJobObjectW")
procSetInformationJobObject = kernel32.NewProc("SetInformationJobObject")
procAssignProcessToJobObject = kernel32.NewProc("AssignProcessToJobObject")
procSetTokenInformation = advapi32.NewProc("SetTokenInformation")
procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock")
procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock")
wtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
procWTSEnumerateSessionsW = wtsapi32.NewProc("WTSEnumerateSessionsW")
procWTSFreeMemory = wtsapi32.NewProc("WTSFreeMemory")
procWTSQuerySessionInformation = wtsapi32.NewProc("WTSQuerySessionInformationW")
)
// GetCurrentSessionID returns the session ID of the current process.
func GetCurrentSessionID() uint32 {
var token windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.TOKEN_QUERY, &token); err != nil {
return 0
}
defer token.Close()
var id uint32
var ret uint32
_ = windows.GetTokenInformation(token, windows.TokenSessionId,
(*byte)(unsafe.Pointer(&id)), 4, &ret)
return id
}
func getConsoleSessionID() uint32 {
r, _, _ := procWTSGetActiveConsoleSessionId.Call()
return uint32(r)
}
const (
wtsActive = 0
wtsConnected = 1
wtsDisconnected = 4
)
// getActiveSessionID returns the session ID of the best session to attach to.
// On a Windows Server with no console display attached, session 1 still
// reports WTSActive (login screen "owns" the console), so a naive
// first-active-wins pick lands on a session with no actual rendering.
// Preference order:
// 1. Active session with a user logged in (RDP user in session ≥2)
// 2. Active session without a user (console at login screen)
// 3. Console session ID
func getActiveSessionID() uint32 {
var sessionInfo uintptr
var count uint32
r, _, _ := procWTSEnumerateSessionsW.Call(
0, // WTS_CURRENT_SERVER_HANDLE
0, // reserved
1, // version
uintptr(unsafe.Pointer(&sessionInfo)),
uintptr(unsafe.Pointer(&count)),
)
if r == 0 || count == 0 {
return getConsoleSessionID()
}
defer func() { _, _, _ = procWTSFreeMemory.Call(sessionInfo) }()
type wtsSession struct {
SessionID uint32
Station *uint16
State uint32
}
sessions := unsafe.Slice((*wtsSession)(unsafe.Pointer(sessionInfo)), count)
var withUser uint32
var withUserFound bool
var anyActive uint32
var anyActiveFound bool
for _, s := range sessions {
if s.SessionID == 0 {
continue
}
if s.State != wtsActive {
continue
}
if !anyActiveFound {
anyActive = s.SessionID
anyActiveFound = true
}
if !withUserFound && wtsSessionHasUser(s.SessionID) {
withUser = s.SessionID
withUserFound = true
}
}
if withUserFound {
return withUser
}
if anyActiveFound {
return anyActive
}
return getConsoleSessionID()
}
// wtsSessionHasUser returns true if the session has a non-empty user name,
// i.e. someone is logged in (vs. the login/Welcome screen). The console
// session at the lock screen has WTSUserName == "".
const wtsUserName = 5
func wtsSessionHasUser(sessionID uint32) bool {
var buf uintptr
var bytesReturned uint32
r, _, _ := procWTSQuerySessionInformation.Call(
0, // WTS_CURRENT_SERVER_HANDLE
uintptr(sessionID),
uintptr(wtsUserName),
uintptr(unsafe.Pointer(&buf)),
uintptr(unsafe.Pointer(&bytesReturned)),
)
if r == 0 || buf == 0 {
return false
}
defer func() { _, _, _ = procWTSFreeMemory.Call(buf) }()
// First UTF-16 code unit non-zero ⇒ non-empty username.
return *(*uint16)(unsafe.Pointer(buf)) != 0
}
// getSystemTokenForSession duplicates the current SYSTEM token and sets its
// session ID so the spawned process runs in the target session. Using a SYSTEM
// token gives access to both Default and Winlogon desktops plus UIPI bypass.
func getSystemTokenForSession(sessionID uint32) (windows.Token, error) {
var cur windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.MAXIMUM_ALLOWED, &cur); err != nil {
return 0, fmt.Errorf("OpenProcessToken: %w", err)
}
defer cur.Close()
var dup windows.Token
if err := windows.DuplicateTokenEx(cur, windows.MAXIMUM_ALLOWED, nil,
securityImpersonation, tokenPrimary, &dup); err != nil {
return 0, fmt.Errorf("DuplicateTokenEx: %w", err)
}
sid := sessionID
r, _, err := procSetTokenInformation.Call(
uintptr(dup),
uintptr(tokenSessionID),
uintptr(unsafe.Pointer(&sid)),
unsafe.Sizeof(sid),
)
if r == 0 {
dup.Close()
return 0, fmt.Errorf("SetTokenInformation(SessionId=%d): %w", sessionID, err)
}
return dup, nil
}
// injectEnvVar appends a KEY=VALUE entry to a Unicode environment block.
// The block is a sequence of null-terminated UTF-16 strings, terminated by
// an extra null. Returns the new []uint16 backing slice; the caller must
// hold the returned slice alive until CreateProcessAsUser completes.
func injectEnvVar(envBlock uintptr, key, value string) []uint16 {
entry := key + "=" + value
// Walk the existing block to find its total length.
ptr := (*uint16)(unsafe.Pointer(envBlock))
var totalChars int
for {
ch := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars)*2))
if ch == 0 {
// Check for double-null terminator.
next := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars+1)*2))
totalChars++
if next == 0 {
// End of block (don't count the final null yet, we'll rebuild).
break
}
} else {
totalChars++
}
}
entryUTF16, _ := windows.UTF16FromString(entry)
// New block: existing entries + new entry (null-terminated) + final null.
newLen := totalChars + len(entryUTF16) + 1
newBlock := make([]uint16, newLen)
// Copy existing entries (up to but not including the final null).
for i := range totalChars {
newBlock[i] = *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(i)*2))
}
copy(newBlock[totalChars:], entryUTF16)
newBlock[newLen-1] = 0 // final null terminator
return newBlock
}
func spawnAgentInSession(sessionID uint32, socketPath, authToken string, jobHandle windows.Handle) (windows.Handle, error) {
token, err := getSystemTokenForSession(sessionID)
if err != nil {
return 0, fmt.Errorf("get SYSTEM token for session %d: %w", sessionID, err)
}
defer token.Close()
var envBlock uintptr
r, _, e := procCreateEnvironmentBlock.Call(
uintptr(unsafe.Pointer(&envBlock)),
uintptr(token),
0,
)
if r == 0 {
// Without an environment block we cannot inject NB_VNC_AGENT_TOKEN;
// the agent would start unauthenticated. Abort instead of launching.
return 0, fmt.Errorf("CreateEnvironmentBlock: %w", e)
}
defer func() { _, _, _ = procDestroyEnvironmentBlock.Call(envBlock) }()
// Inject the auth token into the environment block so it doesn't appear
// in the process command line (visible via tasklist/wmic). injectedBlock
// must stay alive until CreateProcessAsUser returns.
injectedBlock := injectEnvVar(envBlock, agentTokenEnvVar, authToken)
exePath, err := os.Executable()
if err != nil {
return 0, fmt.Errorf("get executable path: %w", err)
}
cmdLine := fmt.Sprintf(`"%s" %s --socket %q`, exePath, vncAgentSubcommand, socketPath)
cmdLineW, err := windows.UTF16PtrFromString(cmdLine)
if err != nil {
return 0, fmt.Errorf("UTF16 cmdline: %w", err)
}
// Create an inheritable pipe for the agent's stderr so we can relog
// its output in the service process.
var sa windows.SecurityAttributes
sa.Length = uint32(unsafe.Sizeof(sa))
sa.InheritHandle = 1
var stderrRead, stderrWrite windows.Handle
if err := windows.CreatePipe(&stderrRead, &stderrWrite, &sa, 0); err != nil {
return 0, fmt.Errorf("create stderr pipe: %w", err)
}
// The read end must NOT be inherited by the child.
_ = windows.SetHandleInformation(stderrRead, windows.HANDLE_FLAG_INHERIT, 0)
desktop, _ := windows.UTF16PtrFromString(`WinSta0\Default`)
si := windows.StartupInfo{
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
Desktop: desktop,
Flags: windows.STARTF_USESHOWWINDOW | windows.STARTF_USESTDHANDLES,
ShowWindow: 0,
StdErr: stderrWrite,
StdOutput: stderrWrite,
}
var pi windows.ProcessInformation
var envPtr *uint16
if len(injectedBlock) > 0 {
envPtr = &injectedBlock[0]
} else if envBlock != 0 {
envPtr = (*uint16)(unsafe.Pointer(envBlock))
}
// CREATE_SUSPENDED so we can assign the process to our Job Object
// before it executes. Without this the agent could spawn its own child
// processes and have them inherit the SCM service-job (not ours), or
// briefly listen on the agent port before we tear it down on rollback.
// CREATE_BREAKAWAY_FROM_JOB lets the child leave the SCM-managed
// service job; harmless if that job allows breakaway, and is required
// before AssignProcessToJobObject can succeed in the no-nested-jobs case.
err = windows.CreateProcessAsUser(
token, nil, cmdLineW,
nil, nil, true, // inheritHandles=true for the pipe
createUnicodeEnvironment|createNoWindow|createSuspended|createBreakawayFromJob,
envPtr, nil, &si, &pi,
)
runtime.KeepAlive(injectedBlock)
// Close the write end in the parent so reads will get EOF when the child exits.
_ = windows.CloseHandle(stderrWrite)
if err != nil {
_ = windows.CloseHandle(stderrRead)
return 0, fmt.Errorf("CreateProcessAsUser: %w", err)
}
if jobHandle != 0 {
r, _, e := procAssignProcessToJobObject.Call(uintptr(jobHandle), uintptr(pi.Process))
if r == 0 {
log.Warnf("assign agent to job object: %v (orphan possible on service crash)", e)
}
}
if _, err := windows.ResumeThread(pi.Thread); err != nil {
_ = windows.CloseHandle(pi.Thread)
_ = windows.TerminateProcess(pi.Process, 1)
_ = windows.CloseHandle(pi.Process)
_ = windows.CloseHandle(stderrRead)
return 0, fmt.Errorf("ResumeThread: %w", err)
}
_ = windows.CloseHandle(pi.Thread)
// Relog agent output in the service with a [vnc-agent] prefix.
go relogAgentOutput(stderrRead)
log.Infof("spawned agent PID=%d in session %d on %s", pi.ProcessId, sessionID, socketPath)
return pi.Process, nil
}
// sessionManager monitors the active console session and ensures a VNC agent
// process is running in it. When the session changes (e.g., user switch, RDP
// connect/disconnect), it kills the old agent and spawns a new one. Each
// spawn picks a per-session Unix-socket path the agent binds and the
// daemon dials over local IPC.
type sessionManager struct {
mu sync.Mutex
agentProc windows.Handle
everSpawned bool
agentStartedAt time.Time
spawnFailures int
nextSpawnAt time.Time
sessionID uint32
authToken string
socketPath string
done chan struct{}
// jobHandle owns the agent processes via a Windows Job Object with
// JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE. When the service exits or crashes,
// the OS closes the handle and terminates every assigned agent: no
// orphaned agent processes holding a socket across restarts.
jobHandle windows.Handle
}
// agentSocketPathFmt parameterizes the per-session agent socket path by
// the Windows session id. C:\Windows\Temp is writable to both the daemon
// (SYSTEM) and the spawned agent (SYSTEM token impersonating the session).
const agentSocketPathFmt = `C:\Windows\Temp\netbird-vnc-%d.sock`
func newSessionManager() *sessionManager {
m := &sessionManager{sessionID: ^uint32(0), done: make(chan struct{})}
if h, err := createKillOnCloseJob(); err != nil {
log.Warnf("create job object for vnc-agent (orphan agents possible after crash): %v", err)
} else {
m.jobHandle = h
}
return m
}
// createKillOnCloseJob returns a Job Object configured so that closing its
// handle (process exit or explicit Close) terminates every process assigned
// to it. Used to keep orphaned vnc-agent processes from outliving the service.
func createKillOnCloseJob() (windows.Handle, error) {
r, _, e := procCreateJobObjectW.Call(0, 0)
if r == 0 {
return 0, fmt.Errorf("CreateJobObject: %w", e)
}
job := windows.Handle(r)
// JOBOBJECT_EXTENDED_LIMIT_INFORMATION on amd64 = 144 bytes.
//
// JOBOBJECT_BASIC_LIMIT_INFORMATION (64 bytes with alignment padding)
// PerProcessUserTimeLimit LARGE_INTEGER off 0
// PerJobUserTimeLimit LARGE_INTEGER off 8
// LimitFlags DWORD off 16
// [4 byte pad to align SIZE_T]
// MinimumWorkingSetSize SIZE_T off 24
// MaximumWorkingSetSize SIZE_T off 32
// ActiveProcessLimit DWORD off 40
// [4 byte pad to align ULONG_PTR]
// Affinity ULONG_PTR off 48
// PriorityClass DWORD off 56
// SchedulingClass DWORD off 60
// IO_COUNTERS (48) + 4 * SIZE_T (32) = 144 total.
//
// We only set LimitFlags; the rest stays zero.
const sizeofExtended = 144
const offsetLimitFlags = 16
const jobObjectExtendedLimitInformation = 9
const jobObjectLimitKillOnJobClose = 0x00002000
var info [sizeofExtended]byte
binary.LittleEndian.PutUint32(info[offsetLimitFlags:offsetLimitFlags+4], jobObjectLimitKillOnJobClose)
r, _, e = procSetInformationJobObject.Call(
uintptr(job),
uintptr(jobObjectExtendedLimitInformation),
uintptr(unsafe.Pointer(&info[0])),
uintptr(sizeofExtended),
)
if r == 0 {
_ = windows.CloseHandle(job)
return 0, fmt.Errorf("SetInformationJobObject(KILL_ON_JOB_CLOSE): %w", e)
}
return job, nil
}
// Resolve returns the current agent socket path, shared token, and the
// uid the agent runs under (0 on Windows since the agent runs as
// SYSTEM in the interactive session; validateAgentPeer is a no-op
// there). When no agent is spawned yet (initial boot, between session
// switches, or permanently disabled when SE_TCB_NAME is missing) it
// surfaces a distinct error so the daemon can reject the connection
// with a meaningful message instead of timing out the proxy dial.
func (m *sessionManager) Resolve(_ context.Context) (string, string, uint32, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.socketPath == "" {
return "", "", 0, errAgentNotReady
}
return m.socketPath, m.authToken, 0, nil
}
var errAgentNotReady = errors.New("VNC agent not running yet")
// Stop signals the session manager to exit its polling loop and closes the
// Job Object handle, which Windows uses as the trigger to terminate every
// agent process this manager spawned.
func (m *sessionManager) Stop() {
select {
case <-m.done:
default:
close(m.done)
}
m.mu.Lock()
if m.jobHandle != 0 {
_ = windows.CloseHandle(m.jobHandle)
m.jobHandle = 0
}
m.mu.Unlock()
}
func (m *sessionManager) run() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
if !m.tick() {
return
}
select {
case <-m.done:
m.mu.Lock()
m.killAgent()
m.mu.Unlock()
return
case <-ticker.C:
}
}
}
// tick performs one session/agent-state update. Returns false if the manager
// should permanently stop (e.g. missing SYSTEM privileges).
func (m *sessionManager) tick() bool {
sid := getActiveSessionID()
m.mu.Lock()
defer m.mu.Unlock()
m.handleSessionChange(sid)
m.reapExitedAgent()
return m.maybeSpawnAgent(sid)
}
func (m *sessionManager) handleSessionChange(sid uint32) {
if sid == m.sessionID {
return
}
log.Infof("active session changed: %d -> %d", m.sessionID, sid)
m.killAgent()
m.sessionID = sid
}
func (m *sessionManager) reapExitedAgent() {
if m.agentProc == 0 {
return
}
var code uint32
if err := windows.GetExitCodeProcess(m.agentProc, &code); err != nil {
log.Debugf("GetExitCodeProcess: %v", err)
return
}
if code == stillActive {
return
}
m.scheduleNextSpawn(code, time.Since(m.agentStartedAt))
if err := windows.CloseHandle(m.agentProc); err != nil {
log.Debugf("close agent handle: %v", err)
}
m.agentProc = 0
m.authToken = ""
m.socketPath = ""
}
// scheduleNextSpawn applies an exponential backoff on fast crashes (<5s) and
// resets immediately otherwise.
func (m *sessionManager) scheduleNextSpawn(exitCode uint32, lifetime time.Duration) {
if lifetime < 5*time.Second {
m.spawnFailures++
backoff := time.Duration(1<<min(m.spawnFailures, 5)) * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
m.nextSpawnAt = time.Now().Add(backoff)
log.Warnf("agent exited (code=%d) after %v, retrying in %v (failures=%d)", exitCode, lifetime.Round(time.Millisecond), backoff, m.spawnFailures)
return
}
m.spawnFailures = 0
m.nextSpawnAt = time.Time{}
log.Infof("agent exited (code=%d) after %v, respawning", exitCode, lifetime.Round(time.Second))
}
// maybeSpawnAgent spawns a new agent if there's no current one and the backoff
// window has elapsed. Returns false to permanently stop the manager when the
// service lacks the privileges needed to spawn cross-session.
func (m *sessionManager) maybeSpawnAgent(sid uint32) bool {
if m.agentProc != 0 || sid == 0xFFFFFFFF || !time.Now().After(m.nextSpawnAt) {
return true
}
// Reap any orphan still holding the agent port from a previous
// service instance, only on our very first spawn. Once we own
// an agent, we manage its lifecycle ourselves and never need to
// kill an unknown listener; if a kill+respawn races on port
// release, the spawn-failure backoff handles it without forcing
// a synchronous wait or duplicate kill.
socketPath := fmt.Sprintf(agentSocketPathFmt, sid)
// Covers a previous-run crash that escaped Job Object kill-on-close.
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
log.Debugf("clear stale agent socket %s: %v", socketPath, err)
}
token, err := generateAuthToken()
if err != nil {
log.Warnf("generate agent auth token: %v", err)
return true
}
m.authToken = token
m.socketPath = socketPath
h, err := spawnAgentInSession(sid, socketPath, m.authToken, m.jobHandle)
if err != nil {
m.authToken = ""
m.socketPath = ""
if errors.Is(err, windows.ERROR_PRIVILEGE_NOT_HELD) {
// SE_TCB_NAME (token-impersonation across sessions) is only
// granted to SYSTEM. Without it spawnAgent will fail every 2
// seconds forever: log once and give up.
log.Warnf("VNC service mode disabled: agent spawn requires SYSTEM privileges (got: %v)", err)
return false
}
log.Warnf("spawn agent in session %d: %v", sid, err)
return true
}
m.agentProc = h
m.agentStartedAt = time.Now()
m.everSpawned = true
return true
}
func (m *sessionManager) killAgent() {
if m.agentProc == 0 {
return
}
_ = windows.TerminateProcess(m.agentProc, 0)
_ = windows.CloseHandle(m.agentProc)
m.agentProc = 0
m.authToken = ""
m.socketPath = ""
log.Info("killed old agent")
}
// relogAgentOutput reads log lines from the agent's stderr pipe and
// relogs them with the service's formatter. The *os.File owns the
// underlying handle, so closing it suffices.
func relogAgentOutput(pipe windows.Handle) {
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
defer func() { _ = f.Close() }()
relogAgentStream(f)
}
// logCleanupCall invokes a Windows syscall used solely as a cleanup primitive
// (CloseClipboard, ReleaseDC, etc.) and logs failures at trace level. The
// indirection lets us satisfy errcheck without scattering ignored returns at
// each call site, while still capturing diagnostic info when the OS reports
// a failure.
func logCleanupCall(name string, proc *windows.LazyProc) {
r, _, err := proc.Call()
if r == 0 && err != nil && err != windows.NTE_OP_OK {
log.Tracef("%s: %v", name, err)
}
}
// logCleanupCallArgs is logCleanupCall with one argument; common pattern for
// release-by-handle syscalls.
func logCleanupCallArgs(name string, proc *windows.LazyProc, args ...uintptr) {
r, _, err := proc.Call(args...)
if r == 0 && err != nil && err != windows.NTE_OP_OK {
log.Tracef("%s: %v", name, err)
}
}

View File

@@ -0,0 +1,643 @@
//go:build darwin && !ios
package server
import (
"errors"
"fmt"
"hash/maphash"
"image"
"os"
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"
"unsafe"
"github.com/ebitengine/purego"
log "github.com/sirupsen/logrus"
)
var darwinCaptureOnce sync.Once
var (
cgMainDisplayID func() uint32
cgDisplayPixelsWide func(uint32) uintptr
cgDisplayPixelsHigh func(uint32) uintptr
cgDisplayCreateImage func(uint32) uintptr
cgImageGetWidth func(uintptr) uintptr
cgImageGetHeight func(uintptr) uintptr
cgImageGetBytesPerRow func(uintptr) uintptr
cgImageGetBitsPerPixel func(uintptr) uintptr
cgImageGetDataProvider func(uintptr) uintptr
cgDataProviderCopyData func(uintptr) uintptr
cgImageRelease func(uintptr)
cfDataGetLength func(uintptr) int64
cfDataGetBytePtr func(uintptr) uintptr
cfRelease func(uintptr)
cgRequestScreenCaptureAccess func() bool
cgEventCreate func(uintptr) uintptr
cgEventGetLocation func(uintptr) cgPoint
darwinCaptureReady bool
)
// cgPoint mirrors CoreGraphics CGPoint: two doubles, 16 bytes, returned
// in registers on Darwin amd64/arm64. Used to receive cursor coordinates
// from CGEventGetLocation via purego.
type cgPoint struct {
X, Y float64
}
func initDarwinCapture() {
darwinCaptureOnce.Do(func() {
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreGraphics: %v", err)
return
}
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreFoundation: %v", err)
return
}
purego.RegisterLibFunc(&cgMainDisplayID, cg, "CGMainDisplayID")
purego.RegisterLibFunc(&cgDisplayPixelsWide, cg, "CGDisplayPixelsWide")
purego.RegisterLibFunc(&cgDisplayPixelsHigh, cg, "CGDisplayPixelsHigh")
purego.RegisterLibFunc(&cgDisplayCreateImage, cg, "CGDisplayCreateImage")
purego.RegisterLibFunc(&cgImageGetWidth, cg, "CGImageGetWidth")
purego.RegisterLibFunc(&cgImageGetHeight, cg, "CGImageGetHeight")
purego.RegisterLibFunc(&cgImageGetBytesPerRow, cg, "CGImageGetBytesPerRow")
purego.RegisterLibFunc(&cgImageGetBitsPerPixel, cg, "CGImageGetBitsPerPixel")
purego.RegisterLibFunc(&cgImageGetDataProvider, cg, "CGImageGetDataProvider")
purego.RegisterLibFunc(&cgDataProviderCopyData, cg, "CGDataProviderCopyData")
purego.RegisterLibFunc(&cgImageRelease, cg, "CGImageRelease")
purego.RegisterLibFunc(&cfDataGetLength, cf, "CFDataGetLength")
purego.RegisterLibFunc(&cfDataGetBytePtr, cf, "CFDataGetBytePtr")
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
// CGRequestScreenCaptureAccess (macOS 11+) prompts on first call and
// is a cheap no-op once granted. The Preflight companion is unreliable
// on Sequoia (returns false even when access is granted), so we drive
// the permission flow from actual capture failures instead.
if sym, err := purego.Dlsym(cg, "CGRequestScreenCaptureAccess"); err == nil {
purego.RegisterFunc(&cgRequestScreenCaptureAccess, sym)
}
// CGEventCreate / CGEventGetLocation feed the cursor position used
// by remote-cursor compositing. Optional; absence reports as a
// position-source error and disables that feature on this host.
if sym, err := purego.Dlsym(cg, "CGEventCreate"); err == nil {
purego.RegisterFunc(&cgEventCreate, sym)
}
if sym, err := purego.Dlsym(cg, "CGEventGetLocation"); err == nil {
purego.RegisterFunc(&cgEventGetLocation, sym)
}
darwinCaptureReady = true
})
}
// CGCapturer captures the macOS main display using Core Graphics.
type CGCapturer struct {
displayID uint32
w, h int
// downscale is 1 for pixel-perfect, 2 for Retina 2:1 box-filter downscale.
downscale int
hashSeed maphash.Seed
lastHash uint64
hasHash bool
// cursor lazily binds the private CGSCreateCurrentCursorImage symbol
// so we can emit the Cursor pseudo-encoding without a per-frame cost
// on builds that never query it.
cursorOnce sync.Once
cursor *cgCursor
}
// PrimeScreenCapturePermission triggers the macOS Screen Recording
// permission prompt without creating a full capturer. The platform wiring
// calls this at VNC-server enable time so the user sees the prompt the
// moment they turn the feature on. CGRequestScreenCaptureAccess is a
// no-op when the grant already exists, so calling it on every enable is
// cheap and safe.
func PrimeScreenCapturePermission() {
initDarwinCapture()
if !darwinCaptureReady {
return
}
if cgRequestScreenCaptureAccess != nil {
cgRequestScreenCaptureAccess()
}
}
// notifyScreenRecordingMissing nudges the user once per agent process to
// approve Screen Recording. The capturer init retries on backoff when the
// grant is missing; without the sync.Once we would reopen System Settings
// every tick and flood the daemon log with the same warning.
var screenRecordingNotifyOnce sync.Once
func notifyScreenRecordingMissing() {
screenRecordingNotifyOnce.Do(func() {
if cgRequestScreenCaptureAccess != nil {
cgRequestScreenCaptureAccess()
}
openPrivacyPane("Privacy_ScreenCapture")
log.Warn("Screen Recording permission not granted. " +
"Opened System Settings > Privacy & Security > Screen Recording; enable netbird and restart.")
})
}
// NewCGCapturer creates a screen capturer for the main display.
func NewCGCapturer() (*CGCapturer, error) {
initDarwinCapture()
if !darwinCaptureReady {
return nil, fmt.Errorf("CoreGraphics not available")
}
displayID := cgMainDisplayID()
c := &CGCapturer{displayID: displayID, downscale: 1, hashSeed: maphash.MakeSeed()}
img, err := c.Capture()
if err != nil {
notifyScreenRecordingMissing()
return nil, fmt.Errorf("probe capture: %w", err)
}
nativeW := img.Rect.Dx()
nativeH := img.Rect.Dy()
c.hasHash = false
if nativeW == 0 || nativeH == 0 {
return nil, errors.New("display dimensions are zero")
}
logicalW := int(cgDisplayPixelsWide(displayID))
logicalH := int(cgDisplayPixelsHigh(displayID))
// Enable 2:1 downscale on Retina unless explicitly disabled. Cuts pixel
// count 4x, shrinking convert, diff, and wire data proportionally.
if !retinaDownscaleDisabled() && nativeW >= 2*logicalW && nativeH >= 2*logicalH && nativeW%2 == 0 && nativeH%2 == 0 {
c.downscale = 2
}
c.w = nativeW / c.downscale
c.h = nativeH / c.downscale
log.Infof("macOS capturer ready: %dx%d (native %dx%d, logical %dx%d, downscale=%d, display=%d)",
c.w, c.h, nativeW, nativeH, logicalW, logicalH, c.downscale, displayID)
return c, nil
}
func retinaDownscaleDisabled() bool {
v := os.Getenv(EnvVNCDisableDownscale)
if v == "" {
return false
}
disabled, err := strconv.ParseBool(v)
if err != nil {
log.Warnf("parse %s: %v", EnvVNCDisableDownscale, err)
return false
}
return disabled
}
// Width returns the screen width.
func (c *CGCapturer) Width() int { return c.w }
// Height returns the screen height.
func (c *CGCapturer) Height() int { return c.h }
// Capture returns the current screen as an RGBA image.
// CaptureInto writes a fresh frame directly into dst, skipping the
// per-frame image.RGBA allocation that Capture() does. Returns
// errFrameUnchanged when the screen hash matches the prior call.
func (c *CGCapturer) CaptureInto(dst *image.RGBA) error {
cgImage := cgDisplayCreateImage(c.displayID)
if cgImage == 0 {
return fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
}
defer cgImageRelease(cgImage)
w := int(cgImageGetWidth(cgImage))
h := int(cgImageGetHeight(cgImage))
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
bpp := int(cgImageGetBitsPerPixel(cgImage))
provider := cgImageGetDataProvider(cgImage)
if provider == 0 {
return fmt.Errorf("CGImageGetDataProvider returned nil")
}
cfData := cgDataProviderCopyData(provider)
if cfData == 0 {
return fmt.Errorf("CGDataProviderCopyData returned nil")
}
defer cfRelease(cfData)
dataLen := int(cfDataGetLength(cfData))
dataPtr := cfDataGetBytePtr(cfData)
if dataPtr == 0 || dataLen == 0 {
return fmt.Errorf("empty image data")
}
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
hash := maphash.Bytes(c.hashSeed, src)
if c.hasHash && hash == c.lastHash {
return errFrameUnchanged
}
c.lastHash = hash
c.hasHash = true
ds := c.downscale
if ds < 1 {
ds = 1
}
outW := w / ds
outH := h / ds
if dst.Rect.Dx() != outW || dst.Rect.Dy() != outH {
return fmt.Errorf("dst size mismatch: dst=%dx%d capturer=%dx%d",
dst.Rect.Dx(), dst.Rect.Dy(), outW, outH)
}
bytesPerPixel := bpp / 8
if bytesPerPixel == 4 && ds == 1 {
convertBGRAToRGBA(dst.Pix, dst.Stride, src, bytesPerRow, w, h)
return nil
}
if bytesPerPixel == 4 && ds == 2 {
convertBGRAToRGBADownscale2(dst.Pix, dst.Stride, src, bytesPerRow, outW, outH)
return nil
}
for row := 0; row < outH; row++ {
srcOff := row * ds * bytesPerRow
dstOff := row * dst.Stride
for col := 0; col < outW; col++ {
si := srcOff + col*ds*bytesPerPixel
di := dstOff + col*4
dst.Pix[di+0] = src[si+2]
dst.Pix[di+1] = src[si+1]
dst.Pix[di+2] = src[si+0]
dst.Pix[di+3] = 0xff
}
}
return nil
}
func (c *CGCapturer) Capture() (*image.RGBA, error) {
cgImage := cgDisplayCreateImage(c.displayID)
if cgImage == 0 {
return nil, fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
}
defer cgImageRelease(cgImage)
w := int(cgImageGetWidth(cgImage))
h := int(cgImageGetHeight(cgImage))
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
bpp := int(cgImageGetBitsPerPixel(cgImage))
provider := cgImageGetDataProvider(cgImage)
if provider == 0 {
return nil, fmt.Errorf("CGImageGetDataProvider returned nil")
}
cfData := cgDataProviderCopyData(provider)
if cfData == 0 {
return nil, fmt.Errorf("CGDataProviderCopyData returned nil")
}
defer cfRelease(cfData)
dataLen := int(cfDataGetLength(cfData))
dataPtr := cfDataGetBytePtr(cfData)
if dataPtr == 0 || dataLen == 0 {
return nil, fmt.Errorf("empty image data")
}
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
hash := maphash.Bytes(c.hashSeed, src)
if c.hasHash && hash == c.lastHash {
return nil, errFrameUnchanged
}
c.lastHash = hash
c.hasHash = true
ds := c.downscale
if ds < 1 {
ds = 1
}
outW := w / ds
outH := h / ds
img := image.NewRGBA(image.Rect(0, 0, outW, outH))
bytesPerPixel := bpp / 8
switch {
case bytesPerPixel == 4 && ds == 1:
convertBGRAToRGBA(img.Pix, img.Stride, src, bytesPerRow, w, h)
case bytesPerPixel == 4 && ds == 2:
convertBGRAToRGBADownscale2(img.Pix, img.Stride, src, bytesPerRow, outW, outH)
default:
convertBGRAToRGBAGeneric(img.Pix, img.Stride, src, bytesPerRow, bgraDownscaleParams{outW: outW, outH: outH, bytesPerPixel: bytesPerPixel, ds: ds})
}
return img, nil
}
type bgraDownscaleParams struct {
outW, outH, bytesPerPixel, ds int
}
// convertBGRAToRGBAGeneric is the slow per-pixel fallback for non-4-bytes
// or non-1/2 downscale formats. Always available regardless of the source
// format quirks the fast paths optimize for.
func convertBGRAToRGBAGeneric(dst []byte, dstStride int, src []byte, srcStride int, p bgraDownscaleParams) {
for row := 0; row < p.outH; row++ {
srcOff := row * p.ds * srcStride
dstOff := row * dstStride
for col := 0; col < p.outW; col++ {
si := srcOff + col*p.ds*p.bytesPerPixel
di := dstOff + col*4
dst[di+0] = src[si+2]
dst[di+1] = src[si+1]
dst[di+2] = src[si+0]
dst[di+3] = 0xff
}
}
}
// convertBGRAToRGBADownscale2 averages every 2x2 BGRA block into one RGBA
// output pixel, parallelised across GOMAXPROCS cores. outW and outH are the
// destination dimensions (source is 2*outW by 2*outH).
func convertBGRAToRGBADownscale2(dst []byte, dstStride int, src []byte, srcStride, outW, outH int) {
workers := runtime.GOMAXPROCS(0)
if workers > outH {
workers = outH
}
if workers < 1 || outH < 32 {
workers = 1
}
convertRows := func(y0, y1 int) {
for row := y0; row < y1; row++ {
srcRow0 := 2 * row * srcStride
srcRow1 := srcRow0 + srcStride
dstOff := row * dstStride
for col := 0; col < outW; col++ {
s0 := srcRow0 + col*8
s1 := srcRow1 + col*8
b := (uint32(src[s0]) + uint32(src[s0+4]) + uint32(src[s1]) + uint32(src[s1+4])) >> 2
g := (uint32(src[s0+1]) + uint32(src[s0+5]) + uint32(src[s1+1]) + uint32(src[s1+5])) >> 2
r := (uint32(src[s0+2]) + uint32(src[s0+6]) + uint32(src[s1+2]) + uint32(src[s1+6])) >> 2
di := dstOff + col*4
dst[di+0] = byte(r)
dst[di+1] = byte(g)
dst[di+2] = byte(b)
dst[di+3] = 0xff
}
}
}
if workers == 1 {
convertRows(0, outH)
return
}
var wg sync.WaitGroup
chunk := (outH + workers - 1) / workers
for i := 0; i < workers; i++ {
y0 := i * chunk
y1 := y0 + chunk
if y1 > outH {
y1 = outH
}
if y0 >= y1 {
break
}
wg.Add(1)
go func(y0, y1 int) {
defer wg.Done()
convertRows(y0, y1)
}(y0, y1)
}
wg.Wait()
}
// convertBGRAToRGBA swaps R/B channels using uint32 word operations, and
// parallelises across GOMAXPROCS cores for large images.
func convertBGRAToRGBA(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
workers := runtime.GOMAXPROCS(0)
if workers > h {
workers = h
}
if workers < 1 || h < 64 {
workers = 1
}
convertRows := func(y0, y1 int) {
rowBytes := w * 4
for row := y0; row < y1; row++ {
dstRow := dst[row*dstStride : row*dstStride+rowBytes]
srcRow := src[row*srcStride : row*srcStride+rowBytes]
dstU := unsafe.Slice((*uint32)(unsafe.Pointer(&dstRow[0])), w)
srcU := unsafe.Slice((*uint32)(unsafe.Pointer(&srcRow[0])), w)
for i, p := range srcU {
dstU[i] = (p & 0xff00ff00) | ((p & 0x000000ff) << 16) | ((p & 0x00ff0000) >> 16) | 0xff000000
}
}
}
if workers == 1 {
convertRows(0, h)
return
}
var wg sync.WaitGroup
chunk := (h + workers - 1) / workers
for i := 0; i < workers; i++ {
y0 := i * chunk
y1 := y0 + chunk
if y1 > h {
y1 = h
}
if y0 >= y1 {
break
}
wg.Add(1)
go func(y0, y1 int) {
defer wg.Done()
convertRows(y0, y1)
}(y0, y1)
}
wg.Wait()
}
// MacPoller wraps CGCapturer with a staleness-cached on-demand Capture:
// sessions drive captures themselves from their encoder goroutine, so we
// don't need a background ticker. The last result is cached for a short
// window so concurrent sessions coalesce into one capture.
//
// The capturer is allocated lazily on first use and released when all
// clients disconnect. Init is retried with backoff because the user may
// grant Screen Recording permission while the server is already running.
type MacPoller struct {
mu sync.Mutex
capturer *CGCapturer
w, h int
lastFrame *image.RGBA
lastAt time.Time
clients atomic.Int32
initFails int
initBackoffUntil time.Time
closed bool
}
// macInitRetryBackoffFor returns the delay we wait between init attempts
// after consecutive failures. Screen Recording permission is a one-shot
// user grant, so after several failures we back off aggressively.
func macInitRetryBackoffFor(fails int) time.Duration {
switch {
case fails > 15:
return 30 * time.Second
case fails > 5:
return 10 * time.Second
default:
return 2 * time.Second
}
}
// NewMacPoller creates a lazy on-demand capturer for the macOS display.
func NewMacPoller() *MacPoller {
return &MacPoller{}
}
// Wake is a no-op retained for API compatibility. With on-demand capture
// there is no background retry loop to kick: init happens on the next
// Capture/ClientConnect call.
func (p *MacPoller) Wake() {
// intentional no-op
}
// ClientConnect increments the active client count and eagerly initialises
// the capturer so the first FBUpdateRequest doesn't pay the init cost.
func (p *MacPoller) ClientConnect() {
if p.clients.Add(1) == 1 {
p.mu.Lock()
_ = p.ensureCapturerLocked()
p.mu.Unlock()
}
}
// ClientDisconnect decrements the active client count. On the last
// disconnect the capturer is released.
func (p *MacPoller) ClientDisconnect() {
if p.clients.Add(-1) == 0 {
p.mu.Lock()
p.capturer = nil
p.lastFrame = nil
p.mu.Unlock()
}
}
// Close releases all resources.
func (p *MacPoller) Close() {
p.mu.Lock()
p.closed = true
p.capturer = nil
p.lastFrame = nil
p.mu.Unlock()
}
// Width returns the screen width. Triggers lazy init if needed.
func (p *MacPoller) Width() int {
p.mu.Lock()
defer p.mu.Unlock()
_ = p.ensureCapturerLocked()
return p.w
}
// Height returns the screen height. Triggers lazy init if needed.
func (p *MacPoller) Height() int {
p.mu.Lock()
defer p.mu.Unlock()
_ = p.ensureCapturerLocked()
return p.h
}
// CaptureInto fills dst directly via the underlying capturer, bypassing
// the freshness cache.
func (p *MacPoller) CaptureInto(dst *image.RGBA) error {
p.mu.Lock()
defer p.mu.Unlock()
if err := p.ensureCapturerLocked(); err != nil {
return err
}
err := p.capturer.CaptureInto(dst)
if errors.Is(err, errFrameUnchanged) {
// Caller (session) treats this as "no change"; the dst buffer
// keeps its prior contents from the previous capture cycle so
// the diff stays meaningful.
return err
}
if err != nil {
p.capturer = nil
return fmt.Errorf("macos capture: %w", err)
}
return nil
}
// Capture returns a fresh frame, serving from the short-lived cache if a
// previous caller captured within freshWindow. Handles the
// errFrameUnchanged return from CGCapturer by reusing the cached frame.
func (p *MacPoller) Capture() (*image.RGBA, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.lastFrame != nil && time.Since(p.lastAt) < freshWindow {
return p.lastFrame, nil
}
if err := p.ensureCapturerLocked(); err != nil {
return nil, err
}
img, err := p.capturer.Capture()
if errors.Is(err, errFrameUnchanged) {
if p.lastFrame != nil {
p.lastAt = time.Now()
return p.lastFrame, nil
}
return nil, err
}
if err != nil {
// Drop the capturer so the next call retries init; the display stream
// can die if the session changes or permissions are revoked.
p.capturer = nil
return nil, fmt.Errorf("macos capture: %w", err)
}
p.lastFrame = img
p.lastAt = time.Now()
return img, nil
}
// ensureCapturerLocked initialises the underlying CGCapturer if needed.
// Caller must hold p.mu.
func (p *MacPoller) ensureCapturerLocked() error {
if p.closed {
return fmt.Errorf("poller closed")
}
if p.capturer != nil {
return nil
}
if time.Now().Before(p.initBackoffUntil) {
return fmt.Errorf("macOS capturer unavailable (retry scheduled)")
}
c, err := NewCGCapturer()
if err != nil {
p.initFails++
p.initBackoffUntil = time.Now().Add(macInitRetryBackoffFor(p.initFails))
if p.initFails == 1 || p.initFails%10 == 0 {
log.Warnf("macOS capturer: %v (attempt %d)", err, p.initFails)
} else {
log.Debugf("macOS capturer: %v (attempt %d)", err, p.initFails)
}
return err
}
p.initFails = 0
p.capturer = c
p.w, p.h = c.Width(), c.Height()
return nil
}
var _ ScreenCapturer = (*MacPoller)(nil)

View File

@@ -0,0 +1,99 @@
//go:build windows
package server
import (
"errors"
"fmt"
"image"
"github.com/kirides/go-d3d/d3d11"
"github.com/kirides/go-d3d/outputduplication"
)
// dxgiCapturer captures the desktop using DXGI Desktop Duplication.
// Provides GPU-accelerated capture with native dirty rect tracking.
// Only works from the interactive user session, not Session 0.
//
// Uses a double-buffer: DXGI writes into img, then we copy to the current
// output buffer and hand it out. Alternating between two output buffers
// avoids allocating a new image.RGBA per frame (~8MB at 1080p, 30fps).
type dxgiCapturer struct {
dup *outputduplication.OutputDuplicator
device *d3d11.ID3D11Device
ctx *d3d11.ID3D11DeviceContext
img *image.RGBA
out [2]*image.RGBA
outIdx int
width int
height int
}
func newDXGICapturer() (*dxgiCapturer, error) {
device, deviceCtx, err := d3d11.NewD3D11Device()
if err != nil {
return nil, fmt.Errorf("create D3D11 device: %w", err)
}
dup, err := outputduplication.NewIDXGIOutputDuplication(device, deviceCtx, 0)
if err != nil {
device.Release()
deviceCtx.Release()
return nil, fmt.Errorf("create output duplication: %w", err)
}
w, h := screenSize()
if w == 0 || h == 0 {
dup.Release()
device.Release()
deviceCtx.Release()
return nil, fmt.Errorf("screen dimensions are zero")
}
rect := image.Rect(0, 0, w, h)
c := &dxgiCapturer{
dup: dup,
device: device,
ctx: deviceCtx,
img: image.NewRGBA(rect),
out: [2]*image.RGBA{image.NewRGBA(rect), image.NewRGBA(rect)},
width: w,
height: h,
}
// Grab the initial frame with a longer timeout to ensure we have
// a valid image before returning.
_ = dup.GetImage(c.img, 2000)
return c, nil
}
func (c *dxgiCapturer) capture() (*image.RGBA, error) {
err := c.dup.GetImage(c.img, 100)
if err != nil && !errors.Is(err, outputduplication.ErrNoImageYet) {
return nil, err
}
// Copy into the next output buffer. The DesktopCapturer hands out the
// returned pointer to VNC sessions that read pixels concurrently, so we
// alternate between two pre-allocated buffers instead of allocating per frame.
out := c.out[c.outIdx]
c.outIdx ^= 1
copy(out.Pix, c.img.Pix)
return out, nil
}
func (c *dxgiCapturer) close() {
if c.dup != nil {
c.dup.Release()
c.dup = nil
}
if c.ctx != nil {
c.ctx.Release()
c.ctx = nil
}
if c.device != nil {
c.device.Release()
c.device = nil
}
}

View File

@@ -0,0 +1,148 @@
//go:build freebsd
package server
import (
"fmt"
"image"
"sync"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// FreeBSD vt(4) framebuffer ioctl numbers from sys/fbio.h.
//
// #define FBIOGTYPE _IOR('F', 0, struct fbtype)
//
// _IOR(g, n, t) on FreeBSD: dir=2 (read) <<30 | (sizeof(t) & 0x1fff)<<16
// | (g<<8) | n. sizeof(struct fbtype)=24 → 0x40184600.
const fbioGType = 0x40184600
func defaultFBPath() string { return "/dev/ttyv0" }
// fbType mirrors FreeBSD's struct fbtype.
type fbType struct {
FbType int32
FbHeight int32
FbWidth int32
FbDepth int32
FbCMSize int32
FbSize int32
}
// FBCapturer reads pixels from FreeBSD's vt(4) framebuffer device. The
// vt(4) console exposes the active framebuffer via ttyv0 with FBIOGTYPE
// for geometry and mmap for backing memory. Pixel layout is assumed to
// be 32bpp BGRA (the common case for KMS-backed vt); fbtype doesn't
// expose channel offsets, so we don't try to handle exotic layouts here.
type FBCapturer struct {
mu sync.Mutex
path string
fd int
mmap []byte
w, h int
bpp int
stride int
closeOnce sync.Once
}
// NewFBCapturer opens the given vt(4) device and queries its geometry.
func NewFBCapturer(path string) (*FBCapturer, error) {
if path == "" {
path = defaultFBPath()
}
fd, err := unix.Open(path, unix.O_RDWR, 0)
if err != nil {
return nil, fmt.Errorf("open %s: %w", path, err)
}
var fbt fbType
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGType, uintptr(unsafe.Pointer(&fbt))); e != 0 {
unix.Close(fd)
return nil, fmt.Errorf("FBIOGTYPE: %v", e)
}
if fbt.FbDepth != 16 && fbt.FbDepth != 24 && fbt.FbDepth != 32 {
unix.Close(fd)
return nil, fmt.Errorf("unsupported framebuffer depth: %d", fbt.FbDepth)
}
if fbt.FbWidth <= 0 || fbt.FbHeight <= 0 || fbt.FbSize <= 0 {
unix.Close(fd)
return nil, fmt.Errorf("invalid framebuffer geometry: %dx%d size=%d", fbt.FbWidth, fbt.FbHeight, fbt.FbSize)
}
mm, err := unix.Mmap(fd, 0, int(fbt.FbSize), unix.PROT_READ, unix.MAP_SHARED)
if err != nil {
unix.Close(fd)
return nil, fmt.Errorf("mmap %s: %w (vt may not support mmap on this driver, e.g. virtio_gpu)", path, err)
}
bpp := int(fbt.FbDepth)
stride := int(fbt.FbWidth) * (bpp / 8)
c := &FBCapturer{
path: path,
fd: fd, // valid fd >= 0; we use -1 as the closed sentinel
mmap: mm,
w: int(fbt.FbWidth),
h: int(fbt.FbHeight),
bpp: bpp,
stride: stride,
}
log.Infof("framebuffer capturer ready: %s %dx%d bpp=%d (freebsd vt)", path, c.w, c.h, c.bpp)
return c, nil
}
// Width returns the framebuffer width.
func (c *FBCapturer) Width() int { return c.w }
// Height returns the framebuffer height.
func (c *FBCapturer) Height() int { return c.h }
// Capture allocates a fresh image and fills it with the current
// framebuffer contents.
func (c *FBCapturer) Capture() (*image.RGBA, error) {
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
if err := c.CaptureInto(img); err != nil {
return nil, err
}
return img, nil
}
// CaptureInto reads the framebuffer directly into dst.Pix. Assumes BGRA
// for 32bpp; the FreeBSD fbtype struct doesn't expose channel offsets.
func (c *FBCapturer) CaptureInto(dst *image.RGBA) error {
c.mu.Lock()
defer c.mu.Unlock()
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
return fmt.Errorf("dst size mismatch: dst=%dx%d fb=%dx%d",
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
}
switch c.bpp {
case 32:
// vt(4) on KMS framebuffers is BGRA: byte 0=B, 1=G, 2=R.
swizzleBGRAtoRGBA(dst.Pix, c.mmap[:c.h*c.stride])
case 24:
swizzleFB24(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
case 16:
swizzleFB16RGB565(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
}
return nil
}
// Close releases the framebuffer mmap and file descriptor. Serialized with
// CaptureInto via c.mu so an in-flight capture can't read freed memory.
func (c *FBCapturer) Close() {
c.closeOnce.Do(func() {
c.mu.Lock()
defer c.mu.Unlock()
if c.mmap != nil {
_ = unix.Munmap(c.mmap)
c.mmap = nil
}
if c.fd >= 0 {
_ = unix.Close(c.fd)
c.fd = -1
}
})
}

View File

@@ -0,0 +1,229 @@
//go:build linux && !android
package server
import (
"encoding/binary"
"fmt"
"image"
"sync"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// Linux framebuffer ioctls (linux/fb.h).
const (
fbioGetVScreenInfo = 0x4600
fbioGetFScreenInfo = 0x4602
)
func defaultFBPath() string { return "/dev/fb0" }
// fbVarScreenInfo mirrors the kernel's fb_var_screeninfo. Only the
// fields we use are mapped; the rest are absorbed into _padN.
type fbVarScreenInfo struct {
Xres, Yres uint32
XresVirtual, YresVirtual uint32
XOffset, YOffset uint32
BitsPerPixel uint32
Grayscale uint32
RedOffset, RedLen, RedMSBR uint32
GreenOffset, GreenLen, GreenMSBR uint32
BlueOffset, BlueLen, BlueMSBR uint32
TranspOffset, TranspLen, TranspM uint32
NonStd uint32
Activate uint32
Height, Width uint32
AccelFlags uint32
PixClock uint32
LeftMargin, RightMargin uint32
UpperMargin, LowerMargin uint32
HsyncLen, VsyncLen uint32
Sync uint32
Vmode uint32
Rotate uint32
Colorspace uint32
_pad [4]uint32
}
// fbFixScreenInfo mirrors fb_fix_screeninfo. We only need LineLength.
type fbFixScreenInfo struct {
IDStr [16]byte
SmemStart uint64
SmemLen uint32
Type uint32
TypeAux uint32
Visual uint32
XPanStep uint16
YPanStep uint16
YWrapStep uint16
_pad0 uint16
LineLength uint32
MmioStart uint64
MmioLen uint32
Accel uint32
Capabilities uint16
_reserved [2]uint16
}
// FBCapturer reads pixels straight from the Linux framebuffer device.
// Used as a fallback when X11 isn't available, e.g. on a headless box at
// the kernel console or the display manager's pre-login screen on machines
// without an Xorg server. The framebuffer must be mmap()-able under our
// process privileges (typically the netbird service runs as root).
type FBCapturer struct {
mu sync.Mutex
path string
fd int
mmap []byte
w, h int
bpp int
stride int
rOff uint32
gOff uint32
bOff uint32
rLen uint32
gLen uint32
bLen uint32
closeOnce sync.Once
}
// NewFBCapturer opens the given framebuffer device (/dev/fbN) and
// queries its current geometry + pixel format.
func NewFBCapturer(path string) (*FBCapturer, error) {
if path == "" {
path = "/dev/fb0"
}
fd, err := unix.Open(path, unix.O_RDONLY, 0)
if err != nil {
return nil, fmt.Errorf("open %s: %w", path, err)
}
var vinfo fbVarScreenInfo
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGetVScreenInfo, uintptr(unsafe.Pointer(&vinfo))); e != 0 {
unix.Close(fd)
return nil, fmt.Errorf("FBIOGET_VSCREENINFO: %v", e)
}
var finfo fbFixScreenInfo
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGetFScreenInfo, uintptr(unsafe.Pointer(&finfo))); e != 0 {
unix.Close(fd)
return nil, fmt.Errorf("FBIOGET_FSCREENINFO: %v", e)
}
bpp := int(vinfo.BitsPerPixel)
if bpp != 16 && bpp != 24 && bpp != 32 {
unix.Close(fd)
return nil, fmt.Errorf("unsupported framebuffer bpp: %d", bpp)
}
size := int(finfo.LineLength) * int(vinfo.Yres)
if size <= 0 {
unix.Close(fd)
return nil, fmt.Errorf("invalid framebuffer dimensions: stride=%d h=%d", finfo.LineLength, vinfo.Yres)
}
mm, err := unix.Mmap(fd, 0, size, unix.PROT_READ, unix.MAP_SHARED)
if err != nil {
unix.Close(fd)
return nil, fmt.Errorf("mmap %s: %w", path, err)
}
c := &FBCapturer{
path: path,
fd: fd,
mmap: mm,
w: int(vinfo.Xres),
h: int(vinfo.Yres),
bpp: bpp,
stride: int(finfo.LineLength),
rOff: vinfo.RedOffset,
gOff: vinfo.GreenOffset,
bOff: vinfo.BlueOffset,
rLen: vinfo.RedLen,
gLen: vinfo.GreenLen,
bLen: vinfo.BlueLen,
}
log.Infof("framebuffer capturer ready: %s %dx%d bpp=%d r=%d/%d g=%d/%d b=%d/%d",
path, c.w, c.h, c.bpp, c.rOff, c.rLen, c.gOff, c.gLen, c.bOff, c.bLen)
return c, nil
}
// Width returns the framebuffer width in pixels.
func (c *FBCapturer) Width() int { return c.w }
// Height returns the framebuffer height in pixels.
func (c *FBCapturer) Height() int { return c.h }
// Capture allocates a fresh image and fills it with the current
// framebuffer contents.
func (c *FBCapturer) Capture() (*image.RGBA, error) {
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
if err := c.CaptureInto(img); err != nil {
return nil, err
}
return img, nil
}
// CaptureInto reads the framebuffer directly into dst.Pix.
func (c *FBCapturer) CaptureInto(dst *image.RGBA) error {
c.mu.Lock()
defer c.mu.Unlock()
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
return fmt.Errorf("dst size mismatch: dst=%dx%d fb=%dx%d",
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
}
switch c.bpp {
case 32:
swizzleFB32(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h, channelShifts{R: c.rOff, G: c.gOff, B: c.bOff})
case 24:
swizzleFB24(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
case 16:
swizzleFB16RGB565(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
}
return nil
}
// Close releases the framebuffer mmap and file descriptor. Serialized with
// CaptureInto via c.mu so an in-flight capture can't read freed memory.
func (c *FBCapturer) Close() {
c.closeOnce.Do(func() {
c.mu.Lock()
defer c.mu.Unlock()
if c.mmap != nil {
_ = unix.Munmap(c.mmap)
c.mmap = nil
}
if c.fd >= 0 {
_ = unix.Close(c.fd)
c.fd = -1
}
})
}
// channelShifts groups the bit offsets for the R/G/B channels in a packed
// uint32 framebuffer pixel. Bundling avoids drowning per-row callers in a
// 9-parameter signature.
type channelShifts struct {
R, G, B uint32
}
// swizzleFB32 handles 32-bit framebuffers with arbitrary R/G/B channel
// offsets. Pulls one pixel per uint32, then masks each channel into the
// destination RGBA byte order.
func swizzleFB32(dst []byte, dstStride int, src []byte, srcStride, w, h int, shifts channelShifts) {
for y := 0; y < h; y++ {
srcRow := src[y*srcStride : y*srcStride+w*4]
dstRow := dst[y*dstStride:]
for x := 0; x < w; x++ {
pix := binary.LittleEndian.Uint32(srcRow[x*4 : x*4+4])
dstRow[x*4+0] = byte(pix >> shifts.R)
dstRow[x*4+1] = byte(pix >> shifts.G)
dstRow[x*4+2] = byte(pix >> shifts.B)
dstRow[x*4+3] = 0xff
}
}
}

View File

@@ -0,0 +1,149 @@
//go:build unix && !darwin && !ios && !android
package server
import (
"image"
"sync"
)
// FBPoller wraps FBCapturer with the same lifecycle (ClientConnect /
// ClientDisconnect, lazy init) as X11Poller, so it slots into the same
// session plumbing without code changes upstream. The concrete
// FBCapturer is platform-specific (capture_fb_linux.go / _freebsd.go);
// this file owns the cross-platform glue.
type FBPoller struct {
mu sync.Mutex
path string
capturer *FBCapturer
w, h int
clients int32
}
// NewFBPoller returns a poller that opens path on first use. Empty path
// defaults to /dev/fb0 on Linux and /dev/ttyv0 on FreeBSD.
func NewFBPoller(path string) *FBPoller {
if path == "" {
path = defaultFBPath()
}
return &FBPoller{path: path}
}
// ClientConnect eagerly initialises the capturer on first connect.
func (p *FBPoller) ClientConnect() {
p.mu.Lock()
defer p.mu.Unlock()
p.clients++
if p.clients == 1 {
_ = p.ensureCapturerLocked()
}
}
// ClientDisconnect closes the capturer when the last client leaves.
func (p *FBPoller) ClientDisconnect() {
p.mu.Lock()
defer p.mu.Unlock()
p.clients--
if p.clients <= 0 && p.capturer != nil {
p.capturer.Close()
p.capturer = nil
}
}
// Width returns the framebuffer width, doing lazy init if needed.
func (p *FBPoller) Width() int {
p.mu.Lock()
defer p.mu.Unlock()
_ = p.ensureCapturerLocked()
return p.w
}
// Height returns the framebuffer height, doing lazy init if needed.
func (p *FBPoller) Height() int {
p.mu.Lock()
defer p.mu.Unlock()
_ = p.ensureCapturerLocked()
return p.h
}
// Capture takes a fresh frame.
func (p *FBPoller) Capture() (*image.RGBA, error) {
p.mu.Lock()
defer p.mu.Unlock()
if err := p.ensureCapturerLocked(); err != nil {
return nil, err
}
return p.capturer.Capture()
}
// CaptureInto fills dst directly.
func (p *FBPoller) CaptureInto(dst *image.RGBA) error {
p.mu.Lock()
defer p.mu.Unlock()
if err := p.ensureCapturerLocked(); err != nil {
return err
}
return p.capturer.CaptureInto(dst)
}
// Close releases all framebuffer resources.
func (p *FBPoller) Close() {
p.mu.Lock()
defer p.mu.Unlock()
if p.capturer != nil {
p.capturer.Close()
p.capturer = nil
}
}
func (p *FBPoller) ensureCapturerLocked() error {
if p.capturer != nil {
return nil
}
c, err := NewFBCapturer(p.path)
if err != nil {
return err
}
p.capturer = c
p.w, p.h = c.Width(), c.Height()
return nil
}
var _ ScreenCapturer = (*FBPoller)(nil)
var _ captureIntoer = (*FBPoller)(nil)
// swizzleFB24 handles 24-bit packed framebuffers (B,G,R triplets).
// Shared between Linux and FreeBSD framebuffer paths.
func swizzleFB24(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
for y := 0; y < h; y++ {
srcRow := src[y*srcStride : y*srcStride+w*3]
dstRow := dst[y*dstStride:]
for x := 0; x < w; x++ {
b := srcRow[x*3+0]
g := srcRow[x*3+1]
r := srcRow[x*3+2]
dstRow[x*4+0] = r
dstRow[x*4+1] = g
dstRow[x*4+2] = b
dstRow[x*4+3] = 0xff
}
}
}
// swizzleFB16RGB565 handles 16bpp RGB 565 framebuffers.
func swizzleFB16RGB565(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
for y := 0; y < h; y++ {
srcRow := src[y*srcStride : y*srcStride+w*2]
dstRow := dst[y*dstStride:]
for x := 0; x < w; x++ {
pix := uint16(srcRow[x*2]) | uint16(srcRow[x*2+1])<<8
r := byte((pix >> 11) & 0x1f)
g := byte((pix >> 5) & 0x3f)
b := byte(pix & 0x1f)
dstRow[x*4+0] = (r << 3) | (r >> 2)
dstRow[x*4+1] = (g << 2) | (g >> 4)
dstRow[x*4+2] = (b << 3) | (b >> 2)
dstRow[x*4+3] = 0xff
}
}
}

View File

@@ -0,0 +1,586 @@
//go:build windows
package server
import (
"fmt"
"image"
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
var (
gdi32 = windows.NewLazySystemDLL("gdi32.dll")
user32 = windows.NewLazySystemDLL("user32.dll")
procGetDC = user32.NewProc("GetDC")
procReleaseDC = user32.NewProc("ReleaseDC")
procCreateCompatDC = gdi32.NewProc("CreateCompatibleDC")
procCreateDIBSection = gdi32.NewProc("CreateDIBSection")
procSelectObject = gdi32.NewProc("SelectObject")
procDeleteObject = gdi32.NewProc("DeleteObject")
procDeleteDC = gdi32.NewProc("DeleteDC")
procBitBlt = gdi32.NewProc("BitBlt")
procGetSystemMetrics = user32.NewProc("GetSystemMetrics")
// Desktop switching for service/Session 0 capture.
procOpenInputDesktop = user32.NewProc("OpenInputDesktop")
procSetThreadDesktop = user32.NewProc("SetThreadDesktop")
procCloseDesktop = user32.NewProc("CloseDesktop")
procOpenWindowStation = user32.NewProc("OpenWindowStationW")
procSetProcessWindowStation = user32.NewProc("SetProcessWindowStation")
procCloseWindowStation = user32.NewProc("CloseWindowStation")
procGetUserObjectInformationW = user32.NewProc("GetUserObjectInformationW")
)
const uoiName = 2
const (
smCxScreen = 0
smCyScreen = 1
srccopy = 0x00CC0020
captureBlt = 0x40000000
dibRgbColors = 0
)
type bitmapInfoHeader struct {
Size uint32
Width int32
Height int32
Planes uint16
BitCount uint16
Compression uint32
SizeImage uint32
XPelsPerMeter int32
YPelsPerMeter int32
ClrUsed uint32
ClrImportant uint32
}
type bitmapInfo struct {
Header bitmapInfoHeader
}
// setupInteractiveWindowStation associates the current process with WinSta0,
// the interactive window station. This is required for a SYSTEM service in
// Session 0 to call OpenInputDesktop for screen capture and input injection.
func setupInteractiveWindowStation() error {
name, err := windows.UTF16PtrFromString("WinSta0")
if err != nil {
return fmt.Errorf("UTF16 WinSta0: %w", err)
}
hWinSta, _, err := procOpenWindowStation.Call(
uintptr(unsafe.Pointer(name)),
0,
uintptr(windows.MAXIMUM_ALLOWED),
)
if hWinSta == 0 {
return fmt.Errorf("OpenWindowStation(WinSta0): %w", err)
}
r, _, err := procSetProcessWindowStation.Call(hWinSta)
if r == 0 {
_, _, _ = procCloseWindowStation.Call(hWinSta)
return fmt.Errorf("SetProcessWindowStation: %w", err)
}
log.Info("process window station set to WinSta0 (interactive)")
return nil
}
func screenSize() (int, int) {
w, _, _ := procGetSystemMetrics.Call(uintptr(smCxScreen))
h, _, _ := procGetSystemMetrics.Call(uintptr(smCyScreen))
return int(w), int(h)
}
func getDesktopName(hDesk uintptr) string {
var buf [256]uint16
var needed uint32
_, _, _ = procGetUserObjectInformationW.Call(hDesk, uoiName,
uintptr(unsafe.Pointer(&buf[0])), 512,
uintptr(unsafe.Pointer(&needed)))
return windows.UTF16ToString(buf[:])
}
// switchToInputDesktop opens the desktop currently receiving user input
// and sets it as the calling OS thread's desktop. Must be called from a
// goroutine locked to its OS thread via runtime.LockOSThread().
func switchToInputDesktop() (bool, string) {
hDesk, _, _ := procOpenInputDesktop.Call(0, 0, uintptr(windows.MAXIMUM_ALLOWED))
if hDesk == 0 {
return false, ""
}
name := getDesktopName(hDesk)
ret, _, _ := procSetThreadDesktop.Call(hDesk)
_, _, _ = procCloseDesktop.Call(hDesk)
return ret != 0, name
}
// gdiCapturer captures the desktop screen using GDI BitBlt.
// GDI objects (DC, DIBSection) are allocated once and reused across frames.
type gdiCapturer struct {
mu sync.Mutex
width int
height int
// Pre-allocated GDI resources, reused across captures.
memDC uintptr
bmp uintptr
bits uintptr
}
func newGDICapturer() (*gdiCapturer, error) {
w, h := screenSize()
if w == 0 || h == 0 {
return nil, fmt.Errorf("screen dimensions are zero")
}
c := &gdiCapturer{width: w, height: h}
if err := c.allocGDI(); err != nil {
return nil, err
}
return c, nil
}
// allocGDI pre-allocates the compatible DC and DIB section for reuse.
func (c *gdiCapturer) allocGDI() error {
screenDC, _, _ := procGetDC.Call(0)
if screenDC == 0 {
return fmt.Errorf("GetDC returned 0")
}
defer func() { _, _, _ = procReleaseDC.Call(0, screenDC) }()
memDC, _, _ := procCreateCompatDC.Call(screenDC)
if memDC == 0 {
return fmt.Errorf("CreateCompatibleDC returned 0")
}
bi := bitmapInfo{
Header: bitmapInfoHeader{
Size: uint32(unsafe.Sizeof(bitmapInfoHeader{})),
Width: int32(c.width),
Height: -int32(c.height), // negative = top-down DIB
Planes: 1,
BitCount: 32,
},
}
var bits uintptr
bmp, _, _ := procCreateDIBSection.Call(
screenDC,
uintptr(unsafe.Pointer(&bi)),
dibRgbColors,
uintptr(unsafe.Pointer(&bits)),
0, 0,
)
if bmp == 0 || bits == 0 {
_, _, _ = procDeleteDC.Call(memDC)
return fmt.Errorf("CreateDIBSection returned 0")
}
_, _, _ = procSelectObject.Call(memDC, bmp)
c.memDC = memDC
c.bmp = bmp
c.bits = bits
return nil
}
func (c *gdiCapturer) close() { c.freeGDI() }
// freeGDI releases pre-allocated GDI resources.
func (c *gdiCapturer) freeGDI() {
if c.bmp != 0 {
_, _, _ = procDeleteObject.Call(c.bmp)
c.bmp = 0
}
if c.memDC != 0 {
_, _, _ = procDeleteDC.Call(c.memDC)
c.memDC = 0
}
c.bits = 0
}
func (c *gdiCapturer) capture() (*image.RGBA, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.memDC == 0 {
return nil, fmt.Errorf("GDI resources not allocated")
}
screenDC, _, _ := procGetDC.Call(0)
if screenDC == 0 {
return nil, fmt.Errorf("GetDC returned 0")
}
defer func() { _, _, _ = procReleaseDC.Call(0, screenDC) }()
// SRCCOPY|CAPTUREBLT: CAPTUREBLT forces inclusion of layered/topmost
// windows in the capture and is required for GDI BitBlt to return live
// pixels when the session is rendered through RDP / DWM-composited
// surfaces. Without it BitBlt reads the backing-store DIB which is
// often empty (all-black) on RDP and headless sessions.
ret, _, _ := procBitBlt.Call(c.memDC, 0, 0, uintptr(c.width), uintptr(c.height),
screenDC, 0, 0, srccopy|captureBlt)
if ret == 0 {
return nil, fmt.Errorf("BitBlt returned 0")
}
n := c.width * c.height * 4
raw := unsafe.Slice((*byte)(unsafe.Pointer(c.bits)), n)
// GDI gives BGRA, the RFB encoder expects RGBA (img.Pix layout).
// Swap R and B in bulk using uint32 operations (one load + mask + shift
// per pixel instead of three separate byte assignments).
img := image.NewRGBA(image.Rect(0, 0, c.width, c.height))
swizzleBGRAtoRGBA(img.Pix, raw)
return img, nil
}
// DesktopCapturer captures the interactive desktop, handling desktop transitions
// (login screen, UAC prompts). A dedicated OS-locked goroutine continuously
// captures frames on demand via a dedicated OS-locked goroutine (required
// because DXGI's D3D11 device context is not thread-safe). Sessions drive
// timing by calling Capture(); a short staleness cache coalesces concurrent
// requests. Capture pauses automatically when no clients are connected.
type DesktopCapturer struct {
mu sync.Mutex
w, h int
// lastFrame/lastAt implement a small staleness cache so multiple
// near-simultaneous Capture calls share one DXGI round-trip.
lastFrame *image.RGBA
lastAt time.Time
// clients tracks the number of active VNC sessions. When zero, the
// worker goroutine releases the underlying capturer.
clients atomic.Int32
// reqCh carries capture requests from sessions to the OS-locked worker.
reqCh chan captureReq
// wake is signaled when a client connects and the worker should resume.
wake chan struct{}
// done is closed when Close is called, terminating the worker.
done chan struct{}
// cursorState holds the latest cursor sprite sampled by the worker.
// The worker calls GetCursorInfo every capture and decodes a new
// sprite only when the HCURSOR changes.
cursorState cursorState
}
// captureReq is a single capture request awaiting a reply. Reply channel is
// buffered to size 1 so the worker never blocks on a sender that's gone.
type captureReq struct {
reply chan captureReply
}
type captureReply struct {
img *image.RGBA
err error
}
// NewDesktopCapturer creates an on-demand capturer for the active desktop.
func NewDesktopCapturer() *DesktopCapturer {
c := &DesktopCapturer{
wake: make(chan struct{}, 1),
done: make(chan struct{}),
reqCh: make(chan captureReq),
}
go c.worker()
return c
}
// ClientConnect increments the active client count, resuming capture if needed.
func (c *DesktopCapturer) ClientConnect() {
c.clients.Add(1)
select {
case c.wake <- struct{}{}:
default:
}
}
// ClientDisconnect decrements the active client count.
func (c *DesktopCapturer) ClientDisconnect() {
c.clients.Add(-1)
}
// Close stops the capture loop and releases resources.
func (c *DesktopCapturer) Close() {
select {
case <-c.done:
default:
close(c.done)
}
}
// Width returns the current screen width, triggering a capture if the
// worker hasn't initialised yet. validateCapturer depends on Width/Height
// becoming non-zero promptly after ClientConnect so it doesn't reject
// brand-new sessions.
func (c *DesktopCapturer) Width() int {
c.mu.Lock()
w := c.w
c.mu.Unlock()
if w == 0 && c.clients.Load() > 0 {
_, _ = c.Capture()
c.mu.Lock()
w = c.w
c.mu.Unlock()
}
return w
}
// Height returns the current screen height, triggering a capture if the
// worker hasn't initialised yet (see Width). Returns 0 while no client is
// connected so callers don't deadlock against a parked worker.
func (c *DesktopCapturer) Height() int {
c.mu.Lock()
h := c.h
c.mu.Unlock()
if h == 0 && c.clients.Load() > 0 {
_, _ = c.Capture()
c.mu.Lock()
h = c.h
c.mu.Unlock()
}
return h
}
// Capture returns a freshly captured frame, serving from a short staleness
// cache when multiple sessions ask within freshWindow of each other. All
// real DXGI/GDI work happens on the OS-locked worker goroutine.
func (c *DesktopCapturer) Capture() (*image.RGBA, error) {
c.mu.Lock()
if c.lastFrame != nil && time.Since(c.lastAt) < freshWindow {
img := c.lastFrame
c.mu.Unlock()
return img, nil
}
c.mu.Unlock()
reply := make(chan captureReply, 1)
select {
case c.reqCh <- captureReq{reply: reply}:
case <-c.done:
return nil, fmt.Errorf("capturer closed")
}
select {
case r := <-reply:
if r.err != nil {
return nil, r.err
}
c.mu.Lock()
c.lastFrame = r.img
c.lastAt = time.Now()
c.mu.Unlock()
return r.img, nil
case <-c.done:
return nil, fmt.Errorf("capturer closed")
}
}
// waitForClient blocks until a client connects or the capturer is closed.
func (c *DesktopCapturer) waitForClient() bool {
if c.clients.Load() > 0 {
return true
}
select {
case <-c.wake:
return true
case <-c.done:
return false
}
}
// worker owns DXGI/GDI state on its OS-locked thread and services capture
// requests from sessions. No background ticker: a capture happens only when
// a session asks for one (throttled by Capture()'s staleness cache).
func (c *DesktopCapturer) worker() {
runtime.LockOSThread()
// When running as a Windows service (Session 0), we need to attach to the
// interactive window station before OpenInputDesktop will succeed.
if err := setupInteractiveWindowStation(); err != nil {
log.Warnf("attach to interactive window station: %v", err)
}
w := &captureWorker{c: c}
defer w.closeCapturer()
for {
if !c.waitForClient() {
return
}
// Drop the capturer when all clients have disconnected so we don't
// hold the DXGI duplication or GDI DC on an idle peer.
if c.clients.Load() <= 0 {
w.closeCapturer()
continue
}
if !w.handleNextRequest() {
return
}
}
}
// frameCapturer is the per-backend interface used by the worker. DXGI and
// GDI implementations both satisfy it.
type frameCapturer interface {
capture() (*image.RGBA, error)
close()
}
// captureWorker owns the worker goroutine's mutable state. Extracted into a
// struct so the request/desktop/init logic can live on small methods and the
// outer worker() stays a thin loop.
type captureWorker struct {
c *DesktopCapturer
cap frameCapturer
desktopFails int
lastDesktop string
nextInitRetry time.Time
cursor cursorSampler
// lastBackend records the last capturer kind that came out of
// createCapturer ("dxgi" or "gdi"); used to demote repeat "using X"
// and DXGI-unavailable logs to debug when nothing changed.
lastBackend string
// lastDXGIErr is the textual DXGI failure printed in the most recent
// fallback warning; suppresses repeat warns when DXGI keeps failing
// the same way across desktop changes (login -> lock -> login).
lastDXGIErr string
}
// handleNextRequest waits for either shutdown or a capture request and runs
// the request through prepCapturer/capture. Returns false when the worker
// should exit.
func (w *captureWorker) handleNextRequest() bool {
select {
case <-w.c.done:
return false
case req := <-w.c.reqCh:
w.serveRequest(req)
return true
}
}
func (w *captureWorker) serveRequest(req captureReq) {
fc, err := w.prepCapturer()
if err != nil {
req.reply <- captureReply{err: err}
return
}
img, err := fc.capture()
if err != nil {
log.Debugf("capture: %v", err)
w.closeCapturer()
w.nextInitRetry = time.Now().Add(100 * time.Millisecond)
req.reply <- captureReply{err: err}
return
}
if snap, err := w.cursor.sample(); err != nil {
w.c.cursorState.store(&cursorSnapshot{err: err})
} else {
w.c.cursorState.store(snap)
}
req.reply <- captureReply{img: img}
}
// prepCapturer switches to the input desktop, handles desktop-change
// teardown, and creates the underlying capturer on demand. Backoff state is
// tracked across calls via w.nextInitRetry.
func (w *captureWorker) prepCapturer() (frameCapturer, error) {
if err := w.refreshDesktop(); err != nil {
return nil, err
}
if w.cap != nil {
return w.cap, nil
}
if time.Now().Before(w.nextInitRetry) {
return nil, fmt.Errorf("capturer init backing off")
}
fc, err := w.createCapturer()
if err != nil {
w.nextInitRetry = time.Now().Add(500 * time.Millisecond)
return nil, err
}
w.cap = fc
sw, sh := screenSize()
w.c.mu.Lock()
sizeChanged := w.c.w != sw || w.c.h != sh
w.c.w, w.c.h = sw, sh
w.c.mu.Unlock()
if sizeChanged {
log.Infof("screen capturer ready: %dx%d", sw, sh)
} else {
log.Debugf("screen capturer ready: %dx%d", sw, sh)
}
return w.cap, nil
}
// refreshDesktop tracks the active input desktop. When it changes (lock
// screen, fast-user-switch) the existing capturer is dropped so the next
// call rebuilds one against the new desktop.
func (w *captureWorker) refreshDesktop() error {
ok, desk := switchToInputDesktop()
if !ok {
w.desktopFails++
if w.desktopFails == 1 || w.desktopFails%100 == 0 {
log.Warnf("switchToInputDesktop failed (count=%d), no interactive desktop session?", w.desktopFails)
}
return fmt.Errorf("no interactive desktop")
}
if w.desktopFails > 0 {
log.Infof("switchToInputDesktop recovered after %d failures, desktop=%q", w.desktopFails, desk)
w.desktopFails = 0
}
if desk != w.lastDesktop {
log.Infof("desktop changed: %q -> %q", w.lastDesktop, desk)
w.lastDesktop = desk
w.closeCapturer()
}
return nil
}
func (w *captureWorker) createCapturer() (frameCapturer, error) {
dc, err := newDXGICapturer()
if err == nil {
if w.lastBackend != "dxgi" {
log.Info("using DXGI Desktop Duplication for capture")
} else {
log.Debug("using DXGI Desktop Duplication for capture")
}
w.lastBackend = "dxgi"
w.lastDXGIErr = ""
return dc, nil
}
errStr := err.Error()
if errStr != w.lastDXGIErr {
log.Warnf("DXGI Desktop Duplication unavailable, falling back to slower GDI BitBlt: %v", err)
w.lastDXGIErr = errStr
} else {
log.Debugf("DXGI Desktop Duplication still unavailable, falling back to slower GDI BitBlt: %v", err)
}
gc, err := newGDICapturer()
if err != nil {
return nil, err
}
if w.lastBackend != "gdi" {
log.Info("using GDI BitBlt for capture")
} else {
log.Debug("using GDI BitBlt for capture")
}
w.lastBackend = "gdi"
return gc, nil
}
func (w *captureWorker) closeCapturer() {
if w.cap != nil {
w.cap.close()
w.cap = nil
}
}

View File

@@ -0,0 +1,533 @@
//go:build unix && !darwin && !ios && !android
package server
import (
"fmt"
"image"
"os"
"os/exec"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"github.com/jezek/xgb"
"github.com/jezek/xgb/xproto"
)
const (
// x11SocketDir is the well-known directory where X servers create
// their abstract UNIX-domain sockets, named "X<display>". Used both
// for auto-detecting an existing display and for placing/probing
// sockets of virtual sessions we spawn.
x11SocketDir = "/tmp/.X11-unix"
// envDisplay is the X11 display selector environment variable.
envDisplay = "DISPLAY"
// envXAuthority points X clients at the cookie file used to
// authenticate against the running X server.
envXAuthority = "XAUTHORITY"
)
// X11Capturer captures the screen from an X11 display using the MIT-SHM extension.
type X11Capturer struct {
mu sync.Mutex
conn *xgb.Conn
screen *xproto.ScreenInfo
w, h int
shmID int
shmAddr []byte
shmSeg uint32
useSHM bool
// bufs double-buffers output images so the X11Poller's capture loop can
// overwrite one while the session is still encoding the other. Before
// this, a single reused buffer would race with the reader. Allocation
// happens on first use and on geometry change.
bufs [2]*image.RGBA
cur int
// cursor is the XFixes binding used to report the current sprite.
// Allocated lazily on the first Cursor call. cursorInitErr latches
// a permanent init failure so we stop retrying every frame.
cursor *xfixesCursor
cursorInitErr error
}
// detectX11Display finds the active X11 display and sets DISPLAY/XAUTHORITY
// environment variables if needed. This is required when running as a system
// service where these vars aren't set.
func detectX11Display() {
if os.Getenv(envDisplay) != "" {
return
}
// Try /proc first (Linux), then ps fallback (FreeBSD and others).
if detectX11FromProc() {
return
}
if detectX11FromSockets() {
return
}
}
// detectX11FromProc scans /proc/*/cmdline for Xorg (Linux).
func detectX11FromProc() bool {
entries, err := os.ReadDir("/proc")
if err != nil {
return false
}
for _, e := range entries {
if !e.IsDir() {
continue
}
cmdline, err := os.ReadFile("/proc/" + e.Name() + "/cmdline")
if err != nil {
continue
}
if display, auth := parseXorgArgs(splitCmdline(cmdline)); display != "" {
setDisplayEnv(display, auth)
return true
}
}
return false
}
// detectX11FromSockets checks /tmp/.X11-unix/ for X sockets and uses ps
// to find the auth file. Works on FreeBSD and other systems without /proc.
func detectX11FromSockets() bool {
entries, err := os.ReadDir(x11SocketDir)
if err != nil {
return false
}
// Pick the lowest numeric display rather than the lexically first
// entry, so X10 doesn't win over X2.
minDisplay := -1
for _, e := range entries {
name := e.Name()
if len(name) < 2 || name[0] != 'X' {
continue
}
n, err := strconv.Atoi(name[1:])
if err != nil {
continue
}
if minDisplay < 0 || n < minDisplay {
minDisplay = n
}
}
if minDisplay < 0 {
return false
}
display := ":" + strconv.Itoa(minDisplay)
os.Setenv(envDisplay, display)
auth := findXorgAuthFromPS()
if auth != "" {
os.Setenv(envXAuthority, auth)
log.Infof("auto-detected DISPLAY=%s (from socket) XAUTHORITY=%s (from ps)", display, auth)
} else {
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
}
return true
}
// findXorgAuthFromPS runs ps to find Xorg and extract its -auth argument.
func findXorgAuthFromPS() string {
out, err := exec.Command("ps", "auxww").Output()
if err != nil {
return ""
}
for _, line := range strings.Split(string(out), "\n") {
if !strings.Contains(line, "Xorg") && !strings.Contains(line, "/X ") {
continue
}
fields := strings.Fields(line)
for i, f := range fields {
if f == "-auth" && i+1 < len(fields) {
return fields[i+1]
}
}
}
return ""
}
func parseXorgArgs(args []string) (display, auth string) {
if len(args) == 0 {
return "", ""
}
base := args[0]
if !(base == "Xorg" || base == "X" || len(base) > 0 && base[len(base)-1] == 'X' ||
strings.Contains(base, "/Xorg") || strings.Contains(base, "/X")) {
return "", ""
}
for i, arg := range args[1:] {
if len(arg) > 0 && arg[0] == ':' {
display = arg
}
if arg == "-auth" && i+2 < len(args) {
auth = args[i+2]
}
}
return display, auth
}
func setDisplayEnv(display, auth string) {
os.Setenv(envDisplay, display)
if auth != "" {
os.Setenv(envXAuthority, auth)
log.Infof("auto-detected DISPLAY=%s XAUTHORITY=%s", display, auth)
return
}
log.Infof("auto-detected DISPLAY=%s", display)
}
func splitCmdline(data []byte) []string {
var args []string
for _, b := range splitNull(data) {
if len(b) > 0 {
args = append(args, string(b))
}
}
return args
}
func splitNull(data []byte) [][]byte {
var parts [][]byte
start := 0
for i, b := range data {
if b == 0 {
parts = append(parts, data[start:i])
start = i + 1
}
}
if start < len(data) {
parts = append(parts, data[start:])
}
return parts
}
// NewX11Capturer connects to the X11 display and sets up shared memory capture.
func NewX11Capturer(display string) (*X11Capturer, error) {
if display == "" {
detectX11Display()
display = os.Getenv(envDisplay)
}
if display == "" {
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
}
conn, err := xgb.NewConnDisplay(display)
if err != nil {
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
}
setup := xproto.Setup(conn)
if len(setup.Roots) == 0 {
conn.Close()
return nil, fmt.Errorf("no X11 screens")
}
screen := setup.Roots[0]
c := &X11Capturer{
conn: conn,
screen: &screen,
w: int(screen.WidthInPixels),
h: int(screen.HeightInPixels),
}
if err := c.initSHM(); err != nil {
log.Debugf("X11 SHM not available, using slow GetImage: %v", err)
}
log.Infof("X11 capturer ready: %dx%d (display=%s, shm=%v)", c.w, c.h, display, c.useSHM)
return c, nil
}
// initSHM is implemented in capture_x11_shm_linux.go (requires SysV SHM).
// On platforms without SysV SHM (FreeBSD), a stub returns an error and
// the capturer falls back to GetImage.
// Width returns the screen width.
func (c *X11Capturer) Width() int { return c.w }
// Height returns the screen height.
func (c *X11Capturer) Height() int { return c.h }
// Capture returns the current screen as an RGBA image.
func (c *X11Capturer) Capture() (*image.RGBA, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.useSHM {
return c.captureSHM()
}
return c.captureGetImage()
}
// CaptureInto fills the caller's destination buffer in one pass. The
// source path (SHM or fallback GetImage) writes directly into dst.Pix
// instead of going through the X11Capturer's internal double-buffer,
// saving one full-frame memcpy per capture.
func (c *X11Capturer) CaptureInto(dst *image.RGBA) error {
c.mu.Lock()
defer c.mu.Unlock()
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
return fmt.Errorf("dst size mismatch: dst=%dx%d capturer=%dx%d",
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
}
if c.useSHM {
return c.captureSHMInto(dst)
}
return c.captureGetImageInto(dst)
}
func (c *X11Capturer) captureGetImageInto(dst *image.RGBA) error {
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
xproto.Drawable(c.screen.Root),
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
reply, err := cookie.Reply()
if err != nil {
return fmt.Errorf("GetImage: %w", err)
}
n := c.w * c.h * 4
if len(reply.Data) < n {
return fmt.Errorf("GetImage returned %d bytes, expected %d", len(reply.Data), n)
}
swizzleBGRAtoRGBA(dst.Pix, reply.Data)
return nil
}
// captureSHM is implemented in capture_x11_shm_linux.go.
func (c *X11Capturer) captureGetImage() (*image.RGBA, error) {
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
xproto.Drawable(c.screen.Root),
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
reply, err := cookie.Reply()
if err != nil {
return nil, fmt.Errorf("GetImage: %w", err)
}
data := reply.Data
n := c.w * c.h * 4
if len(data) < n {
return nil, fmt.Errorf("GetImage returned %d bytes, expected %d", len(data), n)
}
img := c.nextBuffer()
swizzleBGRAtoRGBA(img.Pix, data)
return img, nil
}
// nextBuffer returns the *image.RGBA the next capture should fill, advancing
// the double-buffer index. Reallocates on geometry change.
func (c *X11Capturer) nextBuffer() *image.RGBA {
c.cur ^= 1
b := c.bufs[c.cur]
if b == nil || b.Rect.Dx() != c.w || b.Rect.Dy() != c.h {
b = image.NewRGBA(image.Rect(0, 0, c.w, c.h))
c.bufs[c.cur] = b
}
return b
}
// Close releases X11 resources.
func (c *X11Capturer) Close() {
c.closeSHM()
c.conn.Close()
}
// closeSHM is implemented in capture_x11_shm_linux.go.
// X11Poller wraps X11Capturer with a staleness-cached on-demand Capture:
// sessions drive captures themselves through the encoder goroutine, so we
// don't need a background ticker. The last result is cached for a short
// window so concurrent sessions coalesce into one capture.
//
// The capturer is allocated lazily on first use and released when all
// clients disconnect, so an idle peer holds no X connection or SHM segment.
type X11Poller struct {
mu sync.Mutex
capturer *X11Capturer
w, h int
// closed at Close so callers can stop waiting on retry backoff.
done chan struct{}
// lastFrame/lastAt implement a small cache: multiple near-simultaneous
// Capture calls (multi-client, or input-coalesced) return the same
// frame instead of hammering the X server.
lastFrame *image.RGBA
lastAt time.Time
// initBackoffUntil throttles capturer re-init when the X server is
// unavailable or flapping.
initBackoffUntil time.Time
clients atomic.Int32
display string
}
// initRetryBackoff gates capturer re-init attempts after a failure so we
// don't spin on X server errors.
const initRetryBackoff = 2 * time.Second
// NewX11Poller creates a lazy on-demand capturer for the given X display.
func NewX11Poller(display string) *X11Poller {
return &X11Poller{
display: display,
done: make(chan struct{}),
}
}
// ClientConnect increments the active client count. The first client triggers
// eager capturer initialisation so that the first FBUpdateRequest doesn't
// pay the X11 connect + SHM attach latency.
func (p *X11Poller) ClientConnect() {
if p.clients.Add(1) == 1 {
p.mu.Lock()
_ = p.ensureCapturerLocked()
p.mu.Unlock()
}
}
// ClientDisconnect decrements the active client count. On the last
// disconnect we close the underlying capturer so idle peers cost nothing.
func (p *X11Poller) ClientDisconnect() {
if p.clients.Add(-1) == 0 {
p.mu.Lock()
if p.capturer != nil {
p.capturer.Close()
p.capturer = nil
p.lastFrame = nil
}
p.mu.Unlock()
}
}
// Close releases all resources. Subsequent Capture calls will fail.
func (p *X11Poller) Close() {
p.mu.Lock()
defer p.mu.Unlock()
select {
case <-p.done:
default:
close(p.done)
}
if p.capturer != nil {
p.capturer.Close()
p.capturer = nil
}
}
// Width returns the screen width. Triggers lazy init if needed.
func (p *X11Poller) Width() int {
p.mu.Lock()
defer p.mu.Unlock()
_ = p.ensureCapturerLocked()
return p.w
}
// Height returns the screen height. Triggers lazy init if needed.
func (p *X11Poller) Height() int {
p.mu.Lock()
defer p.mu.Unlock()
_ = p.ensureCapturerLocked()
return p.h
}
// Cursor satisfies cursorSource by forwarding to the lazily-initialised
// X11Capturer. Asking for the cursor on an idle poller triggers the same
// lazy X11 connection setup as a capture would.
func (p *X11Poller) Cursor() (*image.RGBA, int, int, uint64, error) {
p.mu.Lock()
defer p.mu.Unlock()
if err := p.ensureCapturerLocked(); err != nil {
return nil, 0, 0, 0, err
}
return p.capturer.Cursor()
}
// CursorPos satisfies cursorPositionSource by forwarding to the X11Capturer.
func (p *X11Poller) CursorPos() (int, int, error) {
p.mu.Lock()
defer p.mu.Unlock()
if err := p.ensureCapturerLocked(); err != nil {
return 0, 0, err
}
return p.capturer.CursorPos()
}
// Capture returns a fresh frame, serving from the short-lived cache if a
// previous caller captured within freshWindow.
func (p *X11Poller) Capture() (*image.RGBA, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.lastFrame != nil && time.Since(p.lastAt) < freshWindow {
return p.lastFrame, nil
}
if err := p.ensureCapturerLocked(); err != nil {
return nil, err
}
img, err := p.capturer.Capture()
if err != nil {
// Drop the capturer so the next call re-inits; the X connection may
// have died (e.g. Xorg restart).
p.capturer.Close()
p.capturer = nil
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
return nil, fmt.Errorf("x11 capture: %w", err)
}
p.lastFrame = img
p.lastAt = time.Now()
return img, nil
}
// CaptureInto fills dst directly via the underlying capturer, bypassing
// the freshness cache. The session's prevFrame/curFrame swap means each
// session needs its own buffer anyway, so caching wouldn't help.
func (p *X11Poller) CaptureInto(dst *image.RGBA) error {
p.mu.Lock()
defer p.mu.Unlock()
if err := p.ensureCapturerLocked(); err != nil {
return err
}
if err := p.capturer.CaptureInto(dst); err != nil {
p.capturer.Close()
p.capturer = nil
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
return fmt.Errorf("x11 capture: %w", err)
}
return nil
}
// ensureCapturerLocked initialises the underlying X11Capturer if not
// already open. Caller must hold p.mu.
func (p *X11Poller) ensureCapturerLocked() error {
if p.capturer != nil {
return nil
}
select {
case <-p.done:
return fmt.Errorf("x11 capturer closed")
default:
}
if time.Now().Before(p.initBackoffUntil) {
return fmt.Errorf("x11 capturer unavailable (retry scheduled)")
}
c, err := NewX11Capturer(p.display)
if err != nil {
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
log.Debugf("X11 capturer: %v", err)
return err
}
p.capturer = c
p.w, p.h = c.Width(), c.Height()
return nil
}

View File

@@ -0,0 +1,96 @@
//go:build linux && !android
package server
import (
"fmt"
"image"
"github.com/jezek/xgb/shm"
"github.com/jezek/xgb/xproto"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
func (c *X11Capturer) initSHM() error {
if err := shm.Init(c.conn); err != nil {
return fmt.Errorf("init SHM extension: %w", err)
}
size := c.w * c.h * 4
id, err := unix.SysvShmGet(unix.IPC_PRIVATE, size, unix.IPC_CREAT|0600)
if err != nil {
return fmt.Errorf("shmget: %w", err)
}
addr, err := unix.SysvShmAttach(id, 0, 0)
if err != nil {
if _, ctlErr := unix.SysvShmCtl(id, unix.IPC_RMID, nil); ctlErr != nil {
log.Debugf("shmctl IPC_RMID on attach failure: %v", ctlErr)
}
return fmt.Errorf("shmat: %w", err)
}
if _, err := unix.SysvShmCtl(id, unix.IPC_RMID, nil); err != nil {
log.Debugf("shmctl IPC_RMID: %v", err)
}
seg, err := shm.NewSegId(c.conn)
if err != nil {
if detachErr := unix.SysvShmDetach(addr); detachErr != nil {
log.Debugf("shmdt on new-seg failure: %v", detachErr)
}
return fmt.Errorf("new SHM seg: %w", err)
}
if err := shm.AttachChecked(c.conn, seg, uint32(id), false).Check(); err != nil {
if detachErr := unix.SysvShmDetach(addr); detachErr != nil {
log.Debugf("shmdt on attach-checked failure: %v", detachErr)
}
return fmt.Errorf("SHM attach to X: %w", err)
}
c.shmID = id
c.shmAddr = addr
c.shmSeg = uint32(seg)
c.useSHM = true
return nil
}
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
if err := c.fillSHM(); err != nil {
return nil, err
}
img := c.nextBuffer()
swizzleBGRAtoRGBA(img.Pix, c.shmAddr[:c.w*c.h*4])
return img, nil
}
// captureSHMInto runs a single SHM GetImage and swizzles directly into the
// caller-provided destination, skipping the internal double-buffer.
func (c *X11Capturer) captureSHMInto(dst *image.RGBA) error {
if err := c.fillSHM(); err != nil {
return err
}
swizzleBGRAtoRGBA(dst.Pix, c.shmAddr[:c.w*c.h*4])
return nil
}
func (c *X11Capturer) fillSHM() error {
cookie := shm.GetImage(c.conn, xproto.Drawable(c.screen.Root),
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF,
xproto.ImageFormatZPixmap, shm.Seg(c.shmSeg), 0)
if _, err := cookie.Reply(); err != nil {
return fmt.Errorf("SHM GetImage: %w", err)
}
return nil
}
func (c *X11Capturer) closeSHM() {
if c.useSHM {
shm.Detach(c.conn, shm.Seg(c.shmSeg))
if err := unix.SysvShmDetach(c.shmAddr); err != nil {
log.Debugf("shmdt on close: %v", err)
}
}
}

View File

@@ -0,0 +1,24 @@
//go:build freebsd
package server
import (
"fmt"
"image"
)
func (c *X11Capturer) initSHM() error {
return fmt.Errorf("SysV SHM not available on this platform")
}
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
return nil, fmt.Errorf("SHM capture not available on this platform")
}
func (c *X11Capturer) captureSHMInto(_ *image.RGBA) error {
return fmt.Errorf("SHM capture not available on this platform")
}
func (c *X11Capturer) closeSHM() {
// no SHM to close on this platform
}

View File

@@ -0,0 +1,77 @@
//go:build !js && !ios && !android
package server
import (
"reflect"
"testing"
)
func TestCoalesceRects(t *testing.T) {
cases := []struct {
name string
in [][4]int
want [][4]int
}{
{
name: "empty",
in: nil,
want: nil,
},
{
name: "single",
in: [][4]int{{0, 0, 64, 64}},
want: [][4]int{{0, 0, 64, 64}},
},
{
name: "horizontal_run",
in: [][4]int{{0, 0, 64, 64}, {64, 0, 64, 64}, {128, 0, 64, 64}},
want: [][4]int{{0, 0, 192, 64}},
},
{
name: "vertical_run",
in: [][4]int{{0, 0, 64, 64}, {0, 64, 64, 64}, {0, 128, 64, 64}},
want: [][4]int{{0, 0, 64, 192}},
},
{
name: "block_2x2",
in: [][4]int{
{0, 0, 64, 64}, {64, 0, 64, 64},
{0, 64, 64, 64}, {64, 64, 64, 64},
},
want: [][4]int{{0, 0, 128, 128}},
},
{
name: "no_merge_gap",
in: [][4]int{{0, 0, 64, 64}, {192, 0, 64, 64}},
want: [][4]int{{0, 0, 64, 64}, {192, 0, 64, 64}},
},
{
name: "two_disjoint_columns",
in: [][4]int{
{0, 0, 64, 64}, {192, 0, 64, 64},
{0, 64, 64, 64}, {192, 64, 64, 64},
},
want: [][4]int{{0, 0, 64, 128}, {192, 0, 64, 128}},
},
{
name: "misaligned_widths_no_vertical_merge",
in: [][4]int{
{0, 0, 128, 64},
{0, 64, 64, 64},
},
want: [][4]int{{0, 0, 128, 64}, {0, 64, 64, 64}},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := coalesceRects(tc.in)
if len(got) == 0 && len(tc.want) == 0 {
return
}
if !reflect.DeepEqual(got, tc.want) {
t.Fatalf("got %v want %v", got, tc.want)
}
})
}
}

View File

@@ -0,0 +1,10 @@
package server
// interactiveUserError returns nil when a user is logged into the console
// (i.e. an Aqua session is active). At the loginwindow there is nobody to
// display an approval prompt to, so callers can decline without waiting on
// the broker. Any error (including errNoConsoleUser) is treated as decline.
func interactiveUserError() error {
_, err := consoleUserID()
return err
}

View File

@@ -0,0 +1,7 @@
//go:build !darwin && !windows
package server
// interactiveUserError is unused outside service mode (darwin/windows) but
// the symbol must exist so gateApproval compiles on all platforms.
func interactiveUserError() error { return nil }

Some files were not shown because too many files have changed in this diff Show More